diff --git a/kv/mdbx/kv_mdbx.go b/kv/mdbx/kv_mdbx.go index 22739699b..f21d18868 100644 --- a/kv/mdbx/kv_mdbx.go +++ b/kv/mdbx/kv_mdbx.go @@ -34,6 +34,7 @@ import ( "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/log/v3" "github.com/torquem-ch/mdbx-go/mdbx" + "go.uber.org/atomic" ) const NonExistingDBI kv.DBI = 999_999_999 @@ -324,6 +325,7 @@ type MdbxKV struct { opts MdbxOpts txSize uint64 roTxsLimiter chan struct{} // does limit amount of concurrent Ro transactions - in most casess runtime.NumCPU() is good value for this channel capacity - this channel can be shared with other components (like Decompressor) + closed atomic.Bool } // openDBIs - first trying to open existing DBI's in RO transaction @@ -365,9 +367,10 @@ func (db *MdbxKV) openDBIs(buckets []string) error { // Close closes db // All transactions must be closed before closing the database. func (db *MdbxKV) Close() { - if db.env == nil { + if db.closed.Load() { return } + db.closed.Store(true) db.wg.Wait() db.env.Close() db.env = nil @@ -380,14 +383,14 @@ func (db *MdbxKV) Close() { } func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) { + if db.closed.Load() { + return nil, fmt.Errorf("db closed") + } select { case <-ctx.Done(): return nil, ctx.Err() case db.roTxsLimiter <- struct{}{}: } - if db.env == nil { - return nil, fmt.Errorf("db closed") - } defer func() { if err == nil { @@ -408,7 +411,7 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) { } func (db *MdbxKV) BeginRw(_ context.Context) (txn kv.RwTx, err error) { - if db.env == nil { + if db.closed.Load() { return nil, fmt.Errorf("db closed") } runtime.LockOSThread() @@ -621,12 +624,9 @@ func (tx *MdbxTx) ListBuckets() ([]string, error) { } func (db *MdbxKV) View(ctx context.Context, f func(tx kv.Tx) error) (err error) { - if db.env == nil { + if db.closed.Load() { return fmt.Errorf("db closed") } - db.wg.Add(1) - defer db.wg.Done() - // can't use db.evn.View method - because it calls commit for read transactions - it conflicts with write transactions. tx, err := db.BeginRo(ctx) if err != nil { @@ -638,11 +638,9 @@ func (db *MdbxKV) View(ctx context.Context, f func(tx kv.Tx) error) (err error) } func (db *MdbxKV) Update(ctx context.Context, f func(tx kv.RwTx) error) (err error) { - if db.env == nil { + if db.closed.Load() { return fmt.Errorf("db closed") } - db.wg.Add(1) - defer db.wg.Done() tx, err := db.BeginRw(ctx) if err != nil { @@ -754,9 +752,6 @@ func (tx *MdbxTx) ExistsBucket(bucket string) (bool, error) { } func (tx *MdbxTx) Commit() error { - if tx.db.env == nil { - return fmt.Errorf("db closed") - } if tx.tx == nil { return nil } @@ -815,9 +810,6 @@ func (tx *MdbxTx) Commit() error { } func (tx *MdbxTx) Rollback() { - if tx.db.env == nil { - return - } if tx.tx == nil { return }