From 1a6b83b82c783575f176a44ed3d15dedad2df198 Mon Sep 17 00:00:00 2001
From: milen <94537774+taratorio@users.noreply.github.com>
Date: Thu, 14 Dec 2023 20:50:59 +0000
Subject: [PATCH] borheimdall: add test for span persistence (#8988)

1. Adds an eth/stagedsync/test package which provides a test Harness
object
2. Adds the first automated test to the bor-heimdall stage regarding
span persistence (more to come in subsequent PRs)
3. Fixes a bug in the bor-heimdall stage which was uncovered with the
test - we do not fetch span 0 when we sync straight from blockNum=0
without snapshots
4. Reorganises all mocks to be placed under ./mock sub-package within
their respective packages
---
 .../bor/{genesis.go => genesis_contract.go}   |   2 +-
 consensus/bor/heimdall/heimall.go             |   2 +-
 .../bor/heimdall/mock/heimdall_client_mock.go |  80 ++-
 .../bor/{ => mock}/genesis_contract_mock.go   |   8 +-
 .../{span_mock.go => mock/spanner_mock.go}    |   8 +-
 consensus/bor/{span.go => spanner.go}         |   2 +-
 consensus/consensus.go                        |   5 +-
 consensus/mock/chain_header_reader_mock.go    | 150 ++++++
 eth/stagedsync/stage_bor_heimdall.go          |  18 +-
 eth/stagedsync/stage_bor_heimdall_test.go     |  44 ++
 eth/stagedsync/test/chain_configs.go          |  20 +
 eth/stagedsync/test/harness.go                | 459 ++++++++++++++++++
 12 files changed, 775 insertions(+), 23 deletions(-)
 rename consensus/bor/{genesis.go => genesis_contract.go} (74%)
 rename tests/bor/mocks/IHeimdallClient.go => consensus/bor/heimdall/mock/heimdall_client_mock.go (54%)
 rename consensus/bor/{ => mock}/genesis_contract_mock.go (90%)
 rename consensus/bor/{span_mock.go => mock/spanner_mock.go} (93%)
 rename consensus/bor/{span.go => spanner.go} (88%)
 create mode 100644 consensus/mock/chain_header_reader_mock.go
 create mode 100644 eth/stagedsync/stage_bor_heimdall_test.go
 create mode 100644 eth/stagedsync/test/chain_configs.go
 create mode 100644 eth/stagedsync/test/harness.go

diff --git a/consensus/bor/genesis.go b/consensus/bor/genesis_contract.go
similarity index 74%
rename from consensus/bor/genesis.go
rename to consensus/bor/genesis_contract.go
index 24b0964f4..7a232733b 100644
--- a/consensus/bor/genesis.go
+++ b/consensus/bor/genesis_contract.go
@@ -7,7 +7,7 @@ import (
 	"github.com/ledgerwatch/erigon/rlp"
 )
 
-//go:generate mockgen -destination=./genesis_contract_mock.go -package=bor . GenesisContract
+//go:generate mockgen -destination=./mock/genesis_contract_mock.go -package=mock . GenesisContract
 type GenesisContract interface {
 	CommitState(event rlp.RawValue, syscall consensus.SystemCall) error
 	LastStateId(syscall consensus.SystemCall) (*big.Int, error)
diff --git a/consensus/bor/heimdall/heimall.go b/consensus/bor/heimdall/heimall.go
index 2ef405290..ea98ab5ba 100644
--- a/consensus/bor/heimdall/heimall.go
+++ b/consensus/bor/heimdall/heimall.go
@@ -14,7 +14,7 @@ func MilestoneRewindPending() bool {
 	return generics.BorMilestoneRewind.Load() != nil && *generics.BorMilestoneRewind.Load() != 0
 }
 
-//go:generate mockgen -destination=../../tests/bor/mocks/IHeimdallClient.go -package=mocks . IHeimdallClient
+//go:generate mockgen -destination=./mock/heimdall_client_mock.go -package=mock . IHeimdallClient
 type IHeimdallClient interface {
 	StateSyncEvents(ctx context.Context, fromID uint64, to int64) ([]*clerk.EventRecordWithTime, error)
 	Span(ctx context.Context, spanID uint64) (*span.HeimdallSpan, error)
diff --git a/tests/bor/mocks/IHeimdallClient.go b/consensus/bor/heimdall/mock/heimdall_client_mock.go
similarity index 54%
rename from tests/bor/mocks/IHeimdallClient.go
rename to consensus/bor/heimdall/mock/heimdall_client_mock.go
index 1737cae88..b098f0558 100644
--- a/tests/bor/mocks/IHeimdallClient.go
+++ b/consensus/bor/heimdall/mock/heimdall_client_mock.go
@@ -1,8 +1,8 @@
 // Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/ledgerwatch/erigon/consensus/bor (interfaces: IHeimdallClient)
+// Source: github.com/ledgerwatch/erigon/consensus/bor/heimdall (interfaces: IHeimdallClient)
 
-// Package mocks is a generated GoMock package.
-package mocks
+// Package mock is a generated GoMock package.
+package mock
 
 import (
 	context "context"
@@ -11,6 +11,7 @@ import (
 	gomock "github.com/golang/mock/gomock"
 	clerk "github.com/ledgerwatch/erigon/consensus/bor/clerk"
 	checkpoint "github.com/ledgerwatch/erigon/consensus/bor/heimdall/checkpoint"
+	milestone "github.com/ledgerwatch/erigon/consensus/bor/heimdall/milestone"
 	span "github.com/ledgerwatch/erigon/consensus/bor/heimdall/span"
 )
 
@@ -79,6 +80,79 @@ func (mr *MockIHeimdallClientMockRecorder) FetchCheckpointCount(arg0 interface{}
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchCheckpointCount", reflect.TypeOf((*MockIHeimdallClient)(nil).FetchCheckpointCount), arg0)
 }
 
+// FetchLastNoAckMilestone mocks base method.
+func (m *MockIHeimdallClient) FetchLastNoAckMilestone(arg0 context.Context) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "FetchLastNoAckMilestone", arg0)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// FetchLastNoAckMilestone indicates an expected call of FetchLastNoAckMilestone.
+func (mr *MockIHeimdallClientMockRecorder) FetchLastNoAckMilestone(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchLastNoAckMilestone", reflect.TypeOf((*MockIHeimdallClient)(nil).FetchLastNoAckMilestone), arg0)
+}
+
+// FetchMilestone mocks base method.
+func (m *MockIHeimdallClient) FetchMilestone(arg0 context.Context) (*milestone.Milestone, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "FetchMilestone", arg0)
+	ret0, _ := ret[0].(*milestone.Milestone)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// FetchMilestone indicates an expected call of FetchMilestone.
+func (mr *MockIHeimdallClientMockRecorder) FetchMilestone(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchMilestone", reflect.TypeOf((*MockIHeimdallClient)(nil).FetchMilestone), arg0)
+}
+
+// FetchMilestoneCount mocks base method.
+func (m *MockIHeimdallClient) FetchMilestoneCount(arg0 context.Context) (int64, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "FetchMilestoneCount", arg0)
+	ret0, _ := ret[0].(int64)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// FetchMilestoneCount indicates an expected call of FetchMilestoneCount.
+func (mr *MockIHeimdallClientMockRecorder) FetchMilestoneCount(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchMilestoneCount", reflect.TypeOf((*MockIHeimdallClient)(nil).FetchMilestoneCount), arg0)
+}
+
+// FetchMilestoneID mocks base method.
+func (m *MockIHeimdallClient) FetchMilestoneID(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "FetchMilestoneID", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// FetchMilestoneID indicates an expected call of FetchMilestoneID.
+func (mr *MockIHeimdallClientMockRecorder) FetchMilestoneID(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchMilestoneID", reflect.TypeOf((*MockIHeimdallClient)(nil).FetchMilestoneID), arg0, arg1)
+}
+
+// FetchNoAckMilestone mocks base method.
+func (m *MockIHeimdallClient) FetchNoAckMilestone(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "FetchNoAckMilestone", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// FetchNoAckMilestone indicates an expected call of FetchNoAckMilestone.
+func (mr *MockIHeimdallClientMockRecorder) FetchNoAckMilestone(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchNoAckMilestone", reflect.TypeOf((*MockIHeimdallClient)(nil).FetchNoAckMilestone), arg0, arg1)
+}
+
 // Span mocks base method.
 func (m *MockIHeimdallClient) Span(arg0 context.Context, arg1 uint64) (*span.HeimdallSpan, error) {
 	m.ctrl.T.Helper()
diff --git a/consensus/bor/genesis_contract_mock.go b/consensus/bor/mock/genesis_contract_mock.go
similarity index 90%
rename from consensus/bor/genesis_contract_mock.go
rename to consensus/bor/mock/genesis_contract_mock.go
index 6cba2c64b..9ad12ae63 100644
--- a/consensus/bor/genesis_contract_mock.go
+++ b/consensus/bor/mock/genesis_contract_mock.go
@@ -1,8 +1,8 @@
 // Code generated by MockGen. DO NOT EDIT.
 // Source: github.com/ledgerwatch/erigon/consensus/bor (interfaces: GenesisContract)
 
-// Package bor is a generated GoMock package.
-package bor
+// Package mock is a generated GoMock package.
+package mock
 
 import (
 	big "math/big"
@@ -10,7 +10,7 @@ import (
 
 	gomock "github.com/golang/mock/gomock"
 	consensus "github.com/ledgerwatch/erigon/consensus"
-	clerk "github.com/ledgerwatch/erigon/consensus/bor/clerk"
+	rlp "github.com/ledgerwatch/erigon/rlp"
 )
 
 // MockGenesisContract is a mock of GenesisContract interface.
@@ -37,7 +37,7 @@ func (m *MockGenesisContract) EXPECT() *MockGenesisContractMockRecorder {
 }
 
 // CommitState mocks base method.
-func (m *MockGenesisContract) CommitState(arg0 *clerk.EventRecordWithTime, arg1 consensus.SystemCall) error {
+func (m *MockGenesisContract) CommitState(arg0 rlp.RawValue, arg1 consensus.SystemCall) error {
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "CommitState", arg0, arg1)
 	ret0, _ := ret[0].(error)
diff --git a/consensus/bor/span_mock.go b/consensus/bor/mock/spanner_mock.go
similarity index 93%
rename from consensus/bor/span_mock.go
rename to consensus/bor/mock/spanner_mock.go
index ced3dee6a..70db933ed 100644
--- a/consensus/bor/span_mock.go
+++ b/consensus/bor/mock/spanner_mock.go
@@ -1,8 +1,8 @@
 // Code generated by MockGen. DO NOT EDIT.
 // Source: github.com/ledgerwatch/erigon/consensus/bor (interfaces: Spanner)
 
-// Package bor is a generated GoMock package.
-package bor
+// Package mock is a generated GoMock package.
+package mock
 
 import (
 	reflect "reflect"
@@ -52,7 +52,7 @@ func (mr *MockSpannerMockRecorder) CommitSpan(arg0, arg1 interface{}) *gomock.Ca
 }
 
 // GetCurrentProducers mocks base method.
-func (m *MockSpanner) GetCurrentProducers(arg0 uint64, arg1 common.Address, arg2 func(uint64) (*span.HeimdallSpan, error)) ([]*valset.Validator, error) {
+func (m *MockSpanner) GetCurrentProducers(arg0 uint64, arg1 common.Address, arg2 consensus.ChainHeaderReader) ([]*valset.Validator, error) {
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "GetCurrentProducers", arg0, arg1, arg2)
 	ret0, _ := ret[0].([]*valset.Validator)
@@ -82,7 +82,7 @@ func (mr *MockSpannerMockRecorder) GetCurrentSpan(arg0 interface{}) *gomock.Call
 }
 
 // GetCurrentValidators mocks base method.
-func (m *MockSpanner) GetCurrentValidators(arg0 uint64, arg1 common.Address, arg2 func(uint64) (*span.HeimdallSpan, error)) ([]*valset.Validator, error) {
+func (m *MockSpanner) GetCurrentValidators(arg0 uint64, arg1 common.Address, arg2 consensus.ChainHeaderReader) ([]*valset.Validator, error) {
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "GetCurrentValidators", arg0, arg1, arg2)
 	ret0, _ := ret[0].([]*valset.Validator)
diff --git a/consensus/bor/span.go b/consensus/bor/spanner.go
similarity index 88%
rename from consensus/bor/span.go
rename to consensus/bor/spanner.go
index 41e8abec8..77769ea83 100644
--- a/consensus/bor/span.go
+++ b/consensus/bor/spanner.go
@@ -7,7 +7,7 @@ import (
 	"github.com/ledgerwatch/erigon/consensus/bor/valset"
 )
 
-//go:generate mockgen -destination=./span_mock.go -package=bor . Spanner
+//go:generate mockgen -destination=./mock/spanner_mock.go -package=mock . Spanner
 type Spanner interface {
 	GetCurrentSpan(syscall consensus.SystemCall) (*span.Span, error)
 	GetCurrentValidators(spanId uint64, signer libcommon.Address, chain consensus.ChainHeaderReader) ([]*valset.Validator, error)
diff --git a/consensus/consensus.go b/consensus/consensus.go
index 0a98706fa..d9ba40fff 100644
--- a/consensus/consensus.go
+++ b/consensus/consensus.go
@@ -21,19 +21,20 @@ import (
 	"math/big"
 
 	"github.com/holiman/uint256"
+	"github.com/ledgerwatch/log/v3"
 
 	"github.com/ledgerwatch/erigon-lib/chain"
 	libcommon "github.com/ledgerwatch/erigon-lib/common"
-
 	"github.com/ledgerwatch/erigon/core/state"
 	"github.com/ledgerwatch/erigon/core/types"
 	"github.com/ledgerwatch/erigon/rlp"
 	"github.com/ledgerwatch/erigon/rpc"
-	"github.com/ledgerwatch/log/v3"
 )
 
 // ChainHeaderReader defines a small collection of methods needed to access the local
 // blockchain during header verification.
+//
+//go:generate mockgen -destination=./mock/chain_header_reader_mock.go -package=mock . ChainHeaderReader
 type ChainHeaderReader interface {
 	// Config retrieves the blockchain's chain configuration.
 	Config() *chain.Config
diff --git a/consensus/mock/chain_header_reader_mock.go b/consensus/mock/chain_header_reader_mock.go
new file mode 100644
index 000000000..5131b49e3
--- /dev/null
+++ b/consensus/mock/chain_header_reader_mock.go
@@ -0,0 +1,150 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/ledgerwatch/erigon/consensus (interfaces: ChainHeaderReader)
+
+// Package mock is a generated GoMock package.
+package mock
+
+import (
+	big "math/big"
+	reflect "reflect"
+
+	gomock "github.com/golang/mock/gomock"
+	chain "github.com/ledgerwatch/erigon-lib/chain"
+	common "github.com/ledgerwatch/erigon-lib/common"
+	types "github.com/ledgerwatch/erigon/core/types"
+)
+
+// MockChainHeaderReader is a mock of ChainHeaderReader interface.
+type MockChainHeaderReader struct {
+	ctrl     *gomock.Controller
+	recorder *MockChainHeaderReaderMockRecorder
+}
+
+// MockChainHeaderReaderMockRecorder is the mock recorder for MockChainHeaderReader.
+type MockChainHeaderReaderMockRecorder struct {
+	mock *MockChainHeaderReader
+}
+
+// NewMockChainHeaderReader creates a new mock instance.
+func NewMockChainHeaderReader(ctrl *gomock.Controller) *MockChainHeaderReader {
+	mock := &MockChainHeaderReader{ctrl: ctrl}
+	mock.recorder = &MockChainHeaderReaderMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockChainHeaderReader) EXPECT() *MockChainHeaderReaderMockRecorder {
+	return m.recorder
+}
+
+// BorSpan mocks base method.
+func (m *MockChainHeaderReader) BorSpan(arg0 uint64) []byte {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "BorSpan", arg0)
+	ret0, _ := ret[0].([]byte)
+	return ret0
+}
+
+// BorSpan indicates an expected call of BorSpan.
+func (mr *MockChainHeaderReaderMockRecorder) BorSpan(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BorSpan", reflect.TypeOf((*MockChainHeaderReader)(nil).BorSpan), arg0)
+}
+
+// Config mocks base method.
+func (m *MockChainHeaderReader) Config() *chain.Config {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "Config")
+	ret0, _ := ret[0].(*chain.Config)
+	return ret0
+}
+
+// Config indicates an expected call of Config.
+func (mr *MockChainHeaderReaderMockRecorder) Config() *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Config", reflect.TypeOf((*MockChainHeaderReader)(nil).Config))
+}
+
+// CurrentHeader mocks base method.
+func (m *MockChainHeaderReader) CurrentHeader() *types.Header {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "CurrentHeader")
+	ret0, _ := ret[0].(*types.Header)
+	return ret0
+}
+
+// CurrentHeader indicates an expected call of CurrentHeader.
+func (mr *MockChainHeaderReaderMockRecorder) CurrentHeader() *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CurrentHeader", reflect.TypeOf((*MockChainHeaderReader)(nil).CurrentHeader))
+}
+
+// FrozenBlocks mocks base method.
+func (m *MockChainHeaderReader) FrozenBlocks() uint64 {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "FrozenBlocks")
+	ret0, _ := ret[0].(uint64)
+	return ret0
+}
+
+// FrozenBlocks indicates an expected call of FrozenBlocks.
+func (mr *MockChainHeaderReaderMockRecorder) FrozenBlocks() *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FrozenBlocks", reflect.TypeOf((*MockChainHeaderReader)(nil).FrozenBlocks))
+}
+
+// GetHeader mocks base method.
+func (m *MockChainHeaderReader) GetHeader(arg0 common.Hash, arg1 uint64) *types.Header {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetHeader", arg0, arg1)
+	ret0, _ := ret[0].(*types.Header)
+	return ret0
+}
+
+// GetHeader indicates an expected call of GetHeader.
+func (mr *MockChainHeaderReaderMockRecorder) GetHeader(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeader", reflect.TypeOf((*MockChainHeaderReader)(nil).GetHeader), arg0, arg1)
+}
+
+// GetHeaderByHash mocks base method.
+func (m *MockChainHeaderReader) GetHeaderByHash(arg0 common.Hash) *types.Header {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetHeaderByHash", arg0)
+	ret0, _ := ret[0].(*types.Header)
+	return ret0
+}
+
+// GetHeaderByHash indicates an expected call of GetHeaderByHash.
+func (mr *MockChainHeaderReaderMockRecorder) GetHeaderByHash(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeaderByHash", reflect.TypeOf((*MockChainHeaderReader)(nil).GetHeaderByHash), arg0)
+}
+
+// GetHeaderByNumber mocks base method.
+func (m *MockChainHeaderReader) GetHeaderByNumber(arg0 uint64) *types.Header {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetHeaderByNumber", arg0)
+	ret0, _ := ret[0].(*types.Header)
+	return ret0
+}
+
+// GetHeaderByNumber indicates an expected call of GetHeaderByNumber.
+func (mr *MockChainHeaderReaderMockRecorder) GetHeaderByNumber(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeaderByNumber", reflect.TypeOf((*MockChainHeaderReader)(nil).GetHeaderByNumber), arg0)
+}
+
+// GetTd mocks base method.
+func (m *MockChainHeaderReader) GetTd(arg0 common.Hash, arg1 uint64) *big.Int {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetTd", arg0, arg1)
+	ret0, _ := ret[0].(*big.Int)
+	return ret0
+}
+
+// GetTd indicates an expected call of GetTd.
+func (mr *MockChainHeaderReaderMockRecorder) GetTd(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTd", reflect.TypeOf((*MockChainHeaderReader)(nil).GetTd), arg0, arg1)
+}
diff --git a/eth/stagedsync/stage_bor_heimdall.go b/eth/stagedsync/stage_bor_heimdall.go
index 41c5645b6..718b40e22 100644
--- a/eth/stagedsync/stage_bor_heimdall.go
+++ b/eth/stagedsync/stage_bor_heimdall.go
@@ -12,6 +12,9 @@ import (
 	"time"
 
 	lru "github.com/hashicorp/golang-lru/arc/v2"
+	"github.com/ledgerwatch/log/v3"
+	"golang.org/x/sync/errgroup"
+
 	"github.com/ledgerwatch/erigon-lib/chain"
 	"github.com/ledgerwatch/erigon-lib/common"
 	libcommon "github.com/ledgerwatch/erigon-lib/common"
@@ -32,8 +35,6 @@ import (
 	"github.com/ledgerwatch/erigon/rlp"
 	"github.com/ledgerwatch/erigon/turbo/services"
 	"github.com/ledgerwatch/erigon/turbo/stages/headerdownload"
-	"github.com/ledgerwatch/log/v3"
-	"golang.org/x/sync/errgroup"
 )
 
 const (
@@ -201,13 +202,17 @@ func BorHeimdallForward(
 	if err != nil {
 		return err
 	}
-	var nextSpanId uint64
+	var lastSpanId uint64
 	if k != nil {
-		nextSpanId = binary.BigEndian.Uint64(k) + 1
+		lastSpanId = binary.BigEndian.Uint64(k)
 	}
 	snapshotLastSpanId := cfg.blockReader.(LastFrozen).LastFrozenSpanID()
-	if snapshotLastSpanId+1 > nextSpanId {
-		nextSpanId = snapshotLastSpanId + 1
+	if snapshotLastSpanId > lastSpanId {
+		lastSpanId = snapshotLastSpanId
+	}
+	var nextSpanId uint64
+	if lastSpanId > 0 {
+		nextSpanId = lastSpanId + 1
 	}
 	var endSpanID uint64
 	if headNumber > zerothSpanEnd {
@@ -231,7 +236,6 @@ func BorHeimdallForward(
 	var blockNum uint64
 	var fetchTime time.Duration
 	var eventRecords int
-	var lastSpanId uint64
 
 	logTimer := time.NewTicker(logInterval)
 	defer logTimer.Stop()
diff --git a/eth/stagedsync/stage_bor_heimdall_test.go b/eth/stagedsync/stage_bor_heimdall_test.go
new file mode 100644
index 000000000..579807103
--- /dev/null
+++ b/eth/stagedsync/stage_bor_heimdall_test.go
@@ -0,0 +1,44 @@
+package stagedsync_test
+
+import (
+	"context"
+	"testing"
+
+	"github.com/ledgerwatch/log/v3"
+	"github.com/stretchr/testify/require"
+
+	"github.com/ledgerwatch/erigon/eth/stagedsync/stages"
+	"github.com/ledgerwatch/erigon/eth/stagedsync/test"
+	"github.com/ledgerwatch/erigon/turbo/testlog"
+)
+
+func TestBorHeimdallForwardPersistsSpans(t *testing.T) {
+	t.Parallel()
+
+	ctx := context.Background()
+	logger := testlog.Logger(t, log.LvlInfo)
+	numBlocks := 6640
+	testHarness := test.InitHarness(ctx, t, logger, test.HarnessCfg{
+		ChainConfig:            test.BorDevnetChainConfigWithNoBlockSealDelays(),
+		GenerateChainNumBlocks: numBlocks,
+	})
+	// pretend-update previous stage progress
+	testHarness.SaveStageProgress(ctx, t, stages.Headers, uint64(numBlocks))
+
+	// run stage under test
+	testHarness.RunStageForward(t, stages.BorHeimdall)
+
+	// asserts
+	spans, err := testHarness.ReadSpansFromDb(ctx)
+	require.NoError(t, err)
+	require.Len(t, spans, 3)
+	require.Equal(t, uint64(0), spans[0].ID)
+	require.Equal(t, uint64(0), spans[0].StartBlock)
+	require.Equal(t, uint64(255), spans[0].EndBlock)
+	require.Equal(t, uint64(1), spans[1].ID)
+	require.Equal(t, uint64(256), spans[1].StartBlock)
+	require.Equal(t, uint64(6655), spans[1].EndBlock)
+	require.Equal(t, uint64(2), spans[2].ID)
+	require.Equal(t, uint64(6656), spans[2].StartBlock)
+	require.Equal(t, uint64(13055), spans[2].EndBlock)
+}
diff --git a/eth/stagedsync/test/chain_configs.go b/eth/stagedsync/test/chain_configs.go
new file mode 100644
index 000000000..db274245e
--- /dev/null
+++ b/eth/stagedsync/test/chain_configs.go
@@ -0,0 +1,20 @@
+package test
+
+import (
+	"github.com/ledgerwatch/erigon-lib/chain"
+	"github.com/ledgerwatch/erigon/params"
+)
+
+func BorDevnetChainConfigWithNoBlockSealDelays() *chain.Config {
+	// take care not to mutate global var (shallow copy)
+	chainConfigCopy := *params.BorDevnetChainConfig
+	borConfigCopy := *chainConfigCopy.Bor
+	borConfigCopy.Period = map[string]uint64{
+		"0": 0,
+	}
+	borConfigCopy.ProducerDelay = map[string]uint64{
+		"0": 0,
+	}
+	chainConfigCopy.Bor = &borConfigCopy
+	return &chainConfigCopy
+}
diff --git a/eth/stagedsync/test/harness.go b/eth/stagedsync/test/harness.go
new file mode 100644
index 000000000..9ac3d5412
--- /dev/null
+++ b/eth/stagedsync/test/harness.go
@@ -0,0 +1,459 @@
+package test
+
+import (
+	"context"
+	"crypto/ecdsa"
+	"encoding/binary"
+	"encoding/json"
+	"fmt"
+	"math/big"
+	"testing"
+
+	"github.com/golang/mock/gomock"
+	"github.com/holiman/uint256"
+	"github.com/ledgerwatch/log/v3"
+	"github.com/stretchr/testify/require"
+
+	"github.com/ledgerwatch/erigon-lib/chain"
+	libcommon "github.com/ledgerwatch/erigon-lib/common"
+	"github.com/ledgerwatch/erigon-lib/kv"
+	"github.com/ledgerwatch/erigon-lib/kv/memdb"
+	"github.com/ledgerwatch/erigon/consensus"
+	"github.com/ledgerwatch/erigon/consensus/bor"
+	"github.com/ledgerwatch/erigon/consensus/bor/clerk"
+	"github.com/ledgerwatch/erigon/consensus/bor/contract"
+	heimdallmock "github.com/ledgerwatch/erigon/consensus/bor/heimdall/mock"
+	"github.com/ledgerwatch/erigon/consensus/bor/heimdall/span"
+	bormock "github.com/ledgerwatch/erigon/consensus/bor/mock"
+	"github.com/ledgerwatch/erigon/consensus/bor/valset"
+	consensusmock "github.com/ledgerwatch/erigon/consensus/mock"
+	"github.com/ledgerwatch/erigon/core"
+	"github.com/ledgerwatch/erigon/core/rawdb"
+	"github.com/ledgerwatch/erigon/core/types"
+	"github.com/ledgerwatch/erigon/crypto"
+	"github.com/ledgerwatch/erigon/eth/ethconfig"
+	"github.com/ledgerwatch/erigon/eth/stagedsync"
+	"github.com/ledgerwatch/erigon/eth/stagedsync/stages"
+	"github.com/ledgerwatch/erigon/turbo/snapshotsync/freezeblocks"
+)
+
+func InitHarness(ctx context.Context, t *testing.T, logger log.Logger, cfg HarnessCfg) Harness {
+	chainDataDb := memdb.NewTestDB(t)
+	borConsensusDb := memdb.NewTestDB(t)
+	ctrl := gomock.NewController(t)
+	heimdallClient := heimdallmock.NewMockIHeimdallClient(ctrl)
+	snapshotsDir := t.TempDir()
+	blocksFreezingCfg := ethconfig.NewSnapCfg(true, true, true)
+	allRoSnapshots := freezeblocks.NewRoSnapshots(blocksFreezingCfg, snapshotsDir, logger)
+	allRoSnapshots.OptimisticalyReopenWithDB(chainDataDb)
+	allBorRoSnapshots := freezeblocks.NewBorRoSnapshots(blocksFreezingCfg, snapshotsDir, logger)
+	allBorRoSnapshots.OptimisticalyReopenWithDB(chainDataDb)
+	blockReader := freezeblocks.NewBlockReader(allRoSnapshots, allBorRoSnapshots)
+	bhCfg := stagedsync.StageBorHeimdallCfg(
+		chainDataDb,
+		borConsensusDb,
+		stagedsync.NewProposingState(&ethconfig.Defaults.Miner),
+		*cfg.ChainConfig,
+		heimdallClient,
+		blockReader,
+		nil, // headerDownloader
+		nil, // penalize
+		nil, // not used
+		nil, // not used
+	)
+	stateSyncStages := stagedsync.DefaultStages(
+		ctx,
+		stagedsync.SnapshotsCfg{},
+		stagedsync.HeadersCfg{},
+		bhCfg,
+		stagedsync.BlockHashesCfg{},
+		stagedsync.BodiesCfg{},
+		stagedsync.SendersCfg{},
+		stagedsync.ExecuteBlockCfg{},
+		stagedsync.HashStateCfg{},
+		stagedsync.TrieCfg{},
+		stagedsync.HistoryCfg{},
+		stagedsync.LogIndexCfg{},
+		stagedsync.CallTracesCfg{},
+		stagedsync.TxLookupCfg{},
+		stagedsync.FinishCfg{},
+		true,
+	)
+	stateSync := stagedsync.New(stateSyncStages, stagedsync.DefaultUnwindOrder, stagedsync.DefaultPruneOrder, logger)
+	validatorKey, err := crypto.GenerateKey()
+	require.NoError(t, err)
+	validatorAddress := crypto.PubkeyToAddress(validatorKey.PublicKey)
+	h := Harness{
+		logger:           logger,
+		chainDataDb:      chainDataDb,
+		borConsensusDb:   borConsensusDb,
+		chainConfig:      cfg.ChainConfig,
+		blockReader:      blockReader,
+		stateSyncStages:  stateSyncStages,
+		stateSync:        stateSync,
+		bhCfg:            bhCfg,
+		heimdallClient:   heimdallClient,
+		sealedHeaders:    make(map[uint64]*types.Header),
+		borSpanner:       bormock.NewMockSpanner(ctrl),
+		validatorAddress: validatorAddress,
+		validatorKey:     validatorKey,
+	}
+
+	if cfg.ChainConfig.Bor != nil {
+		h.setHeimdallNextMockSpan(logger)
+		h.mockBorSpanner()
+		h.mockHeimdallClient()
+	}
+
+	h.generateChain(ctx, t, ctrl, cfg)
+
+	return h
+}
+
+type genesisInitData struct {
+	genesis                 *types.Genesis
+	genesisAllocPrivateKeys map[libcommon.Address]*ecdsa.PrivateKey
+	fundedAddresses         []libcommon.Address
+}
+
+type HarnessCfg struct {
+	ChainConfig            *chain.Config
+	GenerateChainNumBlocks int
+}
+
+type Harness struct {
+	logger               log.Logger
+	chainDataDb          kv.RwDB
+	borConsensusDb       kv.RwDB
+	chainConfig          *chain.Config
+	blockReader          *freezeblocks.BlockReader
+	stateSyncStages      []*stagedsync.Stage
+	stateSync            *stagedsync.Sync
+	bhCfg                stagedsync.BorHeimdallCfg
+	heimdallClient       *heimdallmock.MockIHeimdallClient
+	heimdallNextMockSpan *span.HeimdallSpan
+	sealedHeaders        map[uint64]*types.Header
+	borSpanner           *bormock.MockSpanner
+	validatorAddress     libcommon.Address
+	validatorKey         *ecdsa.PrivateKey
+	genesisInitData      *genesisInitData
+}
+
+func (h *Harness) SaveStageProgress(ctx context.Context, t *testing.T, stageId stages.SyncStage, progress uint64) {
+	rwTx, err := h.chainDataDb.BeginRw(ctx)
+	require.NoError(t, err)
+	defer rwTx.Rollback()
+
+	err = stages.SaveStageProgress(rwTx, stageId, progress)
+	require.NoError(t, err)
+	err = rwTx.Commit()
+	require.NoError(t, err)
+}
+
+func (h *Harness) RunStageForward(t *testing.T, id stages.SyncStage) {
+	err := h.stateSync.SetCurrentStage(id)
+	require.NoError(t, err)
+
+	stage, found := h.findStateSyncStageById(id)
+	require.True(t, found)
+
+	stageState, err := h.stateSync.StageState(id, nil, h.chainDataDb)
+	require.NoError(t, err)
+
+	err = stage.Forward(true, false, stageState, h.stateSync, nil, h.logger)
+	require.NoError(t, err)
+}
+
+func (h *Harness) ReadSpansFromDb(ctx context.Context) (spans []*span.HeimdallSpan, err error) {
+	err = h.chainDataDb.View(ctx, func(tx kv.Tx) error {
+		spanIter, err := tx.Range(kv.BorSpans, nil, nil)
+		if err != nil {
+			return err
+		}
+
+		for spanIter.HasNext() {
+			keyBytes, spanBytes, err := spanIter.Next()
+			if err != nil {
+				return err
+			}
+
+			spanKey := binary.BigEndian.Uint64(keyBytes)
+			var heimdallSpan span.HeimdallSpan
+			if err = json.Unmarshal(spanBytes, &heimdallSpan); err != nil {
+				return err
+			}
+
+			if spanKey != heimdallSpan.ID {
+				return fmt.Errorf("span key and id mismatch %d!=%d", spanKey, heimdallSpan.ID)
+			}
+
+			spans = append(spans, &heimdallSpan)
+		}
+
+		return nil
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	return spans, nil
+}
+
+func (h *Harness) createGenesisInitData(t *testing.T) *genesisInitData {
+	accountPrivateKey, err := crypto.GenerateKey()
+	require.NoError(t, err)
+	accountAddress := crypto.PubkeyToAddress(accountPrivateKey.PublicKey)
+
+	h.genesisInitData = &genesisInitData{
+		genesis: &types.Genesis{
+			Config: h.chainConfig,
+			Alloc: types.GenesisAlloc{
+				accountAddress: {
+					Balance: new(big.Int).Exp(big.NewInt(1_000), big.NewInt(18), nil),
+				},
+			},
+		},
+		genesisAllocPrivateKeys: map[libcommon.Address]*ecdsa.PrivateKey{
+			accountAddress: accountPrivateKey,
+		},
+		fundedAddresses: []libcommon.Address{
+			accountAddress,
+		},
+	}
+
+	return h.genesisInitData
+}
+
+func (h *Harness) generateChain(ctx context.Context, t *testing.T, ctrl *gomock.Controller, cfg HarnessCfg) {
+	genInitData := h.createGenesisInitData(t)
+	consensusEngine := h.consensusEngine(t, cfg)
+	genesisTmpDbDir := t.TempDir()
+	_, parentBlock, err := core.CommitGenesisBlock(h.chainDataDb, genInitData.genesis, genesisTmpDbDir, h.logger)
+	require.NoError(t, err)
+	h.sealedHeaders[parentBlock.Number().Uint64()] = parentBlock.Header()
+	mockChainHR := h.mockChainHeaderReader(ctrl)
+
+	chainPack, err := core.GenerateChain(
+		h.chainConfig,
+		parentBlock,
+		consensusEngine,
+		h.chainDataDb,
+		cfg.GenerateChainNumBlocks,
+		func(i int, gen *core.BlockGen) {
+			// seal parent block first so that we can Prepare the current header
+			if gen.GetParent().Number().Uint64() > 0 {
+				h.seal(t, mockChainHR, consensusEngine, gen.GetParent())
+			}
+
+			h.logger.Info("Preparing mock header", "headerNum", gen.GetHeader().Number)
+			gen.GetHeader().ParentHash = h.sealedHeaders[gen.GetParent().Number().Uint64()].Hash()
+			if err := consensusEngine.Prepare(mockChainHR, gen.GetHeader(), nil); err != nil {
+				t.Fatal(err)
+			}
+
+			h.logger.Info("Adding 1 mock tx to block", "blockNum", gen.GetHeader().Number)
+			chainId := uint256.Int{}
+			overflow := chainId.SetFromBig(h.chainConfig.ChainID)
+			require.False(t, overflow)
+			from := h.genesisInitData.fundedAddresses[0]
+			tx, err := types.SignTx(
+				types.NewEIP1559Transaction(
+					chainId,
+					gen.TxNonce(from),
+					from, // send to itself
+					new(uint256.Int),
+					21000,
+					new(uint256.Int),
+					new(uint256.Int),
+					uint256.NewInt(937500001),
+					nil,
+				),
+				*types.LatestSignerForChainID(h.chainConfig.ChainID),
+				h.genesisInitData.genesisAllocPrivateKeys[from],
+			)
+			require.NoError(t, err)
+			gen.AddTx(tx)
+		},
+	)
+	require.NoError(t, err)
+
+	h.seal(t, mockChainHR, consensusEngine, chainPack.TopBlock)
+	sealedHeadersList := make([]*types.Header, len(h.sealedHeaders))
+	for num, header := range h.sealedHeaders {
+		sealedHeadersList[num] = header
+	}
+
+	h.saveHeaders(ctx, t, sealedHeadersList)
+}
+
+func (h *Harness) seal(t *testing.T, chr consensus.ChainHeaderReader, eng consensus.Engine, block *types.Block) {
+	h.logger.Info("Sealing mock block", "blockNum", block.Number())
+	sealRes, sealStop := make(chan *types.Block, 1), make(chan struct{}, 1)
+	if err := eng.Seal(chr, block, sealRes, sealStop); err != nil {
+		t.Fatal(err)
+	}
+
+	sealedParentBlock := <-sealRes
+	h.sealedHeaders[sealedParentBlock.Number().Uint64()] = sealedParentBlock.Header()
+}
+
+func (h *Harness) consensusEngine(t *testing.T, cfg HarnessCfg) consensus.Engine {
+	if h.chainConfig.Bor != nil {
+		genesisContracts := contract.NewGenesisContractsClient(
+			h.chainConfig,
+			h.chainConfig.Bor.ValidatorContract,
+			h.chainConfig.Bor.StateReceiverContract,
+			h.logger,
+		)
+
+		borConsensusEng := bor.New(
+			h.chainConfig,
+			h.borConsensusDb,
+			nil,
+			h.borSpanner,
+			h.heimdallClient,
+			genesisContracts,
+			h.logger,
+		)
+
+		borConsensusEng.Authorize(h.validatorAddress, func(_ libcommon.Address, _ string, msg []byte) ([]byte, error) {
+			return crypto.Sign(crypto.Keccak256(msg), h.validatorKey)
+		})
+
+		return borConsensusEng
+	}
+
+	t.Fatal(fmt.Sprintf("unimplmented consensus engine init for cfg %v", cfg.ChainConfig))
+	return nil
+}
+
+func (h *Harness) saveHeaders(ctx context.Context, t *testing.T, headers []*types.Header) {
+	rwTx, err := h.chainDataDb.BeginRw(ctx)
+	require.NoError(t, err)
+	defer rwTx.Rollback()
+
+	for _, header := range headers {
+		err = rawdb.WriteHeader(rwTx, header)
+		require.NoError(t, err)
+
+		err = rawdb.WriteCanonicalHash(rwTx, header.Hash(), header.Number.Uint64())
+		require.NoError(t, err)
+	}
+
+	err = rwTx.Commit()
+	require.NoError(t, err)
+}
+
+func (h *Harness) mockChainHeaderReader(ctrl *gomock.Controller) consensus.ChainHeaderReader {
+	mockChainHR := consensusmock.NewMockChainHeaderReader(ctrl)
+	mockChainHR.
+		EXPECT().
+		GetHeader(gomock.Any(), gomock.Any()).
+		DoAndReturn(func(_ libcommon.Hash, number uint64) *types.Header {
+			return h.sealedHeaders[number]
+		}).
+		AnyTimes()
+
+	mockChainHR.
+		EXPECT().
+		GetHeaderByNumber(gomock.Any()).
+		DoAndReturn(func(number uint64) *types.Header {
+			return h.sealedHeaders[number]
+		}).
+		AnyTimes()
+
+	mockChainHR.
+		EXPECT().
+		FrozenBlocks().
+		Return(uint64(0)).
+		AnyTimes()
+
+	return mockChainHR
+}
+
+func (h *Harness) setHeimdallNextMockSpan(logger log.Logger) {
+	validators := []*valset.Validator{
+		{
+			ID:               1,
+			Address:          h.validatorAddress,
+			VotingPower:      1000,
+			ProposerPriority: 1,
+		},
+	}
+
+	validatorSet := valset.NewValidatorSet(validators, logger)
+	selectedProducers := make([]valset.Validator, len(validators))
+	for i := range validators {
+		selectedProducers[i] = *validators[i]
+	}
+
+	h.heimdallNextMockSpan = &span.HeimdallSpan{
+		Span: span.Span{
+			ID:         0,
+			StartBlock: 0,
+			EndBlock:   255,
+		},
+		ValidatorSet:      *validatorSet,
+		SelectedProducers: selectedProducers,
+	}
+}
+
+func (h *Harness) mockBorSpanner() {
+	h.borSpanner.
+		EXPECT().
+		GetCurrentValidators(gomock.Any(), gomock.Any(), gomock.Any()).
+		Return(h.heimdallNextMockSpan.ValidatorSet.Validators, nil).
+		AnyTimes()
+
+	h.borSpanner.
+		EXPECT().
+		GetCurrentProducers(gomock.Any(), gomock.Any(), gomock.Any()).
+		DoAndReturn(func(_ uint64, _ libcommon.Address, _ consensus.ChainHeaderReader) ([]*valset.Validator, error) {
+			res := make([]*valset.Validator, len(h.heimdallNextMockSpan.SelectedProducers))
+			for i := range h.heimdallNextMockSpan.SelectedProducers {
+				res[i] = &h.heimdallNextMockSpan.SelectedProducers[i]
+			}
+
+			return res, nil
+		}).
+		AnyTimes()
+}
+
+func (h *Harness) mockHeimdallClient() {
+	h.heimdallClient.
+		EXPECT().
+		Span(gomock.Any(), gomock.Any()).
+		DoAndReturn(func(ctx context.Context, spanID uint64) (*span.HeimdallSpan, error) {
+			res := h.heimdallNextMockSpan
+			h.heimdallNextMockSpan = &span.HeimdallSpan{
+				Span: span.Span{
+					ID:         res.ID + 1,
+					StartBlock: res.EndBlock + 1,
+					EndBlock:   res.EndBlock + 6400,
+				},
+				ValidatorSet:      res.ValidatorSet,
+				SelectedProducers: res.SelectedProducers,
+			}
+
+			return res, nil
+		}).
+		AnyTimes()
+
+	h.heimdallClient.
+		EXPECT().
+		StateSyncEvents(gomock.Any(), gomock.Any(), gomock.Any()).
+		DoAndReturn(func(ctx context.Context, fromID uint64, to int64) ([]*clerk.EventRecordWithTime, error) {
+			return nil, nil
+		}).
+		AnyTimes()
+}
+
+func (h *Harness) findStateSyncStageById(id stages.SyncStage) (*stagedsync.Stage, bool) {
+	for _, s := range h.stateSyncStages {
+		if s.ID == id {
+			return s, true
+		}
+	}
+
+	return nil, false
+}