diff --git a/erigon-lib/kv/mdbx/kv_mdbx.go b/erigon-lib/kv/mdbx/kv_mdbx.go index cd465b593..58ff2f4a0 100644 --- a/erigon-lib/kv/mdbx/kv_mdbx.go +++ b/erigon-lib/kv/mdbx/kv_mdbx.go @@ -34,7 +34,6 @@ import ( "github.com/c2h5oh/datasize" "github.com/erigontech/mdbx-go/mdbx" stack2 "github.com/go-stack/stack" - "github.com/ledgerwatch/erigon-lib/mmap" "github.com/ledgerwatch/log/v3" "golang.org/x/exp/maps" "golang.org/x/sync/semaphore" @@ -44,6 +43,7 @@ import ( "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv/iter" "github.com/ledgerwatch/erigon-lib/kv/order" + "github.com/ledgerwatch/erigon-lib/mmap" ) const NonExistingDBI kv.DBI = 999_999_999 @@ -385,15 +385,20 @@ func (opts MdbxOpts) Open(ctx context.Context) (kv.RwDB, error) { targetSemCount := int64(runtime.GOMAXPROCS(-1) * 16) opts.roTxsLimiter = semaphore.NewWeighted(targetSemCount) // 1 less than max to allow unlocking to happen } + + txsCountMutex := &sync.Mutex{} + db := &MdbxKV{ opts: opts, env: env, log: opts.log, - wg: &sync.WaitGroup{}, buckets: kv.TableCfg{}, txSize: dirtyPagesLimit * opts.pageSize, roTxsLimiter: opts.roTxsLimiter, + txsCountMutex: txsCountMutex, + txsAllDoneOnCloseCond: sync.NewCond(txsCountMutex), + leakDetector: dbg.NewLeakDetector("db."+opts.label.String(), dbg.SlowTx()), } @@ -457,7 +462,6 @@ func (opts MdbxOpts) MustOpen() kv.RwDB { type MdbxKV struct { log log.Logger env *mdbx.Env - wg *sync.WaitGroup buckets kv.TableCfg roTxsLimiter *semaphore.Weighted // does limit amount of concurrent Ro transactions - in most casess runtime.NumCPU() is good value for this channel capacity - this channel can be shared with other components (like Decompressor) opts MdbxOpts @@ -465,6 +469,10 @@ type MdbxKV struct { closed atomic.Bool path string + txsCount uint + txsCountMutex *sync.Mutex + txsAllDoneOnCloseCond *sync.Cond + leakDetector *dbg.LeakDetector } @@ -507,13 +515,53 @@ func (db *MdbxKV) openDBIs(buckets []string) error { }) } +func (db *MdbxKV) trackTxBegin() bool { + db.txsCountMutex.Lock() + defer db.txsCountMutex.Unlock() + + isOpen := !db.closed.Load() + if isOpen { + db.txsCount++ + } + return isOpen +} + +func (db *MdbxKV) hasTxsAllDoneAndClosed() bool { + return (db.txsCount == 0) && db.closed.Load() +} + +func (db *MdbxKV) trackTxEnd() { + db.txsCountMutex.Lock() + defer db.txsCountMutex.Unlock() + + if db.txsCount > 0 { + db.txsCount-- + } else { + panic("MdbxKV: unmatched trackTxEnd") + } + + if db.hasTxsAllDoneAndClosed() { + db.txsAllDoneOnCloseCond.Signal() + } +} + +func (db *MdbxKV) waitTxsAllDoneOnClose() { + db.txsCountMutex.Lock() + defer db.txsCountMutex.Unlock() + + for !db.hasTxsAllDoneAndClosed() { + db.txsAllDoneOnCloseCond.Wait() + } +} + // Close closes db // All transactions must be closed before closing the database. func (db *MdbxKV) Close() { if ok := db.closed.CompareAndSwap(false, true); !ok { return } - db.wg.Wait() + db.waitTxsAllDoneOnClose() + db.env.Close() db.env = nil @@ -526,10 +574,6 @@ func (db *MdbxKV) Close() { } func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) { - if db.closed.Load() { - return nil, fmt.Errorf("db closed") - } - // don't try to acquire if the context is already done select { case <-ctx.Done(): @@ -538,8 +582,13 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) { // otherwise carry on } + if !db.trackTxBegin() { + return nil, fmt.Errorf("db closed") + } + // will return nil err if context is cancelled (may appear to acquire the semaphore) if semErr := db.roTxsLimiter.Acquire(ctx, 1); semErr != nil { + db.trackTxEnd() return nil, semErr } @@ -548,6 +597,7 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) { // on error, or if there is whatever reason that we don't return a tx, // we need to free up the limiter slot, otherwise it could lead to deadlocks db.roTxsLimiter.Release(1) + db.trackTxEnd() } }() @@ -555,7 +605,7 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) { if err != nil { return nil, fmt.Errorf("%w, label: %s, trace: %s", err, db.opts.label.String(), stack2.Trace().String()) } - db.wg.Add(1) + return &MdbxTx{ ctx: ctx, db: db, @@ -579,16 +629,18 @@ func (db *MdbxKV) beginRw(ctx context.Context, flags uint) (txn kv.RwTx, err err default: } - if db.closed.Load() { + if !db.trackTxBegin() { return nil, fmt.Errorf("db closed") } + runtime.LockOSThread() tx, err := db.env.BeginTxn(nil, flags) if err != nil { runtime.UnlockOSThread() // unlock only in case of error. normal flow is "defer .Rollback()" + db.trackTxEnd() return nil, fmt.Errorf("%w, lable: %s, trace: %s", err, db.opts.label.String(), stack2.Trace().String()) } - db.wg.Add(1) + return &MdbxTx{ db: db, tx: tx, @@ -830,7 +882,7 @@ func (tx *MdbxTx) Commit() error { } defer func() { tx.tx = nil - tx.db.wg.Done() + tx.db.trackTxEnd() if tx.readOnly { tx.db.roTxsLimiter.Release(1) } else { @@ -881,7 +933,7 @@ func (tx *MdbxTx) Rollback() { } defer func() { tx.tx = nil - tx.db.wg.Done() + tx.db.trackTxEnd() if tx.readOnly { tx.db.roTxsLimiter.Release(1) } else { diff --git a/erigon-lib/kv/mdbx/kv_mdbx_test.go b/erigon-lib/kv/mdbx/kv_mdbx_test.go index e79a852da..66506ef72 100644 --- a/erigon-lib/kv/mdbx/kv_mdbx_test.go +++ b/erigon-lib/kv/mdbx/kv_mdbx_test.go @@ -18,14 +18,17 @@ package mdbx import ( "context" + "sync/atomic" "testing" + "time" "github.com/c2h5oh/datasize" - "github.com/ledgerwatch/erigon-lib/kv" - "github.com/ledgerwatch/erigon-lib/kv/order" "github.com/ledgerwatch/log/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/ledgerwatch/erigon-lib/kv" + "github.com/ledgerwatch/erigon-lib/kv/order" ) func BaseCase(t *testing.T) (kv.RwDB, kv.RwTx, kv.RwCursorDupSort) { @@ -773,3 +776,124 @@ func TestAutoConversionSeekBothRange(t *testing.T) { require.NoError(t, err) assert.Nil(t, v) } + +func TestBeginRoAfterClose(t *testing.T) { + db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen() + db.Close() + _, err := db.BeginRo(context.Background()) + require.ErrorContains(t, err, "closed") +} + +func TestBeginRwAfterClose(t *testing.T) { + db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen() + db.Close() + _, err := db.BeginRw(context.Background()) + require.ErrorContains(t, err, "closed") +} + +func TestBeginRoWithDoneContext(t *testing.T) { + db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen() + defer db.Close() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := db.BeginRo(ctx) + require.ErrorIs(t, err, context.Canceled) +} + +func TestBeginRwWithDoneContext(t *testing.T) { + db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen() + defer db.Close() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := db.BeginRw(ctx) + require.ErrorIs(t, err, context.Canceled) +} + +func testCloseWaitsAfterTxBegin( + t *testing.T, + count int, + txBeginFunc func(kv.RwDB) (kv.StatelessReadTx, error), + txEndFunc func(kv.StatelessReadTx) error, +) { + t.Helper() + db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen() + var txs []kv.StatelessReadTx + for i := 0; i < count; i++ { + tx, err := txBeginFunc(db) + require.Nil(t, err) + txs = append(txs, tx) + } + + isClosed := &atomic.Bool{} + closeDone := make(chan struct{}) + + go func() { + db.Close() + isClosed.Store(true) + close(closeDone) + }() + + for _, tx := range txs { + // arbitrary delay to give db.Close() a chance to exit prematurely + time.Sleep(time.Millisecond * 20) + assert.False(t, isClosed.Load()) + + err := txEndFunc(tx) + require.Nil(t, err) + } + + <-closeDone + assert.True(t, isClosed.Load()) +} + +func TestCloseWaitsAfterTxBegin(t *testing.T) { + ctx := context.Background() + t.Run("BeginRoAndCommit", func(t *testing.T) { + testCloseWaitsAfterTxBegin( + t, + 1, + func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) }, + func(tx kv.StatelessReadTx) error { return tx.Commit() }, + ) + }) + t.Run("BeginRoAndCommit3", func(t *testing.T) { + testCloseWaitsAfterTxBegin( + t, + 3, + func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) }, + func(tx kv.StatelessReadTx) error { return tx.Commit() }, + ) + }) + t.Run("BeginRoAndRollback", func(t *testing.T) { + testCloseWaitsAfterTxBegin( + t, + 1, + func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) }, + func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil }, + ) + }) + t.Run("BeginRoAndRollback3", func(t *testing.T) { + testCloseWaitsAfterTxBegin( + t, + 3, + func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) }, + func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil }, + ) + }) + t.Run("BeginRwAndCommit", func(t *testing.T) { + testCloseWaitsAfterTxBegin( + t, + 1, + func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRw(ctx) }, + func(tx kv.StatelessReadTx) error { return tx.Commit() }, + ) + }) + t.Run("BeginRwAndRollback", func(t *testing.T) { + testCloseWaitsAfterTxBegin( + t, + 1, + func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRw(ctx) }, + func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil }, + ) + }) +}