mdbx: race conditions in MdbxKV.Close (#8409) (#9244)

In the previous code WaitGroup db.wg.Add(), Wait() and db.closed were
not treated in sync. In particular, it was theoretically possible to
first check closed, then set closed and Wait, and then call wg.Add()
while waiting (leading to WaitGroup panic).
In theory it was also possible that db.env.BeginTxn() is called on a
closed or nil db.env, because db.wg.Add() was called only after BeginTxn
(db.wg.Wait() could already return).

WaitGroup is replaced with a Cond variable.
Now it is not possible to increase the active transactions count on a
closed database. It is also not possible to call BeginTxn on a closed
database.
This commit is contained in:
battlmonstr 2024-01-17 15:28:37 +01:00 committed by GitHub
parent 5e5d8490b1
commit 1914b52de0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 191 additions and 15 deletions

View File

@ -34,7 +34,6 @@ import (
"github.com/c2h5oh/datasize"
"github.com/erigontech/mdbx-go/mdbx"
stack2 "github.com/go-stack/stack"
"github.com/ledgerwatch/erigon-lib/mmap"
"github.com/ledgerwatch/log/v3"
"golang.org/x/exp/maps"
"golang.org/x/sync/semaphore"
@ -44,6 +43,7 @@ import (
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/iter"
"github.com/ledgerwatch/erigon-lib/kv/order"
"github.com/ledgerwatch/erigon-lib/mmap"
)
const NonExistingDBI kv.DBI = 999_999_999
@ -385,15 +385,20 @@ func (opts MdbxOpts) Open(ctx context.Context) (kv.RwDB, error) {
targetSemCount := int64(runtime.GOMAXPROCS(-1) * 16)
opts.roTxsLimiter = semaphore.NewWeighted(targetSemCount) // 1 less than max to allow unlocking to happen
}
txsCountMutex := &sync.Mutex{}
db := &MdbxKV{
opts: opts,
env: env,
log: opts.log,
wg: &sync.WaitGroup{},
buckets: kv.TableCfg{},
txSize: dirtyPagesLimit * opts.pageSize,
roTxsLimiter: opts.roTxsLimiter,
txsCountMutex: txsCountMutex,
txsAllDoneOnCloseCond: sync.NewCond(txsCountMutex),
leakDetector: dbg.NewLeakDetector("db."+opts.label.String(), dbg.SlowTx()),
}
@ -457,7 +462,6 @@ func (opts MdbxOpts) MustOpen() kv.RwDB {
type MdbxKV struct {
log log.Logger
env *mdbx.Env
wg *sync.WaitGroup
buckets kv.TableCfg
roTxsLimiter *semaphore.Weighted // does limit amount of concurrent Ro transactions - in most casess runtime.NumCPU() is good value for this channel capacity - this channel can be shared with other components (like Decompressor)
opts MdbxOpts
@ -465,6 +469,10 @@ type MdbxKV struct {
closed atomic.Bool
path string
txsCount uint
txsCountMutex *sync.Mutex
txsAllDoneOnCloseCond *sync.Cond
leakDetector *dbg.LeakDetector
}
@ -507,13 +515,53 @@ func (db *MdbxKV) openDBIs(buckets []string) error {
})
}
func (db *MdbxKV) trackTxBegin() bool {
db.txsCountMutex.Lock()
defer db.txsCountMutex.Unlock()
isOpen := !db.closed.Load()
if isOpen {
db.txsCount++
}
return isOpen
}
func (db *MdbxKV) hasTxsAllDoneAndClosed() bool {
return (db.txsCount == 0) && db.closed.Load()
}
func (db *MdbxKV) trackTxEnd() {
db.txsCountMutex.Lock()
defer db.txsCountMutex.Unlock()
if db.txsCount > 0 {
db.txsCount--
} else {
panic("MdbxKV: unmatched trackTxEnd")
}
if db.hasTxsAllDoneAndClosed() {
db.txsAllDoneOnCloseCond.Signal()
}
}
func (db *MdbxKV) waitTxsAllDoneOnClose() {
db.txsCountMutex.Lock()
defer db.txsCountMutex.Unlock()
for !db.hasTxsAllDoneAndClosed() {
db.txsAllDoneOnCloseCond.Wait()
}
}
// Close closes db
// All transactions must be closed before closing the database.
func (db *MdbxKV) Close() {
if ok := db.closed.CompareAndSwap(false, true); !ok {
return
}
db.wg.Wait()
db.waitTxsAllDoneOnClose()
db.env.Close()
db.env = nil
@ -526,10 +574,6 @@ func (db *MdbxKV) Close() {
}
func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) {
if db.closed.Load() {
return nil, fmt.Errorf("db closed")
}
// don't try to acquire if the context is already done
select {
case <-ctx.Done():
@ -538,8 +582,13 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) {
// otherwise carry on
}
if !db.trackTxBegin() {
return nil, fmt.Errorf("db closed")
}
// will return nil err if context is cancelled (may appear to acquire the semaphore)
if semErr := db.roTxsLimiter.Acquire(ctx, 1); semErr != nil {
db.trackTxEnd()
return nil, semErr
}
@ -548,6 +597,7 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) {
// on error, or if there is whatever reason that we don't return a tx,
// we need to free up the limiter slot, otherwise it could lead to deadlocks
db.roTxsLimiter.Release(1)
db.trackTxEnd()
}
}()
@ -555,7 +605,7 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) {
if err != nil {
return nil, fmt.Errorf("%w, label: %s, trace: %s", err, db.opts.label.String(), stack2.Trace().String())
}
db.wg.Add(1)
return &MdbxTx{
ctx: ctx,
db: db,
@ -579,16 +629,18 @@ func (db *MdbxKV) beginRw(ctx context.Context, flags uint) (txn kv.RwTx, err err
default:
}
if db.closed.Load() {
if !db.trackTxBegin() {
return nil, fmt.Errorf("db closed")
}
runtime.LockOSThread()
tx, err := db.env.BeginTxn(nil, flags)
if err != nil {
runtime.UnlockOSThread() // unlock only in case of error. normal flow is "defer .Rollback()"
db.trackTxEnd()
return nil, fmt.Errorf("%w, lable: %s, trace: %s", err, db.opts.label.String(), stack2.Trace().String())
}
db.wg.Add(1)
return &MdbxTx{
db: db,
tx: tx,
@ -830,7 +882,7 @@ func (tx *MdbxTx) Commit() error {
}
defer func() {
tx.tx = nil
tx.db.wg.Done()
tx.db.trackTxEnd()
if tx.readOnly {
tx.db.roTxsLimiter.Release(1)
} else {
@ -881,7 +933,7 @@ func (tx *MdbxTx) Rollback() {
}
defer func() {
tx.tx = nil
tx.db.wg.Done()
tx.db.trackTxEnd()
if tx.readOnly {
tx.db.roTxsLimiter.Release(1)
} else {

View File

@ -18,14 +18,17 @@ package mdbx
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/c2h5oh/datasize"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/order"
"github.com/ledgerwatch/log/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/order"
)
func BaseCase(t *testing.T) (kv.RwDB, kv.RwTx, kv.RwCursorDupSort) {
@ -773,3 +776,124 @@ func TestAutoConversionSeekBothRange(t *testing.T) {
require.NoError(t, err)
assert.Nil(t, v)
}
func TestBeginRoAfterClose(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
db.Close()
_, err := db.BeginRo(context.Background())
require.ErrorContains(t, err, "closed")
}
func TestBeginRwAfterClose(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
db.Close()
_, err := db.BeginRw(context.Background())
require.ErrorContains(t, err, "closed")
}
func TestBeginRoWithDoneContext(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := db.BeginRo(ctx)
require.ErrorIs(t, err, context.Canceled)
}
func TestBeginRwWithDoneContext(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := db.BeginRw(ctx)
require.ErrorIs(t, err, context.Canceled)
}
func testCloseWaitsAfterTxBegin(
t *testing.T,
count int,
txBeginFunc func(kv.RwDB) (kv.StatelessReadTx, error),
txEndFunc func(kv.StatelessReadTx) error,
) {
t.Helper()
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
var txs []kv.StatelessReadTx
for i := 0; i < count; i++ {
tx, err := txBeginFunc(db)
require.Nil(t, err)
txs = append(txs, tx)
}
isClosed := &atomic.Bool{}
closeDone := make(chan struct{})
go func() {
db.Close()
isClosed.Store(true)
close(closeDone)
}()
for _, tx := range txs {
// arbitrary delay to give db.Close() a chance to exit prematurely
time.Sleep(time.Millisecond * 20)
assert.False(t, isClosed.Load())
err := txEndFunc(tx)
require.Nil(t, err)
}
<-closeDone
assert.True(t, isClosed.Load())
}
func TestCloseWaitsAfterTxBegin(t *testing.T) {
ctx := context.Background()
t.Run("BeginRoAndCommit", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { return tx.Commit() },
)
})
t.Run("BeginRoAndCommit3", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
3,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { return tx.Commit() },
)
})
t.Run("BeginRoAndRollback", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil },
)
})
t.Run("BeginRoAndRollback3", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
3,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil },
)
})
t.Run("BeginRwAndCommit", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRw(ctx) },
func(tx kv.StatelessReadTx) error { return tx.Commit() },
)
})
t.Run("BeginRwAndRollback", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRw(ctx) },
func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil },
)
})
}