From c5b75d00ca469b0fe84eadaa3437cd897d32b44d Mon Sep 17 00:00:00 2001
From: battlmonstr <battlmonstr@users.noreply.github.com>
Date: Mon, 15 Jan 2024 12:36:25 +0100
Subject: [PATCH] polygon/sync: span updates (#9229)

It is possible that a span update happens during a milestone.
A headers slice might cross to the new span.
Also if 2 forks evolve simulaneously, a shorter fork can still be in the
previous span.
In these cases we need access to the previous span to calculate
difficulty and validate header times.

SpansCache will keep recent spans.
The cache will be updated on new span events from the heimdall.
The cache is pruned on new milestone events and in practice no more than
2 spans are kept.

The header difficulty calculation and time validation depends on having
a span for that header in the cache.
---
 polygon/sync/canonical_chain_builder.go      | 17 ++++++---
 polygon/sync/canonical_chain_builder_test.go |  2 +-
 polygon/sync/difficulty.go                   | 32 +++++++++--------
 polygon/sync/difficulty_test.go              | 17 ++++++---
 polygon/sync/header_time_validator.go        | 33 ++++++++++--------
 polygon/sync/spans_cache.go                  | 36 ++++++++++++++++++++
 6 files changed, 96 insertions(+), 41 deletions(-)
 create mode 100644 polygon/sync/spans_cache.go

diff --git a/polygon/sync/canonical_chain_builder.go b/polygon/sync/canonical_chain_builder.go
index 2c6d63285..b6575544b 100644
--- a/polygon/sync/canonical_chain_builder.go
+++ b/polygon/sync/canonical_chain_builder.go
@@ -37,20 +37,21 @@ type canonicalChainBuilderImpl struct {
 	root *forkTreeNode
 	tip  *forkTreeNode
 
-	difficultyCalc DifficultyCalculator
-
+	difficultyCalc  DifficultyCalculator
 	headerValidator HeaderValidator
+	spansCache      *SpansCache
 }
 
 func NewCanonicalChainBuilder(
 	root *types.Header,
 	difficultyCalc DifficultyCalculator,
 	headerValidator HeaderValidator,
+	spansCache *SpansCache,
 ) CanonicalChainBuilder {
 	impl := &canonicalChainBuilderImpl{
-		difficultyCalc: difficultyCalc,
-
+		difficultyCalc:  difficultyCalc,
 		headerValidator: headerValidator,
+		spansCache:      spansCache,
 	}
 	impl.Reset(root)
 	return impl
@@ -63,6 +64,9 @@ func (impl *canonicalChainBuilderImpl) Reset(root *types.Header) {
 		headerHash: root.Hash(),
 	}
 	impl.tip = impl.root
+	if impl.spansCache != nil {
+		impl.spansCache.Prune(root.Number.Uint64())
+	}
 }
 
 // depth-first search
@@ -138,8 +142,11 @@ func (impl *canonicalChainBuilderImpl) Prune(newRootNum uint64) error {
 	for newRoot.header.Number.Uint64() > newRootNum {
 		newRoot = newRoot.parent
 	}
-
 	impl.root = newRoot
+
+	if impl.spansCache != nil {
+		impl.spansCache.Prune(newRootNum)
+	}
 	return nil
 }
 
diff --git a/polygon/sync/canonical_chain_builder_test.go b/polygon/sync/canonical_chain_builder_test.go
index 86392a8c6..0fb061732 100644
--- a/polygon/sync/canonical_chain_builder_test.go
+++ b/polygon/sync/canonical_chain_builder_test.go
@@ -34,7 +34,7 @@ func makeRoot() *types.Header {
 
 func makeCCB(root *types.Header) CanonicalChainBuilder {
 	difficultyCalc := testDifficultyCalculator{}
-	builder := NewCanonicalChainBuilder(root, &difficultyCalc, nil)
+	builder := NewCanonicalChainBuilder(root, &difficultyCalc, nil, nil)
 	return builder
 }
 
diff --git a/polygon/sync/difficulty.go b/polygon/sync/difficulty.go
index 1ccb4c0b1..7880ade67 100644
--- a/polygon/sync/difficulty.go
+++ b/polygon/sync/difficulty.go
@@ -1,12 +1,12 @@
 package sync
 
 import (
+	"fmt"
+
 	lru "github.com/hashicorp/golang-lru/arc/v2"
 
 	"github.com/ledgerwatch/erigon/eth/stagedsync"
 
-	heimdallspan "github.com/ledgerwatch/erigon/polygon/heimdall/span"
-
 	libcommon "github.com/ledgerwatch/erigon-lib/common"
 	"github.com/ledgerwatch/erigon/core/types"
 	"github.com/ledgerwatch/erigon/polygon/bor"
@@ -16,20 +16,19 @@ import (
 
 type DifficultyCalculator interface {
 	HeaderDifficulty(header *types.Header) (uint64, error)
-	SetSpan(span *heimdallspan.HeimdallSpan)
 }
 
 type difficultyCalculatorImpl struct {
 	borConfig           *borcfg.BorConfig
-	span                *heimdallspan.HeimdallSpan
-	validatorSetFactory func() validatorSetInterface
+	spans               *SpansCache
+	validatorSetFactory func(headerNum uint64) validatorSetInterface
 	signaturesCache     *lru.ARCCache[libcommon.Hash, libcommon.Address]
 }
 
 func NewDifficultyCalculator(
 	borConfig *borcfg.BorConfig,
-	span *heimdallspan.HeimdallSpan,
-	validatorSetFactory func() validatorSetInterface,
+	spans *SpansCache,
+	validatorSetFactory func(headerNum uint64) validatorSetInterface,
 	signaturesCache *lru.ARCCache[libcommon.Hash, libcommon.Address],
 ) DifficultyCalculator {
 	if signaturesCache == nil {
@@ -42,7 +41,7 @@ func NewDifficultyCalculator(
 
 	impl := difficultyCalculatorImpl{
 		borConfig:           borConfig,
-		span:                span,
+		spans:               spans,
 		validatorSetFactory: validatorSetFactory,
 		signaturesCache:     signaturesCache,
 	}
@@ -54,12 +53,12 @@ func NewDifficultyCalculator(
 	return &impl
 }
 
-func (impl *difficultyCalculatorImpl) makeValidatorSet() validatorSetInterface {
-	return valset.NewValidatorSet(impl.span.ValidatorSet.Validators)
-}
-
-func (impl *difficultyCalculatorImpl) SetSpan(span *heimdallspan.HeimdallSpan) {
-	impl.span = span
+func (impl *difficultyCalculatorImpl) makeValidatorSet(headerNum uint64) validatorSetInterface {
+	span := impl.spans.SpanAt(headerNum)
+	if span == nil {
+		return nil
+	}
+	return valset.NewValidatorSet(span.ValidatorSet.Validators)
 }
 
 func (impl *difficultyCalculatorImpl) HeaderDifficulty(header *types.Header) (uint64, error) {
@@ -71,7 +70,10 @@ func (impl *difficultyCalculatorImpl) HeaderDifficulty(header *types.Header) (ui
 }
 
 func (impl *difficultyCalculatorImpl) signerDifficulty(signer libcommon.Address, headerNum uint64) (uint64, error) {
-	validatorSet := impl.validatorSetFactory()
+	validatorSet := impl.validatorSetFactory(headerNum)
+	if validatorSet == nil {
+		return 0, fmt.Errorf("difficultyCalculatorImpl.signerDifficulty: no span at %d", headerNum)
+	}
 
 	sprintNum := impl.borConfig.CalculateSprintNumber(headerNum)
 	if sprintNum > 0 {
diff --git a/polygon/sync/difficulty_test.go b/polygon/sync/difficulty_test.go
index 4a3f7f39f..669b8dbcf 100644
--- a/polygon/sync/difficulty_test.go
+++ b/polygon/sync/difficulty_test.go
@@ -7,8 +7,6 @@ import (
 
 	"github.com/stretchr/testify/require"
 
-	heimdallspan "github.com/ledgerwatch/erigon/polygon/heimdall/span"
-
 	libcommon "github.com/ledgerwatch/erigon-lib/common"
 	"github.com/ledgerwatch/erigon/core/types"
 	"github.com/ledgerwatch/erigon/polygon/bor/borcfg"
@@ -57,7 +55,7 @@ func TestSignerDifficulty(t *testing.T) {
 		libcommon.HexToAddress("01"),
 		libcommon.HexToAddress("02"),
 	}
-	validatorSetFactory := func() validatorSetInterface { return &testValidatorSetInterface{signers: signers} }
+	validatorSetFactory := func(uint64) validatorSetInterface { return &testValidatorSetInterface{signers: signers} }
 	calc := NewDifficultyCalculator(&borConfig, nil, validatorSetFactory, nil).(*difficultyCalculatorImpl)
 
 	var d uint64
@@ -123,9 +121,18 @@ func TestSignerDifficulty(t *testing.T) {
 
 func TestHeaderDifficultyNoSignature(t *testing.T) {
 	borConfig := borcfg.BorConfig{}
-	span := heimdallspan.HeimdallSpan{}
-	calc := NewDifficultyCalculator(&borConfig, &span, nil, nil)
+	spans := NewSpansCache()
+	calc := NewDifficultyCalculator(&borConfig, spans, nil, nil)
 
 	_, err := calc.HeaderDifficulty(new(types.Header))
 	require.ErrorContains(t, err, "signature suffix missing")
 }
+
+func TestSignerDifficultyNoSpan(t *testing.T) {
+	borConfig := borcfg.BorConfig{}
+	spans := NewSpansCache()
+	calc := NewDifficultyCalculator(&borConfig, spans, nil, nil).(*difficultyCalculatorImpl)
+
+	_, err := calc.signerDifficulty(libcommon.HexToAddress("00"), 0)
+	require.ErrorContains(t, err, "no span")
+}
diff --git a/polygon/sync/header_time_validator.go b/polygon/sync/header_time_validator.go
index 8b84e36c0..d2da61764 100644
--- a/polygon/sync/header_time_validator.go
+++ b/polygon/sync/header_time_validator.go
@@ -1,6 +1,7 @@
 package sync
 
 import (
+	"fmt"
 	"time"
 
 	lru "github.com/hashicorp/golang-lru/arc/v2"
@@ -11,25 +12,23 @@ import (
 	"github.com/ledgerwatch/erigon/polygon/bor"
 	"github.com/ledgerwatch/erigon/polygon/bor/borcfg"
 	"github.com/ledgerwatch/erigon/polygon/bor/valset"
-	heimdallspan "github.com/ledgerwatch/erigon/polygon/heimdall/span"
 )
 
 type HeaderTimeValidator interface {
 	ValidateHeaderTime(header *types.Header, now time.Time, parent *types.Header) error
-	SetSpan(span *heimdallspan.HeimdallSpan)
 }
 
 type headerTimeValidatorImpl struct {
 	borConfig           *borcfg.BorConfig
-	span                *heimdallspan.HeimdallSpan
-	validatorSetFactory func() validatorSetInterface
+	spans               *SpansCache
+	validatorSetFactory func(headerNum uint64) validatorSetInterface
 	signaturesCache     *lru.ARCCache[libcommon.Hash, libcommon.Address]
 }
 
 func NewHeaderTimeValidator(
 	borConfig *borcfg.BorConfig,
-	span *heimdallspan.HeimdallSpan,
-	validatorSetFactory func() validatorSetInterface,
+	spans *SpansCache,
+	validatorSetFactory func(headerNum uint64) validatorSetInterface,
 	signaturesCache *lru.ARCCache[libcommon.Hash, libcommon.Address],
 ) HeaderTimeValidator {
 	if signaturesCache == nil {
@@ -42,7 +41,7 @@ func NewHeaderTimeValidator(
 
 	impl := headerTimeValidatorImpl{
 		borConfig:           borConfig,
-		span:                span,
+		spans:               spans,
 		validatorSetFactory: validatorSetFactory,
 		signaturesCache:     signaturesCache,
 	}
@@ -54,18 +53,22 @@ func NewHeaderTimeValidator(
 	return &impl
 }
 
-func (impl *headerTimeValidatorImpl) makeValidatorSet() validatorSetInterface {
-	return valset.NewValidatorSet(impl.span.ValidatorSet.Validators)
-}
-
-func (impl *headerTimeValidatorImpl) SetSpan(span *heimdallspan.HeimdallSpan) {
-	impl.span = span
+func (impl *headerTimeValidatorImpl) makeValidatorSet(headerNum uint64) validatorSetInterface {
+	span := impl.spans.SpanAt(headerNum)
+	if span == nil {
+		return nil
+	}
+	return valset.NewValidatorSet(span.ValidatorSet.Validators)
 }
 
 func (impl *headerTimeValidatorImpl) ValidateHeaderTime(header *types.Header, now time.Time, parent *types.Header) error {
-	validatorSet := impl.validatorSetFactory()
+	headerNum := header.Number.Uint64()
+	validatorSet := impl.validatorSetFactory(headerNum)
+	if validatorSet == nil {
+		return fmt.Errorf("headerTimeValidatorImpl.ValidateHeaderTime: no span at %d", headerNum)
+	}
 
-	sprintNum := impl.borConfig.CalculateSprintNumber(header.Number.Uint64())
+	sprintNum := impl.borConfig.CalculateSprintNumber(headerNum)
 	if sprintNum > 0 {
 		validatorSet.IncrementProposerPriority(int(sprintNum))
 	}
diff --git a/polygon/sync/spans_cache.go b/polygon/sync/spans_cache.go
new file mode 100644
index 000000000..5c03a4229
--- /dev/null
+++ b/polygon/sync/spans_cache.go
@@ -0,0 +1,36 @@
+package sync
+
+import heimdallspan "github.com/ledgerwatch/erigon/polygon/heimdall/span"
+
+type SpansCache struct {
+	spans map[uint64]*heimdallspan.HeimdallSpan
+}
+
+func NewSpansCache() *SpansCache {
+	return &SpansCache{
+		spans: make(map[uint64]*heimdallspan.HeimdallSpan),
+	}
+}
+
+func (cache *SpansCache) Add(span *heimdallspan.HeimdallSpan) {
+	cache.spans[span.StartBlock] = span
+}
+
+// SpanAt finds a span that contains blockNum.
+func (cache *SpansCache) SpanAt(blockNum uint64) *heimdallspan.HeimdallSpan {
+	for _, span := range cache.spans {
+		if (span.StartBlock <= blockNum) && (blockNum <= span.EndBlock) {
+			return span
+		}
+	}
+	return nil
+}
+
+// Prune removes spans that ended before blockNum.
+func (cache *SpansCache) Prune(blockNum uint64) {
+	for key, span := range cache.spans {
+		if span.EndBlock < blockNum {
+			delete(cache.spans, key)
+		}
+	}
+}