diff --git a/eth/stagedsync/stage_headers.go b/eth/stagedsync/stage_headers.go index 66c0243fd..b21fd32ad 100644 --- a/eth/stagedsync/stage_headers.go +++ b/eth/stagedsync/stage_headers.go @@ -264,14 +264,14 @@ func HeadersUnwind(u *UnwindState, s *StageState, tx ethdb.RwTx, cfg HeadersCfg) if hash, err = rawdb.ReadCanonicalHash(tx, blockHeight); err != nil { return err } - cfg.hd.BadHeaders[hash] = struct{}{} + cfg.hd.ReportBadHeader(hash) } if err = rawdb.DeleteCanonicalHash(tx, blockHeight); err != nil { return err } } if u.BadBlock != (common.Hash{}) { - cfg.hd.BadHeaders[u.BadBlock] = struct{}{} + cfg.hd.ReportBadHeader(u.BadBlock) // Find header with biggest TD tdCursor, cErr := tx.Cursor(dbutils.HeaderTDBucket) if cErr != nil { @@ -292,7 +292,7 @@ func HeadersUnwind(u *UnwindState, s *StageState, tx ethdb.RwTx, cfg HeadersCfg) } var hash common.Hash copy(hash[:], k[8:]) - if _, bad := cfg.hd.BadHeaders[hash]; bad { + if cfg.hd.IsBadHeader(hash) { continue } var td big.Int diff --git a/turbo/stages/headerdownload/header_algos.go b/turbo/stages/headerdownload/header_algos.go index a0852ac2e..03d6fa23e 100644 --- a/turbo/stages/headerdownload/header_algos.go +++ b/turbo/stages/headerdownload/header_algos.go @@ -54,7 +54,7 @@ func (hd *HeaderDownload) SplitIntoSegments(headersRaw [][]byte, msg []*types.He dedupMap := make(map[common.Hash]struct{}) // Map used for detecting duplicate headers for i, header := range msg { headerHash := header.Hash() - if _, bad := hd.BadHeaders[headerHash]; bad { + if _, bad := hd.badHeaders[headerHash]; bad { return nil, BadBlockPenalty, nil } if _, duplicate := dedupMap[headerHash]; duplicate { @@ -100,12 +100,26 @@ func (hd *HeaderDownload) SingleHeaderAsSegment(headerRaw []byte, header *types. hd.lock.RLock() defer hd.lock.RUnlock() headerHash := header.Hash() - if _, bad := hd.BadHeaders[headerHash]; bad { + if _, bad := hd.badHeaders[headerHash]; bad { return nil, BadBlockPenalty, nil } return []*ChainSegment{{HeadersRaw: [][]byte{headerRaw}, Headers: []*types.Header{header}}}, NoPenalty, nil } +// ReportBadHeader - +func (hd *HeaderDownload) ReportBadHeader(headerHash common.Hash) { + hd.lock.Lock() + defer hd.lock.Unlock() + hd.badHeaders[headerHash] = struct{}{} +} + +func (hd *HeaderDownload) IsBadHeader(headerHash common.Hash) bool { + hd.lock.RLock() + defer hd.lock.RUnlock() + _, ok := hd.badHeaders[headerHash] + return ok +} + // FindAnchors attempts to find anchors to which given chain segment can be attached to func (hd *HeaderDownload) findAnchors(segment *ChainSegment) (found bool, start int) { // Walk the segment from children towards parents @@ -561,7 +575,7 @@ func (hd *HeaderDownload) InsertHeaders(hf func(header *types.Header, blockHeigh hd.insertList = hd.insertList[:len(hd.insertList)-1] skip := false if !link.preverified { - if _, bad := hd.BadHeaders[link.hash]; bad { + if _, bad := hd.badHeaders[link.hash]; bad { skip = true } else if err := hd.engine.VerifyHeader(hd.headerReader, link.header, true /* seal */); err != nil { log.Warn("Verification failed for header", "hash", link.header.Hash(), "height", link.blockHeight, "error", err) diff --git a/turbo/stages/headerdownload/header_data_struct.go b/turbo/stages/headerdownload/header_data_struct.go index 8b88b01c9..aa2d900af 100644 --- a/turbo/stages/headerdownload/header_data_struct.go +++ b/turbo/stages/headerdownload/header_data_struct.go @@ -167,7 +167,7 @@ type CalcDifficultyFunc func(childTimestamp uint64, parentTime uint64, parentDif type HeaderDownload struct { lock sync.RWMutex - BadHeaders map[common.Hash]struct{} + badHeaders map[common.Hash]struct{} anchors map[common.Hash]*Anchor // Mapping from parentHash to collection of anchors preverifiedHashes map[common.Hash]struct{} // Set of hashes that are known to belong to canonical chain preverifiedHeight uint64 // Block height corresponding to the last preverified hash @@ -203,7 +203,7 @@ func NewHeaderDownload( ) *HeaderDownload { persistentLinkLimit := linkLimit / 16 hd := &HeaderDownload{ - BadHeaders: make(map[common.Hash]struct{}), + badHeaders: make(map[common.Hash]struct{}), anchors: make(map[common.Hash]*Anchor), persistedLinkLimit: persistentLinkLimit, linkLimit: linkLimit - persistentLinkLimit, diff --git a/turbo/stages/headerdownload/header_test.go b/turbo/stages/headerdownload/header_test.go index fbeaed2ec..f78070c6f 100644 --- a/turbo/stages/headerdownload/header_test.go +++ b/turbo/stages/headerdownload/header_test.go @@ -51,7 +51,7 @@ func TestSplitIntoSegments(t *testing.T) { } // Single header with a bad hash - hd.BadHeaders[h.Hash()] = struct{}{} + hd.ReportBadHeader(h.Hash()) if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}}, []*types.Header{&h}); err == nil { if penalty != BadBlockPenalty { t.Errorf("expected BadBlock penalty, got %s", penalty) @@ -166,7 +166,7 @@ func TestSingleHeaderAsSegment(t *testing.T) { } // Same header with a bad hash - hd.BadHeaders[h.Hash()] = struct{}{} + hd.ReportBadHeader(h.Hash()) if chainSegments, penalty, err := hd.SingleHeaderAsSegment([]byte{}, &h); err == nil { if penalty != BadBlockPenalty { t.Errorf("expected BadBlock penalty, got %s", penalty) diff --git a/turbo/stages/mock_sentry.go b/turbo/stages/mock_sentry.go index 2f53b6a43..c3a43f9d3 100644 --- a/turbo/stages/mock_sentry.go +++ b/turbo/stages/mock_sentry.go @@ -384,7 +384,7 @@ func (ms *MockSentry) InsertChain(chain *core.ChainPack) error { }); err != nil { return err } - if _, bad := ms.downloader.Hd.BadHeaders[chain.TopBlock.Hash()]; bad { + if ms.downloader.Hd.IsBadHeader(chain.TopBlock.Hash()) { return fmt.Errorf("block %d %x was invalid", chain.TopBlock.NumberU64(), chain.TopBlock.Hash()) } return nil