p2p: panic in enode DB Close on shutdown (#9237) (#9240)

If any DB method is called while Close() is waiting for db.kv.Close()
(it waits for ongoing method calls/transactions to finish)
a panic: "WaitGroup is reused before previous Wait has returned" might
happen.

Use context cancellation to ensure that new method calls immediately
return during db.kv.Close().
This commit is contained in:
battlmonstr 2024-01-16 09:34:31 +01:00 committed by GitHub
parent 2793ef6ec1
commit e979d79c08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,14 +29,13 @@ import (
"time" "time"
"github.com/c2h5oh/datasize" "github.com/c2h5oh/datasize"
mdbx1 "github.com/erigontech/mdbx-go/mdbx"
"github.com/ledgerwatch/log/v3"
libcommon "github.com/ledgerwatch/erigon-lib/common" libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/mdbx" "github.com/ledgerwatch/erigon-lib/kv/mdbx"
"github.com/ledgerwatch/erigon/rlp" "github.com/ledgerwatch/erigon/rlp"
"github.com/ledgerwatch/log/v3"
mdbx1 "github.com/erigontech/mdbx-go/mdbx"
) )
// Keys in the node database. // Keys in the node database.
@ -77,14 +76,16 @@ var zeroIP = make(net.IP, 16)
type DB struct { type DB struct {
kv kv.RwDB // Interface to the database itself kv kv.RwDB // Interface to the database itself
runner sync.Once // Ensures we can start at most one expirer runner sync.Once // Ensures we can start at most one expirer
quit chan struct{} // Channel to signal the expiring thread to stop
ctx context.Context
ctxCancel func()
} }
// OpenDB opens a node database for storing and retrieving infos about known peers in the // OpenDB opens a node database for storing and retrieving infos about known peers in the
// network. If no path is given an in-memory, temporary database is constructed. // network. If no path is given an in-memory, temporary database is constructed.
func OpenDB(ctx context.Context, path string, tmpDir string, logger log.Logger) (*DB, error) { func OpenDB(ctx context.Context, path string, tmpDir string, logger log.Logger) (*DB, error) {
if path == "" { if path == "" {
return newMemoryDB(logger, tmpDir) return newMemoryDB(ctx, logger, tmpDir)
} }
return newPersistentDB(ctx, logger, path) return newPersistentDB(ctx, logger, path)
} }
@ -97,22 +98,27 @@ func bucketsConfig(_ kv.TableCfg) kv.TableCfg {
} }
// newMemoryNodeDB creates a new in-memory node database without a persistent backend. // newMemoryNodeDB creates a new in-memory node database without a persistent backend.
func newMemoryDB(logger log.Logger, tmpDir string) (*DB, error) { func newMemoryDB(ctx context.Context, logger log.Logger, tmpDir string) (*DB, error) {
db := &DB{quit: make(chan struct{})} db, err := mdbx.NewMDBX(logger).
var err error InMem(tmpDir).
db.kv, err = mdbx.NewMDBX(logger).InMem(tmpDir).Label(kv.SentryDB).WithTableCfg(bucketsConfig).MapSize(1 * datasize.GB).Open(context.Background()) Label(kv.SentryDB).
WithTableCfg(bucketsConfig).
MapSize(1 * datasize.GB).
Open(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return db, nil
nodeDB := &DB{kv: db}
nodeDB.ctx, nodeDB.ctxCancel = context.WithCancel(ctx)
return nodeDB, nil
} }
// newPersistentNodeDB creates/opens a persistent node database, // newPersistentNodeDB creates/opens a persistent node database,
// also flushing its contents in case of a version mismatch. // also flushing its contents in case of a version mismatch.
func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB, error) { func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB, error) {
var db kv.RwDB db, err := mdbx.NewMDBX(logger).
var err error
db, err = mdbx.NewMDBX(logger).
Path(path). Path(path).
Label(kv.SentryDB). Label(kv.SentryDB).
WithTableCfg(bucketsConfig). WithTableCfg(bucketsConfig).
@ -125,13 +131,14 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB,
if err != nil { if err != nil {
return nil, err return nil, err
} }
// The nodes contained in the cache correspond to a certain protocol version. // The nodes contained in the cache correspond to a certain protocol version.
// Flush all nodes if the version doesn't match. // Flush all nodes if the version doesn't match.
currentVer := make([]byte, binary.MaxVarintLen64) currentVer := make([]byte, binary.MaxVarintLen64)
currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))] currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))]
var blob []byte var blob []byte
if err := db.Update(context.Background(), func(tx kv.RwTx) error { if err := db.Update(ctx, func(tx kv.RwTx) error {
c, err := tx.RwCursor(kv.Inodes) c, err := tx.RwCursor(kv.Inodes)
if err != nil { if err != nil {
return err return err
@ -150,6 +157,7 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB,
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
if blob != nil && !bytes.Equal(blob, currentVer) { if blob != nil && !bytes.Equal(blob, currentVer) {
db.Close() db.Close()
if err := os.RemoveAll(path); err != nil { if err := os.RemoveAll(path); err != nil {
@ -157,7 +165,11 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB,
} }
return newPersistentDB(ctx, logger, path) return newPersistentDB(ctx, logger, path)
} }
return &DB{kv: db, quit: make(chan struct{})}, nil
nodeDB := &DB{kv: db}
nodeDB.ctx, nodeDB.ctxCancel = context.WithCancel(ctx)
return nodeDB, nil
} }
// nodeKey returns the database key for a node record. // nodeKey returns the database key for a node record.
@ -227,7 +239,7 @@ func localItemKey(id ID, field string) []byte {
// fetchInt64 retrieves an integer associated with a particular key. // fetchInt64 retrieves an integer associated with a particular key.
func (db *DB) fetchInt64(key []byte) int64 { func (db *DB) fetchInt64(key []byte) int64 {
var val int64 var val int64
if err := db.kv.View(context.Background(), func(tx kv.Tx) error { if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
blob, errGet := tx.GetOne(kv.Inodes, key) blob, errGet := tx.GetOne(kv.Inodes, key)
if errGet != nil { if errGet != nil {
return errGet return errGet
@ -249,7 +261,7 @@ func (db *DB) fetchInt64(key []byte) int64 {
func (db *DB) storeInt64(key []byte, n int64) error { func (db *DB) storeInt64(key []byte, n int64) error {
blob := make([]byte, binary.MaxVarintLen64) blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutVarint(blob, n)] blob = blob[:binary.PutVarint(blob, n)]
return db.kv.Update(context.Background(), func(tx kv.RwTx) error { return db.kv.Update(db.ctx, func(tx kv.RwTx) error {
return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob) return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob)
}) })
} }
@ -257,7 +269,7 @@ func (db *DB) storeInt64(key []byte, n int64) error {
// fetchUint64 retrieves an integer associated with a particular key. // fetchUint64 retrieves an integer associated with a particular key.
func (db *DB) fetchUint64(key []byte) uint64 { func (db *DB) fetchUint64(key []byte) uint64 {
var val uint64 var val uint64
if err := db.kv.View(context.Background(), func(tx kv.Tx) error { if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
blob, errGet := tx.GetOne(kv.Inodes, key) blob, errGet := tx.GetOne(kv.Inodes, key)
if errGet != nil { if errGet != nil {
return errGet return errGet
@ -276,7 +288,7 @@ func (db *DB) fetchUint64(key []byte) uint64 {
func (db *DB) storeUint64(key []byte, n uint64) error { func (db *DB) storeUint64(key []byte, n uint64) error {
blob := make([]byte, binary.MaxVarintLen64) blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutUvarint(blob, n)] blob = blob[:binary.PutUvarint(blob, n)]
return db.kv.Update(context.Background(), func(tx kv.RwTx) error { return db.kv.Update(db.ctx, func(tx kv.RwTx) error {
return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob) return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob)
}) })
} }
@ -284,7 +296,7 @@ func (db *DB) storeUint64(key []byte, n uint64) error {
// Node retrieves a node with a given id from the database. // Node retrieves a node with a given id from the database.
func (db *DB) Node(id ID) *Node { func (db *DB) Node(id ID) *Node {
var blob []byte var blob []byte
if err := db.kv.View(context.Background(), func(tx kv.Tx) error { if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
v, errGet := tx.GetOne(kv.NodeRecords, nodeKey(id)) v, errGet := tx.GetOne(kv.NodeRecords, nodeKey(id))
if errGet != nil { if errGet != nil {
return errGet return errGet
@ -322,7 +334,7 @@ func (db *DB) UpdateNode(node *Node) error {
if err != nil { if err != nil {
return err return err
} }
if err := db.kv.Update(context.Background(), func(tx kv.RwTx) error { if err := db.kv.Update(db.ctx, func(tx kv.RwTx) error {
return tx.Put(kv.NodeRecords, nodeKey(node.ID()), blob) return tx.Put(kv.NodeRecords, nodeKey(node.ID()), blob)
}); err != nil { }); err != nil {
return err return err
@ -346,11 +358,11 @@ func (db *DB) Resolve(n *Node) *Node {
// DeleteNode deletes all information associated with a node. // DeleteNode deletes all information associated with a node.
func (db *DB) DeleteNode(id ID) { func (db *DB) DeleteNode(id ID) {
deleteRange(db.kv, nodeKey(id)) db.deleteRange(nodeKey(id))
} }
func deleteRange(db kv.RwDB, prefix []byte) { func (db *DB) deleteRange(prefix []byte) {
if err := db.Update(context.Background(), func(tx kv.RwTx) error { if err := db.kv.Update(db.ctx, func(tx kv.RwTx) error {
for bucket := range bucketsConfig(nil) { for bucket := range bucketsConfig(nil) {
if err := deleteRangeInBucket(tx, prefix, bucket); err != nil { if err := deleteRangeInBucket(tx, prefix, bucket); err != nil {
return err return err
@ -398,7 +410,7 @@ func (db *DB) expirer() {
select { select {
case <-tick.C: case <-tick.C:
db.expireNodes() db.expireNodes()
case <-db.quit: case <-db.ctx.Done():
return return
} }
} }
@ -412,7 +424,7 @@ func (db *DB) expireNodes() {
youngestPong int64 youngestPong int64
) )
var toDelete [][]byte var toDelete [][]byte
if err := db.kv.View(context.Background(), func(tx kv.Tx) error { if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
c, err := tx.Cursor(kv.Inodes) c, err := tx.Cursor(kv.Inodes)
if err != nil { if err != nil {
return err return err
@ -455,7 +467,7 @@ func (db *DB) expireNodes() {
log.Warn("nodeDB.expireNodes failed", "err", err) log.Warn("nodeDB.expireNodes failed", "err", err)
} }
for _, td := range toDelete { for _, td := range toDelete {
deleteRange(db.kv, td) db.deleteRange(td)
} }
} }
@ -545,7 +557,7 @@ func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node {
id ID id ID
) )
if err := db.kv.View(context.Background(), func(tx kv.Tx) error { if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
c, err := tx.Cursor(kv.NodeRecords) c, err := tx.Cursor(kv.NodeRecords)
if err != nil { if err != nil {
return err return err
@ -603,14 +615,6 @@ func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node {
// close flushes and closes the database files. // close flushes and closes the database files.
func (db *DB) Close() { func (db *DB) Close() {
select { db.ctxCancel()
case <-db.quit:
return // means closed already
default:
}
if db.quit == nil {
return
}
libcommon.SafeClose(db.quit)
db.kv.Close() db.kv.Close()
} }