p2p: panic in enode DB Close on shutdown (#9237) (#9240)

If any DB method is called while Close() is waiting for db.kv.Close()
(it waits for ongoing method calls/transactions to finish)
a panic: "WaitGroup is reused before previous Wait has returned" might
happen.

Use context cancellation to ensure that new method calls immediately
return during db.kv.Close().
This commit is contained in:
battlmonstr 2024-01-16 09:34:31 +01:00 committed by GitHub
parent 2793ef6ec1
commit e979d79c08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,14 +29,13 @@ import (
"time"
"github.com/c2h5oh/datasize"
mdbx1 "github.com/erigontech/mdbx-go/mdbx"
"github.com/ledgerwatch/log/v3"
libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/mdbx"
"github.com/ledgerwatch/erigon/rlp"
"github.com/ledgerwatch/log/v3"
mdbx1 "github.com/erigontech/mdbx-go/mdbx"
)
// Keys in the node database.
@ -75,16 +74,18 @@ var zeroIP = make(net.IP, 16)
// DB is the node database, storing previously seen nodes and any collected metadata about
// them for QoS purposes.
type DB struct {
kv kv.RwDB // Interface to the database itself
runner sync.Once // Ensures we can start at most one expirer
quit chan struct{} // Channel to signal the expiring thread to stop
kv kv.RwDB // Interface to the database itself
runner sync.Once // Ensures we can start at most one expirer
ctx context.Context
ctxCancel func()
}
// OpenDB opens a node database for storing and retrieving infos about known peers in the
// network. If no path is given an in-memory, temporary database is constructed.
func OpenDB(ctx context.Context, path string, tmpDir string, logger log.Logger) (*DB, error) {
if path == "" {
return newMemoryDB(logger, tmpDir)
return newMemoryDB(ctx, logger, tmpDir)
}
return newPersistentDB(ctx, logger, path)
}
@ -97,22 +98,27 @@ func bucketsConfig(_ kv.TableCfg) kv.TableCfg {
}
// newMemoryNodeDB creates a new in-memory node database without a persistent backend.
func newMemoryDB(logger log.Logger, tmpDir string) (*DB, error) {
db := &DB{quit: make(chan struct{})}
var err error
db.kv, err = mdbx.NewMDBX(logger).InMem(tmpDir).Label(kv.SentryDB).WithTableCfg(bucketsConfig).MapSize(1 * datasize.GB).Open(context.Background())
func newMemoryDB(ctx context.Context, logger log.Logger, tmpDir string) (*DB, error) {
db, err := mdbx.NewMDBX(logger).
InMem(tmpDir).
Label(kv.SentryDB).
WithTableCfg(bucketsConfig).
MapSize(1 * datasize.GB).
Open(ctx)
if err != nil {
return nil, err
}
return db, nil
nodeDB := &DB{kv: db}
nodeDB.ctx, nodeDB.ctxCancel = context.WithCancel(ctx)
return nodeDB, nil
}
// newPersistentNodeDB creates/opens a persistent node database,
// also flushing its contents in case of a version mismatch.
func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB, error) {
var db kv.RwDB
var err error
db, err = mdbx.NewMDBX(logger).
db, err := mdbx.NewMDBX(logger).
Path(path).
Label(kv.SentryDB).
WithTableCfg(bucketsConfig).
@ -125,13 +131,14 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB,
if err != nil {
return nil, err
}
// The nodes contained in the cache correspond to a certain protocol version.
// Flush all nodes if the version doesn't match.
currentVer := make([]byte, binary.MaxVarintLen64)
currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))]
var blob []byte
if err := db.Update(context.Background(), func(tx kv.RwTx) error {
if err := db.Update(ctx, func(tx kv.RwTx) error {
c, err := tx.RwCursor(kv.Inodes)
if err != nil {
return err
@ -150,6 +157,7 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB,
}); err != nil {
return nil, err
}
if blob != nil && !bytes.Equal(blob, currentVer) {
db.Close()
if err := os.RemoveAll(path); err != nil {
@ -157,7 +165,11 @@ func newPersistentDB(ctx context.Context, logger log.Logger, path string) (*DB,
}
return newPersistentDB(ctx, logger, path)
}
return &DB{kv: db, quit: make(chan struct{})}, nil
nodeDB := &DB{kv: db}
nodeDB.ctx, nodeDB.ctxCancel = context.WithCancel(ctx)
return nodeDB, nil
}
// nodeKey returns the database key for a node record.
@ -227,7 +239,7 @@ func localItemKey(id ID, field string) []byte {
// fetchInt64 retrieves an integer associated with a particular key.
func (db *DB) fetchInt64(key []byte) int64 {
var val int64
if err := db.kv.View(context.Background(), func(tx kv.Tx) error {
if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
blob, errGet := tx.GetOne(kv.Inodes, key)
if errGet != nil {
return errGet
@ -249,7 +261,7 @@ func (db *DB) fetchInt64(key []byte) int64 {
func (db *DB) storeInt64(key []byte, n int64) error {
blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutVarint(blob, n)]
return db.kv.Update(context.Background(), func(tx kv.RwTx) error {
return db.kv.Update(db.ctx, func(tx kv.RwTx) error {
return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob)
})
}
@ -257,7 +269,7 @@ func (db *DB) storeInt64(key []byte, n int64) error {
// fetchUint64 retrieves an integer associated with a particular key.
func (db *DB) fetchUint64(key []byte) uint64 {
var val uint64
if err := db.kv.View(context.Background(), func(tx kv.Tx) error {
if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
blob, errGet := tx.GetOne(kv.Inodes, key)
if errGet != nil {
return errGet
@ -276,7 +288,7 @@ func (db *DB) fetchUint64(key []byte) uint64 {
func (db *DB) storeUint64(key []byte, n uint64) error {
blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutUvarint(blob, n)]
return db.kv.Update(context.Background(), func(tx kv.RwTx) error {
return db.kv.Update(db.ctx, func(tx kv.RwTx) error {
return tx.Put(kv.Inodes, libcommon.CopyBytes(key), blob)
})
}
@ -284,7 +296,7 @@ func (db *DB) storeUint64(key []byte, n uint64) error {
// Node retrieves a node with a given id from the database.
func (db *DB) Node(id ID) *Node {
var blob []byte
if err := db.kv.View(context.Background(), func(tx kv.Tx) error {
if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
v, errGet := tx.GetOne(kv.NodeRecords, nodeKey(id))
if errGet != nil {
return errGet
@ -322,7 +334,7 @@ func (db *DB) UpdateNode(node *Node) error {
if err != nil {
return err
}
if err := db.kv.Update(context.Background(), func(tx kv.RwTx) error {
if err := db.kv.Update(db.ctx, func(tx kv.RwTx) error {
return tx.Put(kv.NodeRecords, nodeKey(node.ID()), blob)
}); err != nil {
return err
@ -346,11 +358,11 @@ func (db *DB) Resolve(n *Node) *Node {
// DeleteNode deletes all information associated with a node.
func (db *DB) DeleteNode(id ID) {
deleteRange(db.kv, nodeKey(id))
db.deleteRange(nodeKey(id))
}
func deleteRange(db kv.RwDB, prefix []byte) {
if err := db.Update(context.Background(), func(tx kv.RwTx) error {
func (db *DB) deleteRange(prefix []byte) {
if err := db.kv.Update(db.ctx, func(tx kv.RwTx) error {
for bucket := range bucketsConfig(nil) {
if err := deleteRangeInBucket(tx, prefix, bucket); err != nil {
return err
@ -398,7 +410,7 @@ func (db *DB) expirer() {
select {
case <-tick.C:
db.expireNodes()
case <-db.quit:
case <-db.ctx.Done():
return
}
}
@ -412,7 +424,7 @@ func (db *DB) expireNodes() {
youngestPong int64
)
var toDelete [][]byte
if err := db.kv.View(context.Background(), func(tx kv.Tx) error {
if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
c, err := tx.Cursor(kv.Inodes)
if err != nil {
return err
@ -455,7 +467,7 @@ func (db *DB) expireNodes() {
log.Warn("nodeDB.expireNodes failed", "err", err)
}
for _, td := range toDelete {
deleteRange(db.kv, td)
db.deleteRange(td)
}
}
@ -545,7 +557,7 @@ func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node {
id ID
)
if err := db.kv.View(context.Background(), func(tx kv.Tx) error {
if err := db.kv.View(db.ctx, func(tx kv.Tx) error {
c, err := tx.Cursor(kv.NodeRecords)
if err != nil {
return err
@ -603,14 +615,6 @@ func (db *DB) QuerySeeds(n int, maxAge time.Duration) []*Node {
// close flushes and closes the database files.
func (db *DB) Close() {
select {
case <-db.quit:
return // means closed already
default:
}
if db.quit == nil {
return
}
libcommon.SafeClose(db.quit)
db.ctxCancel()
db.kv.Close()
}