diff --git a/eth/stagedsync/stage_hashcheck.go b/eth/stagedsync/stage_hashcheck.go index 66fdec544..3477c95b2 100644 --- a/eth/stagedsync/stage_hashcheck.go +++ b/eth/stagedsync/stage_hashcheck.go @@ -73,7 +73,7 @@ func SpawnCheckFinalHashStage(s *StageState, stateDB ethdb.Database, datadir str return s.DoneAndUpdate(stateDB, blockNr) } -func unwindHashCheckStage(unwindPoint uint64, stateDB ethdb.Database) error { +func unwindHashCheckStage(unwindPoint uint64, stateDB ethdb.Database, datadir string, quit chan struct{}) error { // Currently it does not require unwinding because it does not create any Intemediate Hash records // and recomputes the state root from scratch lastProcessedBlockNumber, err := stages.GetStageProgress(stateDB, stages.HashCheck) @@ -87,14 +87,16 @@ func unwindHashCheckStage(unwindPoint uint64, stateDB ethdb.Database) error { } return nil } - mutation := stateDB.NewBatch() - err = stages.SaveStageUnwind(mutation, stages.HashCheck, 0) - if err != nil { - return fmt.Errorf("unwind HashCheck: reset: %v", err) + prom := NewPromoter(stateDB, quit) + prom.TempDir = datadir + if err = prom.Unwind(lastProcessedBlockNumber, unwindPoint, dbutils.PlainAccountChangeSetBucket); err != nil { + return err } - _, err = mutation.Commit() - if err != nil { - return fmt.Errorf("unwind HashCheck: failed to write db commit: %v", err) + if err = prom.Unwind(lastProcessedBlockNumber, unwindPoint, dbutils.PlainStorageChangeSetBucket); err != nil { + return err + } + if err = stages.SaveStageUnwind(stateDB, stages.HashCheck, 0); err != nil { + return fmt.Errorf("unwind HashCheck: reset: %v", err) } return nil } @@ -291,6 +293,42 @@ func (p *Promoter) writeBufferMapToTempFile(pattern string, bufferMap map[string return filename, nil } +func (p *Promoter) writeUnwindBufferMapToTempFile(pattern string, bufferMap map[string][]byte) (string, error) { + var filename string + keys := make([]string, len(bufferMap)) + i := 0 + for key := range bufferMap { + keys[i] = key + i++ + } + sort.Strings(keys) + var w *bufio.Writer + if bufferFile, err := ioutil.TempFile(p.TempDir, pattern); err == nil { + //nolint:errcheck + defer bufferFile.Close() + filename = bufferFile.Name() + w = bufio.NewWriter(bufferFile) + } else { + return filename, fmt.Errorf("creating temp buf file %s: %v", pattern, err) + } + for _, key := range keys { + if _, err := w.Write([]byte(key)); err != nil { + return filename, err + } + value := bufferMap[key] + if err := w.WriteByte(byte(len(value))); err != nil { + return filename, err + } + if _, err := w.Write(value); err != nil { + return filename, err + } + } + if err := w.Flush(); err != nil { + return filename, fmt.Errorf("flushing file %s: %v", filename, err) + } + return filename, nil +} + func (p *Promoter) mergeFilesAndCollect(bufferFileNames []string, keyLength int, collector *etl.Collector) error { h := &etl.Heap{} heap.Init(h) @@ -341,6 +379,76 @@ func (p *Promoter) mergeFilesAndCollect(bufferFileNames []string, keyLength int, return nil } +func (p *Promoter) mergeUnwindFilesAndCollect(bufferFileNames []string, keyLength int, collector *etl.Collector) error { + h := &etl.Heap{} + heap.Init(h) + readers := make([]io.Reader, len(bufferFileNames)) + for i, fileName := range bufferFileNames { + if f, err := os.Open(fileName); err == nil { + readers[i] = bufio.NewReader(f) + //nolint:errcheck + defer f.Close() + } else { + return err + } + // Read first key + keyBuf := make([]byte, keyLength) + if n, err := io.ReadFull(readers[i], keyBuf); err != nil || n != keyLength { + return fmt.Errorf("init reading from account buffer file: %d %x %v", n, keyBuf[:n], err) + } + var l [1]byte + if n, err := io.ReadFull(readers[i], l[:]); err != nil || n != 1 { + return fmt.Errorf("init reading from account buffer file: %d %v", n, err) + } + var valBuf []byte + valLength := int(l[0]) + if valLength > 0 { + valBuf = make([]byte, valLength) + if n, err := io.ReadFull(readers[i], valBuf); err != nil || n != valLength { + return fmt.Errorf("init reading from account buffer file: %d %v", n, err) + } + } + heap.Push(h, etl.HeapElem{keyBuf, i, valBuf}) + } + // By now, the heap has one element for each buffer file + var prevKey []byte + for h.Len() > 0 { + if err := common.Stopped(p.quitCh); err != nil { + return err + } + element := (heap.Pop(h)).(etl.HeapElem) + if !bytes.Equal(element.Key, prevKey) { + // Ignore all the repeating keys, and take the earlist + prevKey = common.CopyBytes(element.Key) + if err := collector.Collect(element.Key, element.Value); err != nil { + return err + } + } + reader := readers[element.TimeIdx] + // Try to read the next key (reuse the element) + if n, err := io.ReadFull(reader, element.Key); err == nil && n == keyLength { + var l [1]byte + if n1, err1 := io.ReadFull(reader, l[:]); err1 != nil || n1 != 1 { + return fmt.Errorf("reading from account buffer file: %d %v", n1, err1) + } + var valBuf []byte + valLength := int(l[0]) + if valLength > 0 { + valBuf = make([]byte, valLength) + if n1, err1 := io.ReadFull(reader, valBuf); err1 != nil || n1 != valLength { + return fmt.Errorf("reading from account buffer file: %d %v", n1, err1) + } + } + element.Value = valBuf + heap.Push(h, element) + } else if err != io.EOF { + // If it is EOF, we simply do not return anything into the heap + return fmt.Errorf("next reading from account buffer file: %d %x %v", n, element.Key[:n], err) + } + } + return nil +} + func (p *Promoter) Promote(from, to uint64, changeSetBucket []byte) error { v, ok := promoterMapper[string(changeSetBucket)] if !ok { @@ -402,6 +510,71 @@ func (p *Promoter) Promote(from, to uint64, changeSetBucket []byte) error { return nil } +func (p *Promoter) Unwind(from, to uint64, changeSetBucket []byte) error { + v, ok := promoterMapper[string(changeSetBucket)] + if !ok { + return fmt.Errorf("unknown bucket type: %s", changeSetBucket) + } + log.Info("Unwinding started", "from", from, "to", to, "csbucket", string(changeSetBucket)) + var m runtime.MemStats + var bufferFileNames []string + changesets := make([]byte, p.ChangeSetBufSize) // 256 Mb buffer by default + var offsets []int + var done = false + blockNum := to + 1 + for !done { + if newDone, newBlockNum, newOffsets, err := p.fillChangeSetBuffer(changeSetBucket, blockNum, from, changesets, offsets); err == nil { + done = newDone + blockNum = newBlockNum + offsets = newOffsets + } else { + return err + } + if len(offsets) == 0 { + break + } + + bufferMap := make(map[string][]byte) + prevOffset := 0 + for _, offset := range offsets { + if err := v.WalkerAdapter(changesets[prevOffset:offset]).Walk(func(k, v []byte) error { + ks := string(k) + if _, ok := bufferMap[ks]; !ok { + // Do not replace the existing values, so we end up with the earlier possible values + bufferMap[ks] = v + } + return nil + }); err != nil { + return err + } + prevOffset = offset + } + + if filename, err := p.writeUnwindBufferMapToTempFile(v.Template, bufferMap); err == nil { + defer func() { + //nolint:errcheck + os.Remove(filename) + }() + bufferFileNames = append(bufferFileNames, filename) + runtime.ReadMemStats(&m) + log.Info("Created a buffer file", "name", filename, "up to block", blockNum, + "alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC)) + } else { + return err + } + } + if len(offsets) > 0 { + collector := etl.NewCollector(p.TempDir) + if err := p.mergeUnwindFilesAndCollect(bufferFileNames, v.KeySize, collector); err != nil { + return err + } + if err := collector.Load(p.db, dbutils.CurrentStateBucket, keyTransformLoadFunc, etl.TransformArgs{Quit: p.quitCh}); err != nil { + return err + } + } + return nil +} + func promoteHashedStateIncrementally(from, to uint64, db ethdb.Database, datadir string, quit chan struct{}) error { prom := NewPromoter(db, quit) prom.TempDir = datadir diff --git a/eth/stagedsync/stage_hashcheck_test.go b/eth/stagedsync/stage_hashcheck_test.go index cd36ad444..de2cd1be9 100644 --- a/eth/stagedsync/stage_hashcheck_test.go +++ b/eth/stagedsync/stage_hashcheck_test.go @@ -91,3 +91,21 @@ func TestPromoteHashedStateIncrementalMixed(t *testing.T) { compareCurrentState(t, db1, db2, dbutils.CurrentStateBucket) } + +func TestUnwindHashed(t *testing.T) { + db1 := ethdb.NewMemDatabase() + db2 := ethdb.NewMemDatabase() + + generateBlocks(t, 1, 50, hashedWriterGen(db1), changeCodeWithIncarnations) + generateBlocks(t, 1, 50, plainWriterGen(db2), changeCodeWithIncarnations) + + err := promoteHashedState(db2, 0, 100, getDataDir(), nil) + if err != nil { + t.Errorf("error while promoting state: %v", err) + } + err = unwindHashCheckStage(50, db2, getDataDir(), nil) + if err != nil { + t.Errorf("error while unwind state: %v", err) + } + compareCurrentState(t, db1, db2, dbutils.CurrentStateBucket) +} diff --git a/eth/stagedsync/stage_headers.go b/eth/stagedsync/stage_headers.go index 0f7294677..c23fb37c3 100644 --- a/eth/stagedsync/stage_headers.go +++ b/eth/stagedsync/stage_headers.go @@ -9,7 +9,7 @@ import ( "github.com/ledgerwatch/turbo-geth/log" ) -func DownloadHeaders(s *StageState, d DownloaderGlue, stateDB ethdb.Database, headersFetchers []func() error, quitCh chan struct{}) error { +func DownloadHeaders(s *StageState, d DownloaderGlue, stateDB ethdb.Database, headersFetchers []func() error, datadir string, quitCh chan struct{}) error { err := d.SpawnSync(headersFetchers) if err != nil { return err @@ -35,7 +35,7 @@ func DownloadHeaders(s *StageState, d DownloaderGlue, stateDB ethdb.Database, he case stages.Execution: err = unwindExecutionStage(unwindPoint, stateDB) case stages.HashCheck: - err = unwindHashCheckStage(unwindPoint, stateDB) + err = unwindHashCheckStage(unwindPoint, stateDB, datadir, quitCh) case stages.AccountHistoryIndex: err = unwindAccountHistoryIndex(unwindPoint, stateDB, core.UsePlainStateExecution, quitCh) case stages.StorageHistoryIndex: diff --git a/eth/stagedsync/stagedsync.go b/eth/stagedsync/stagedsync.go index 77495e496..cda3600a2 100644 --- a/eth/stagedsync/stagedsync.go +++ b/eth/stagedsync/stagedsync.go @@ -26,7 +26,7 @@ func DoStagedSyncWithFetchers( ID: stages.Headers, Description: "Downloading headers", ExecFunc: func(s *StageState) error { - return DownloadHeaders(s, d, stateDB, headersFetchers, quitCh) + return DownloadHeaders(s, d, stateDB, headersFetchers, datadir, quitCh) }, }, {