package migrations import ( "context" "errors" "testing" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv/memdb" "github.com/ledgerwatch/erigon/eth/stagedsync/stages" "github.com/ledgerwatch/erigon/node/nodecfg/datadir" "github.com/stretchr/testify/require" ) func TestApplyWithInit(t *testing.T) { require, db := require.New(t), memdb.NewTestDB(t) m := []Migration{ { "one", func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { tx, err := db.BeginRw(context.Background()) if err != nil { return err } defer tx.Rollback() if err := BeforeCommit(tx, nil, true); err != nil { return err } return tx.Commit() }, }, { "two", func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { tx, err := db.BeginRw(context.Background()) if err != nil { return err } defer tx.Rollback() if err := BeforeCommit(tx, nil, true); err != nil { return err } return tx.Commit() }, }, } migrator := NewMigrator(kv.ChainDB) migrator.Migrations = m err := migrator.Apply(db, "") require.NoError(err) var applied map[string][]byte err = db.View(context.Background(), func(tx kv.Tx) error { applied, err = AppliedMigrations(tx, false) require.NoError(err) _, ok := applied[m[0].Name] require.True(ok) _, ok = applied[m[1].Name] require.True(ok) return nil }) require.NoError(err) // apply again err = migrator.Apply(db, "") require.NoError(err) err = db.View(context.Background(), func(tx kv.Tx) error { applied2, err := AppliedMigrations(tx, false) require.NoError(err) require.Equal(applied, applied2) return nil }) require.NoError(err) } func TestApplyWithoutInit(t *testing.T) { require, db := require.New(t), memdb.NewTestDB(t) m := []Migration{ { "one", func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { t.Fatal("shouldn't been executed") return nil }, }, { "two", func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { tx, err := db.BeginRw(context.Background()) if err != nil { return err } defer tx.Rollback() if err := BeforeCommit(tx, nil, true); err != nil { return err } return tx.Commit() }, }, } err := db.Update(context.Background(), func(tx kv.RwTx) error { return tx.Put(kv.Migrations, []byte(m[0].Name), []byte{1}) }) require.NoError(err) migrator := NewMigrator(kv.ChainDB) migrator.Migrations = m err = migrator.Apply(db, "") require.NoError(err) var applied map[string][]byte err = db.View(context.Background(), func(tx kv.Tx) error { applied, err = AppliedMigrations(tx, false) require.NoError(err) require.Equal(2, len(applied)) _, ok := applied[m[1].Name] require.True(ok) _, ok = applied[m[0].Name] require.True(ok) return nil }) require.NoError(err) // apply again err = migrator.Apply(db, "") require.NoError(err) err = db.View(context.Background(), func(tx kv.Tx) error { applied2, err := AppliedMigrations(tx, false) require.NoError(err) require.Equal(applied, applied2) return nil }) require.NoError(err) } func TestWhenNonFirstMigrationAlreadyApplied(t *testing.T) { require, db := require.New(t), memdb.NewTestDB(t) m := []Migration{ { "one", func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { tx, err := db.BeginRw(context.Background()) if err != nil { return err } defer tx.Rollback() if err := BeforeCommit(tx, nil, true); err != nil { return err } return tx.Commit() }, }, { "two", func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { t.Fatal("shouldn't been executed") return nil }, }, } err := db.Update(context.Background(), func(tx kv.RwTx) error { return tx.Put(kv.Migrations, []byte(m[1].Name), []byte{1}) // apply non-first migration }) require.NoError(err) migrator := NewMigrator(kv.ChainDB) migrator.Migrations = m err = migrator.Apply(db, "") require.NoError(err) var applied map[string][]byte err = db.View(context.Background(), func(tx kv.Tx) error { applied, err = AppliedMigrations(tx, false) require.NoError(err) require.Equal(2, len(applied)) _, ok := applied[m[1].Name] require.True(ok) _, ok = applied[m[0].Name] require.True(ok) return nil }) require.NoError(err) // apply again err = migrator.Apply(db, "") require.NoError(err) err = db.View(context.Background(), func(tx kv.Tx) error { applied2, err := AppliedMigrations(tx, false) require.NoError(err) require.Equal(applied, applied2) return nil }) require.NoError(err) } func TestMarshalStages(t *testing.T) { require := require.New(t) _, tx := memdb.NewTestTx(t) err := stages.SaveStageProgress(tx, stages.Execution, 42) require.NoError(err) data, err := MarshalMigrationPayload(tx) require.NoError(err) res, err := UnmarshalMigrationPayload(data) require.NoError(err) require.Equal(1, len(res)) v, ok := res[string(stages.Execution)] require.True(ok) require.NotNil(v) } func TestValidation(t *testing.T) { require, db := require.New(t), memdb.NewTestDB(t) m := []Migration{ { Name: "repeated_name", Up: func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { tx, err := db.BeginRw(context.Background()) if err != nil { return err } defer tx.Rollback() if err := BeforeCommit(tx, nil, true); err != nil { return err } return tx.Commit() }, }, { Name: "repeated_name", Up: func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { tx, err := db.BeginRw(context.Background()) if err != nil { return err } defer tx.Rollback() if err := BeforeCommit(tx, nil, true); err != nil { return err } return tx.Commit() }, }, } migrator := NewMigrator(kv.ChainDB) migrator.Migrations = m err := migrator.Apply(db, "") require.True(errors.Is(err, ErrMigrationNonUniqueName)) var applied map[string][]byte err = db.View(context.Background(), func(tx kv.Tx) error { applied, err = AppliedMigrations(tx, false) require.NoError(err) require.Equal(0, len(applied)) return nil }) require.NoError(err) } func TestCommitCallRequired(t *testing.T) { require, db := require.New(t), memdb.NewTestDB(t) m := []Migration{ { Name: "one", Up: func(db kv.RwDB, dirs datadir.Dirs, progress []byte, BeforeCommit Callback) (err error) { //don't call BeforeCommit return nil }, }, } migrator := NewMigrator(kv.ChainDB) migrator.Migrations = m err := migrator.Apply(db, "") require.True(errors.Is(err, ErrMigrationCommitNotCalled)) var applied map[string][]byte err = db.View(context.Background(), func(tx kv.Tx) error { applied, err = AppliedMigrations(tx, false) require.NoError(err) require.Equal(0, len(applied)) return nil }) require.NoError(err) }