From 2521f47e7b848a5a01a6bf1774aac56a041e0438 Mon Sep 17 00:00:00 2001 From: battlmonstr Date: Mon, 8 Jan 2024 15:55:43 +0100 Subject: [PATCH] polygon/sync: canonical chain builder unit tests (#9137) --- polygon/sync/canonical_chain_builder_test.go | 237 ++++++++++++++++++- polygon/sync/difficulty.go | 52 +++- polygon/sync/difficulty_test.go | 109 ++++++++- 3 files changed, 377 insertions(+), 21 deletions(-) diff --git a/polygon/sync/canonical_chain_builder_test.go b/polygon/sync/canonical_chain_builder_test.go index aa640f3c4..802009e2c 100644 --- a/polygon/sync/canonical_chain_builder_test.go +++ b/polygon/sync/canonical_chain_builder_test.go @@ -1,23 +1,246 @@ package sync import ( + "bytes" + "errors" + "math/big" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + heimdallspan "github.com/ledgerwatch/erigon/consensus/bor/heimdall/span" "github.com/ledgerwatch/erigon/core/types" ) type testDifficultyCalculator struct { } -func (*testDifficultyCalculator) HeaderDifficulty(*types.Header) (uint64, error) { - return 0, nil +func (*testDifficultyCalculator) HeaderDifficulty(header *types.Header) (uint64, error) { + if header.Difficulty == nil { + return 0, errors.New("unset header.Difficulty") + } + return header.Difficulty.Uint64(), nil } -func TestCanonicalChainBuilderConnectEmpty(t *testing.T) { - difficultyCalc := testDifficultyCalculator{} - builder := NewCanonicalChainBuilder(new(types.Header), &difficultyCalc) - err := builder.Connect([]*types.Header{}) - require.Nil(t, err) +func (*testDifficultyCalculator) SetSpan(*heimdallspan.HeimdallSpan) {} + +func makeRoot() *types.Header { + return &types.Header{ + Number: big.NewInt(0), + } +} + +func makeCCB(root *types.Header) CanonicalChainBuilder { + difficultyCalc := testDifficultyCalculator{} + builder := NewCanonicalChainBuilder(root, &difficultyCalc) + return builder +} + +type connectCCBTest struct { + t *testing.T + root *types.Header + builder CanonicalChainBuilder + + currentHeaderTime uint64 +} + +func newConnectCCBTest(t *testing.T) (*connectCCBTest, *types.Header) { + root := makeRoot() + builder := makeCCB(root) + test := &connectCCBTest{ + t: t, + root: root, + builder: builder, + } + return test, root +} + +func (test *connectCCBTest) makeHeader(parent *types.Header, difficulty uint64) *types.Header { + test.currentHeaderTime++ + return &types.Header{ + ParentHash: parent.Hash(), + Difficulty: big.NewInt(int64(difficulty)), + Number: big.NewInt(parent.Number.Int64() + 1), + Time: test.currentHeaderTime, + Extra: bytes.Repeat([]byte{0x00}, types.ExtraVanityLength+types.ExtraSealLength), + } +} + +func (test *connectCCBTest) makeHeaders(parent *types.Header, difficulties []uint64) []*types.Header { + count := len(difficulties) + headers := make([]*types.Header, 0, count) + for i := 0; i < count; i++ { + header := test.makeHeader(parent, difficulties[i]) + headers = append(headers, header) + parent = header + } + return headers +} + +func (test *connectCCBTest) testConnect( + headers []*types.Header, + expectedTip *types.Header, + expectedHeaders []*types.Header, +) { + t := test.t + builder := test.builder + + err := builder.Connect(headers) + require.Nil(t, err) + + newTip := builder.Tip() + assert.Equal(t, expectedTip.Hash(), newTip.Hash()) + + require.NotNil(t, newTip.Number) + count := uint64(len(expectedHeaders)) + start := newTip.Number.Uint64() - (count - 1) + + actualHeaders := builder.HeadersInRange(start, count) + require.Equal(t, len(expectedHeaders), len(actualHeaders)) + for i, h := range actualHeaders { + assert.Equal(t, expectedHeaders[i].Hash(), h.Hash()) + } +} + +func TestCCBEmptyState(t *testing.T) { + test, root := newConnectCCBTest(t) + + tip := test.builder.Tip() + assert.Equal(t, root.Hash(), tip.Hash()) + + headers := test.builder.HeadersInRange(0, 1) + require.Equal(t, 1, len(headers)) + assert.Equal(t, root.Hash(), headers[0].Hash()) +} + +func TestCCBConnectEmpty(t *testing.T) { + test, root := newConnectCCBTest(t) + test.testConnect([]*types.Header{}, root, []*types.Header{root}) +} + +// connect 0 to 0 +func TestCCBConnectRoot(t *testing.T) { + test, root := newConnectCCBTest(t) + test.testConnect([]*types.Header{root}, root, []*types.Header{root}) +} + +// connect 1 to 0 +func TestCCBConnectOneToRoot(t *testing.T) { + test, root := newConnectCCBTest(t) + newTip := test.makeHeader(root, 1) + test.testConnect([]*types.Header{newTip}, newTip, []*types.Header{root, newTip}) +} + +// connect 1-2-3 to 0 +func TestCCBConnectSomeToRoot(t *testing.T) { + test, root := newConnectCCBTest(t) + headers := test.makeHeaders(root, []uint64{1, 2, 3}) + test.testConnect(headers, headers[len(headers)-1], append([]*types.Header{root}, headers...)) +} + +// connect any subset of 0-1-2-3 to 0-1-2-3 +func TestCCBConnectOverlapsFull(t *testing.T) { + test, root := newConnectCCBTest(t) + headers := test.makeHeaders(root, []uint64{1, 2, 3}) + require.Nil(t, test.builder.Connect(headers)) + + expectedTip := headers[len(headers)-1] + expectedHeaders := append([]*types.Header{root}, headers...) + + for subsetLen := 1; subsetLen <= len(headers); subsetLen++ { + for i := 0; i+subsetLen-1 < len(expectedHeaders); i++ { + headers := expectedHeaders[i : i+subsetLen] + test.testConnect(headers, expectedTip, expectedHeaders) + } + } +} + +// connect 0-1 to 0 +func TestCCBConnectOverlapPartialOne(t *testing.T) { + test, root := newConnectCCBTest(t) + newTip := test.makeHeader(root, 1) + test.testConnect([]*types.Header{root, newTip}, newTip, []*types.Header{root, newTip}) +} + +// connect 2-3-4-5 to 0-1-2-3 +func TestCCBConnectOverlapPartialSome(t *testing.T) { + test, root := newConnectCCBTest(t) + headers := test.makeHeaders(root, []uint64{1, 2, 3}) + require.Nil(t, test.builder.Connect(headers)) + + overlapHeaders := append(headers[1:], test.makeHeaders(headers[len(headers)-1], []uint64{4, 5})...) + expectedTip := overlapHeaders[len(overlapHeaders)-1] + expectedHeaders := append([]*types.Header{root, headers[0]}, overlapHeaders...) + test.testConnect(overlapHeaders, expectedTip, expectedHeaders) +} + +// connect 2 to 0-1 at 0, then connect 10 to 0-1 +func TestCCBConnectAltMainBecomesFork(t *testing.T) { + test, root := newConnectCCBTest(t) + header1 := test.makeHeader(root, 1) + header2 := test.makeHeader(root, 2) + require.Nil(t, test.builder.Connect([]*types.Header{header1})) + + // the tip changes to header2 + test.testConnect([]*types.Header{header2}, header2, []*types.Header{root, header2}) + + header10 := test.makeHeader(header1, 10) + test.testConnect([]*types.Header{header10}, header10, []*types.Header{root, header1, header10}) +} + +// connect 1 to 0-2 at 0, then connect 10 to 0-1 +func TestCCBConnectAltForkBecomesMain(t *testing.T) { + test, root := newConnectCCBTest(t) + header1 := test.makeHeader(root, 1) + header2 := test.makeHeader(root, 2) + require.Nil(t, test.builder.Connect([]*types.Header{header2})) + + // the tip stays at header2 + test.testConnect([]*types.Header{header1}, header2, []*types.Header{root, header2}) + + header10 := test.makeHeader(header1, 10) + test.testConnect([]*types.Header{header10}, header10, []*types.Header{root, header1, header10}) +} + +// connect 10 and 11 to 1, then 20 and 22 to 2 one by one starting from a [0-1, 0-2] tree +func TestCCBConnectAltForksAtLevel2(t *testing.T) { + test, root := newConnectCCBTest(t) + header1 := test.makeHeader(root, 1) + header10 := test.makeHeader(header1, 10) + header11 := test.makeHeader(header1, 11) + header2 := test.makeHeader(root, 2) + header20 := test.makeHeader(header2, 20) + header22 := test.makeHeader(header2, 22) + require.Nil(t, test.builder.Connect([]*types.Header{header1})) + require.Nil(t, test.builder.Connect([]*types.Header{header2})) + + test.testConnect([]*types.Header{header10}, header10, []*types.Header{root, header1, header10}) + test.testConnect([]*types.Header{header11}, header11, []*types.Header{root, header1, header11}) + test.testConnect([]*types.Header{header20}, header20, []*types.Header{root, header2, header20}) + test.testConnect([]*types.Header{header22}, header22, []*types.Header{root, header2, header22}) +} + +// connect 11 and 10 to 1, then 22 and 20 to 2 one by one starting from a [0-1, 0-2] tree +// then connect 100 to 10, and 200 to 20 +func TestCCBConnectAltForksAtLevel2Reverse(t *testing.T) { + test, root := newConnectCCBTest(t) + header1 := test.makeHeader(root, 1) + header10 := test.makeHeader(header1, 10) + header11 := test.makeHeader(header1, 11) + header2 := test.makeHeader(root, 2) + header20 := test.makeHeader(header2, 20) + header22 := test.makeHeader(header2, 22) + header100 := test.makeHeader(header10, 100) + header200 := test.makeHeader(header20, 200) + require.Nil(t, test.builder.Connect([]*types.Header{header1})) + require.Nil(t, test.builder.Connect([]*types.Header{header2})) + + test.testConnect([]*types.Header{header11}, header11, []*types.Header{root, header1, header11}) + test.testConnect([]*types.Header{header10}, header11, []*types.Header{root, header1, header11}) + test.testConnect([]*types.Header{header22}, header22, []*types.Header{root, header2, header22}) + test.testConnect([]*types.Header{header20}, header22, []*types.Header{root, header2, header22}) + + test.testConnect([]*types.Header{header100}, header100, []*types.Header{root, header1, header10, header100}) + test.testConnect([]*types.Header{header200}, header200, []*types.Header{root, header2, header20, header200}) } diff --git a/polygon/sync/difficulty.go b/polygon/sync/difficulty.go index 7a6895c50..c63205002 100644 --- a/polygon/sync/difficulty.go +++ b/polygon/sync/difficulty.go @@ -2,9 +2,10 @@ package sync import ( lru "github.com/hashicorp/golang-lru/arc/v2" - "github.com/ledgerwatch/erigon/eth/stagedsync" "github.com/ledgerwatch/log/v3" + "github.com/ledgerwatch/erigon/eth/stagedsync" + libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon/consensus/bor" "github.com/ledgerwatch/erigon/consensus/bor/borcfg" @@ -15,32 +16,56 @@ import ( type DifficultyCalculator interface { HeaderDifficulty(header *types.Header) (uint64, error) + SetSpan(span *heimdallspan.HeimdallSpan) } type difficultyCalculatorImpl struct { - borConfig *borcfg.BorConfig - span *heimdallspan.HeimdallSpan - signaturesCache *lru.ARCCache[libcommon.Hash, libcommon.Address] + borConfig *borcfg.BorConfig + span *heimdallspan.HeimdallSpan + validatorSetFactory func() validatorSetInterface + signaturesCache *lru.ARCCache[libcommon.Hash, libcommon.Address] log log.Logger } +// valset.ValidatorSet abstraction for unit tests +type validatorSetInterface interface { + IncrementProposerPriority(times int, logger log.Logger) + Difficulty(signer libcommon.Address) (uint64, error) +} + func NewDifficultyCalculator( borConfig *borcfg.BorConfig, span *heimdallspan.HeimdallSpan, + validatorSetFactory func() validatorSetInterface, log log.Logger, ) DifficultyCalculator { signaturesCache, err := lru.NewARC[libcommon.Hash, libcommon.Address](stagedsync.InMemorySignatures) if err != nil { panic(err) } - return &difficultyCalculatorImpl{ - borConfig: borConfig, - span: span, - signaturesCache: signaturesCache, + impl := difficultyCalculatorImpl{ + borConfig: borConfig, + span: span, + validatorSetFactory: validatorSetFactory, + signaturesCache: signaturesCache, log: log, } + + if validatorSetFactory == nil { + impl.validatorSetFactory = impl.makeValidatorSet + } + + return &impl +} + +func (impl *difficultyCalculatorImpl) makeValidatorSet() validatorSetInterface { + return valset.NewValidatorSet(impl.span.ValidatorSet.Validators, impl.log) +} + +func (impl *difficultyCalculatorImpl) SetSpan(span *heimdallspan.HeimdallSpan) { + impl.span = span } func (impl *difficultyCalculatorImpl) HeaderDifficulty(header *types.Header) (uint64, error) { @@ -48,12 +73,15 @@ func (impl *difficultyCalculatorImpl) HeaderDifficulty(header *types.Header) (ui if err != nil { return 0, err } + return impl.signerDifficulty(signer, header.Number.Uint64()) +} - validatorSet := valset.NewValidatorSet(impl.span.ValidatorSet.Validators, log.New()) +func (impl *difficultyCalculatorImpl) signerDifficulty(signer libcommon.Address, headerNum uint64) (uint64, error) { + validatorSet := impl.validatorSetFactory() - sprintCount := impl.borConfig.CalculateSprintNumber(header.Number.Uint64()) - if sprintCount > 0 { - validatorSet.IncrementProposerPriority(int(sprintCount), impl.log) + sprintNum := impl.borConfig.CalculateSprintNumber(headerNum) + if sprintNum > 0 { + validatorSet.IncrementProposerPriority(int(sprintNum), impl.log) } return validatorSet.Difficulty(signer) diff --git a/polygon/sync/difficulty_test.go b/polygon/sync/difficulty_test.go index 69d684149..77b0711c3 100644 --- a/polygon/sync/difficulty_test.go +++ b/polygon/sync/difficulty_test.go @@ -1,21 +1,126 @@ package sync import ( - "github.com/ledgerwatch/log/v3" "testing" + "github.com/ledgerwatch/log/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon/consensus/bor/borcfg" heimdallspan "github.com/ledgerwatch/erigon/consensus/bor/heimdall/span" "github.com/ledgerwatch/erigon/core/types" ) +type testValidatorSetInterface struct { + signers []libcommon.Address + sprintNum int +} + +func (v *testValidatorSetInterface) IncrementProposerPriority(times int, _ log.Logger) { + v.sprintNum = times +} + +func (v *testValidatorSetInterface) Difficulty(signer libcommon.Address) (uint64, error) { + var i int + for (i < len(v.signers)) && (v.signers[i] != signer) { + i++ + } + + sprintOffset := v.sprintNum % len(v.signers) + var delta int + if i >= sprintOffset { + delta = i - sprintOffset + } else { + delta = i + len(v.signers) - sprintOffset + } + + return uint64(len(v.signers) - delta), nil +} + +func TestSignerDifficulty(t *testing.T) { + borConfig := borcfg.BorConfig{ + Sprint: map[string]uint64{"0": 16}, + } + span := heimdallspan.HeimdallSpan{} + signers := []libcommon.Address{ + libcommon.HexToAddress("00"), + libcommon.HexToAddress("01"), + libcommon.HexToAddress("02"), + } + validatorSetFactory := func() validatorSetInterface { return &testValidatorSetInterface{signers: signers} } + logger := log.New() + calc := NewDifficultyCalculator(&borConfig, &span, validatorSetFactory, logger).(*difficultyCalculatorImpl) + + var d uint64 + + // sprint 0 + d, _ = calc.signerDifficulty(signers[0], 0) + assert.Equal(t, uint64(3), d) + + d, _ = calc.signerDifficulty(signers[0], 1) + assert.Equal(t, uint64(3), d) + + d, _ = calc.signerDifficulty(signers[0], 15) + assert.Equal(t, uint64(3), d) + + d, _ = calc.signerDifficulty(signers[1], 0) + assert.Equal(t, uint64(2), d) + + d, _ = calc.signerDifficulty(signers[1], 1) + assert.Equal(t, uint64(2), d) + + d, _ = calc.signerDifficulty(signers[1], 15) + assert.Equal(t, uint64(2), d) + + d, _ = calc.signerDifficulty(signers[2], 0) + assert.Equal(t, uint64(1), d) + + d, _ = calc.signerDifficulty(signers[2], 1) + assert.Equal(t, uint64(1), d) + + d, _ = calc.signerDifficulty(signers[2], 15) + assert.Equal(t, uint64(1), d) + + // sprint 1 + d, _ = calc.signerDifficulty(signers[1], 16) + assert.Equal(t, uint64(3), d) + + d, _ = calc.signerDifficulty(signers[2], 16) + assert.Equal(t, uint64(2), d) + + d, _ = calc.signerDifficulty(signers[0], 16) + assert.Equal(t, uint64(1), d) + + // sprint 2 + d, _ = calc.signerDifficulty(signers[2], 32) + assert.Equal(t, uint64(3), d) + + d, _ = calc.signerDifficulty(signers[0], 32) + assert.Equal(t, uint64(2), d) + + d, _ = calc.signerDifficulty(signers[1], 32) + assert.Equal(t, uint64(1), d) + + // sprint 3 + d, _ = calc.signerDifficulty(signers[0], 48) + assert.Equal(t, uint64(3), d) + + d, _ = calc.signerDifficulty(signers[1], 48) + assert.Equal(t, uint64(2), d) + + d, _ = calc.signerDifficulty(signers[2], 48) + assert.Equal(t, uint64(1), d) +} + func TestHeaderDifficultyNoSignature(t *testing.T) { borConfig := borcfg.BorConfig{} span := heimdallspan.HeimdallSpan{} logger := log.New() - calc := NewDifficultyCalculator(&borConfig, &span, logger) + calc := NewDifficultyCalculator(&borConfig, &span, nil, logger) + _, err := calc.HeaderDifficulty(new(types.Header)) require.ErrorContains(t, err, "signature suffix missing") }