diff --git a/kv/memdb/memory_database.go b/kv/memdb/memory_database.go index 73bf24818..337a12e3d 100644 --- a/kv/memdb/memory_database.go +++ b/kv/memdb/memory_database.go @@ -20,10 +20,9 @@ import ( "context" "testing" - "github.com/ledgerwatch/log/v3" - "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv/mdbx" + "github.com/ledgerwatch/log/v3" ) func New() kv.RwDB { @@ -47,6 +46,26 @@ func NewTestDB(tb testing.TB) kv.RwDB { return db } +func BeginRw(tb testing.TB, db kv.RwDB) kv.RwTx { + tb.Helper() + tx, err := db.BeginRw(context.Background()) + if err != nil { + tb.Fatal(err) + } + tb.Cleanup(tx.Rollback) + return tx +} + +func BeginRo(tb testing.TB, db kv.RoDB) kv.Tx { + tb.Helper() + tx, err := db.BeginRo(context.Background()) + if err != nil { + tb.Fatal(err) + } + tb.Cleanup(tx.Rollback) + return tx +} + func NewTestPoolDB(tb testing.TB) kv.RwDB { tb.Helper() db := NewPoolDB()