diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index 34f1cc1f3..b955a7a29 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -29,14 +29,13 @@ import ( "time" "github.com/c2h5oh/datasize" + mdbx1 "github.com/erigontech/mdbx-go/mdbx" + "github.com/ledgerwatch/log/v3" + libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv/mdbx" - "github.com/ledgerwatch/erigon/rlp" - "github.com/ledgerwatch/log/v3" - - mdbx1 "github.com/erigontech/mdbx-go/mdbx" ) // Keys in the node database. @@ -75,16 +74,18 @@ var zeroIP = make(net.IP, 16) // DB is the node database, storing previously seen nodes and any collected metadata about // them for QoS purposes. type DB struct { - kv kv.RwDB // Interface to the database itself - runner sync.Once // Ensures we can start at most one expirer - quit chan struct{} // Channel to signal the expiring thread to stop + kv kv.RwDB // Interface to the database itself + runner sync.Once // Ensures we can start at most one expirer + + ctx context.Context + ctxCancel func() } // OpenDB opens a node database for storing and retrieving infos about known peers in the // network. If no path is given an in-memory, temporary database is constructed. func OpenDB(ctx context.Context, path string, tmpDir string, logger log.Logger) (*DB, error) { if path == "" { - return newMemoryDB(logger, tmpDir) + return newMemoryDB(ctx, logger, tmpDir) } return newPersistentDB(ctx, logger, path) } @@ -97,22 +98,27 @@ func bucketsConfig(_ kv.TableCfg) kv.TableCfg { } // newMemoryNodeDB creates a new in-memory node database without a persistent backend. -func newMemoryDB(logger log.Logger, tmpDir string) (*DB, error) { - db := &DB{quit: make(chan struct{})} - var err error - db.kv, err = mdbx.NewMDBX(logger).InMem(tmpDir).Label(kv.SentryDB).WithTableCfg(bucketsConfig).MapSize(1 * datasize.GB).Open(context.Background()) +func newMemoryDB(ctx context.Context, logger log.Logger, tmpDir string) (*DB, error) { + db, err := mdbx.NewMDBX(logger). + InMem(tmpDir). + Label(kv.SentryDB). + WithTableCfg(bucketsConfig). + MapSize(1 * datasize.GB). + Open(ctx) if err != nil { return nil, err } - return db, nil + + nodeDB := &DB{kv: db} + nodeDB.ctx, nodeDB.ctxCancel = context.WithCancel(ctx) + + return nodeDB, nil } // newPersistentNodeDB creates/opens a persistent node database, // also flushing its contents in case of a version mismatch. func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB, error) { - var db kv.RwDB - var err error - db, err = mdbx.NewMDBX(logger). + db, err := mdbx.NewMDBX(logger). Path(path). Label(kv.SentryDB). WithTableCfg(bucketsConfig). @@ -125,13 +131,14 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB, if err != nil { return nil, err } + // The nodes contained in the cache correspond to a certain protocol version. // Flush all nodes if the version doesn't match. currentVer := make([]byte, binary.MaxVarintLen64) currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))] var blob []byte - if err := db.Update(context.Background(), func(tx kv.RwTx) error { + if err := db.Update(ctx, func(tx kv.RwTx) error { c, err := tx.RwCursor(kv.Inodes) if err != nil { return err @@ -150,6 +157,7 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB, }); err != nil { return nil, err } + if blob != nil && !bytes.Equal(blob, currentVer) { db.Close() if err := os.RemoveAll(path); err != nil { @@ -157,7 +165,11 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB, } return newPersistentDB(ctx, logger, path) } - return &DB{kv: db, quit: make(chan struct{})}, nil + + nodeDB := &DB{kv: db} + nodeDB.ctx, nodeDB.ctxCancel = context.WithCancel(ctx) + + return nodeDB, nil } // nodeKey returns the database key for a node record. @@ -227,7 +239,7 @@ func localItemKey(id ID, field string) []byte { // fetchInt64 retrieves an integer associated with a particular key. func (db *DB) fetchInt64(key []byte) int64 { var val int64 - if err := db.kv.View(context.Background(), func(tx kv.Tx) error { + if err := db.kv.View(db.ctx, func(tx kv.Tx) error { blob, errGet := tx.GetOne(kv.Inodes, key) if errGet != nil { return errGet @@ -249,7 +261,7 @@ func (db *DB) fetchInt64(key []byte) int64 { func (db *DB) storeInt64(key []byte, n int64) error { blob := make([]byte, binary.MaxVarintLen64) blob = blob[:binary.PutVarint(blob, n)] - return db.kv.Update(context.Background(), func(tx kv.RwTx) error { + return db.kv.Update(db.ctx, func(tx kv.RwTx) error { return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob) }) } @@ -257,7 +269,7 @@ func (db *DB) storeInt64(key []byte, n int64) error { // fetchUint64 retrieves an integer associated with a particular key. func (db *DB) fetchUint64(key []byte) uint64 { var val uint64 - if err := db.kv.View(context.Background(), func(tx kv.Tx) error { + if err := db.kv.View(db.ctx, func(tx kv.Tx) error { blob, errGet := tx.GetOne(kv.Inodes, key) if errGet != nil { return errGet @@ -276,7 +288,7 @@ func (db *DB) fetchUint64(key []byte) uint64 { func (db *DB) storeUint64(key []byte, n uint64) error { blob := make([]byte, binary.MaxVarintLen64) blob = blob[:binary.PutUvarint(blob, n)] - return db.kv.Update(context.Background(), func(tx kv.RwTx) error { + return db.kv.Update(db.ctx, func(tx kv.RwTx) error { return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob) }) } @@ -284,7 +296,7 @@ func (db *DB) storeUint64(key []byte, n uint64) error { // Node retrieves a node with a given id from the database. func (db *DB) Node(id ID) *Node { var blob []byte - if err := db.kv.View(context.Background(), func(tx kv.Tx) error { + if err := db.kv.View(db.ctx, func(tx kv.Tx) error { v, errGet := tx.GetOne(kv.NodeRecords, nodeKey(id)) if errGet != nil { return errGet @@ -322,7 +334,7 @@ func (db *DB) UpdateNode(node *Node) error { if err != nil { return err } - if err := db.kv.Update(context.Background(), func(tx kv.RwTx) error { + if err := db.kv.Update(db.ctx, func(tx kv.RwTx) error { return tx.Put(kv.NodeRecords, nodeKey(node.ID()), blob) }); err != nil { return err @@ -346,11 +358,11 @@ func (db *DB) Resolve(n *Node) *Node { // DeleteNode deletes all information associated with a node. func (db *DB) DeleteNode(id ID) { - deleteRange(db.kv, nodeKey(id)) + db.deleteRange(nodeKey(id)) } -func deleteRange(db kv.RwDB, prefix []byte) { - if err := db.Update(context.Background(), func(tx kv.RwTx) error { +func (db *DB) deleteRange(prefix []byte) { + if err := db.kv.Update(db.ctx, func(tx kv.RwTx) error { for bucket := range bucketsConfig(nil) { if err := deleteRangeInBucket(tx, prefix, bucket); err != nil { return err @@ -398,7 +410,7 @@ func (db *DB) expirer() { select { case <-tick.C: db.expireNodes() - case <-db.quit: + case <-db.ctx.Done(): return } } @@ -412,7 +424,7 @@ func (db *DB) expireNodes() { youngestPong int64 ) var toDelete [][]byte - if err := db.kv.View(context.Background(), func(tx kv.Tx) error { + if err := db.kv.View(db.ctx, func(tx kv.Tx) error { c, err := tx.Cursor(kv.Inodes) if err != nil { return err @@ -455,7 +467,7 @@ func (db *DB) expireNodes() { log.Warn("nodeDB.expireNodes failed", "err", err) } for _, td := range toDelete { - deleteRange(db.kv, td) + db.deleteRange(td) } } @@ -545,7 +557,7 @@ func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node { id ID ) - if err := db.kv.View(context.Background(), func(tx kv.Tx) error { + if err := db.kv.View(db.ctx, func(tx kv.Tx) error { c, err := tx.Cursor(kv.NodeRecords) if err != nil { return err @@ -603,14 +615,6 @@ func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node { // close flushes and closes the database files. func (db *DB) Close() { - select { - case <-db.quit: - return // means closed already - default: - } - if db.quit == nil { - return - } - libcommon.SafeClose(db.quit) + db.ctxCancel() db.kv.Close() }