prysm-pulse/beacon-chain/sync/initial-sync/fsm_test.go
Victor Farazdagi 8a3b75e9e3
Update init-sync FSM (#7640)
* update fsm

* more tests
2020-10-26 05:26:35 +00:00

463 lines
14 KiB
Go

package initialsync
import (
"errors"
"fmt"
"testing"
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
"github.com/prysmaticlabs/prysm/shared/testutil/require"
)
func TestStateMachineManager_String(t *testing.T) {
tests := []struct {
name string
machines map[uint64]*stateMachine
want string
}{
{
"empty epoch state list",
map[uint64]*stateMachine{},
"map[]",
},
{
"newly created state machine",
map[uint64]*stateMachine{
0: {start: 64 * 0, state: stateNew},
64: {start: 64 * 1, state: stateScheduled},
128: {start: 64 * 2, state: stateDataParsed},
196: {start: 64 * 3, state: stateSkipped},
256: {start: 64 * 4, state: stateSent},
},
"map[0:{0:new} 64:{2:scheduled} 128:{4:dataParsed} 196:{6:skipped} 256:{8:sent}]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smm := &stateMachineManager{
machines: tt.machines,
}
assert.Equal(t, tt.want, smm.String())
})
}
}
func TestStateMachine_StateIDString(t *testing.T) {
stateIDs := []stateID{stateNew, stateScheduled, stateDataParsed, stateSkipped, stateSent}
assert.Equal(t, "[new scheduled dataParsed skipped sent]", fmt.Sprintf("%v", stateIDs))
assert.Equal(t, "stateUnknown", stateID(15).String())
}
func TestStateMachine_EventIDString(t *testing.T) {
eventIDs := []eventID{eventTick, eventDataReceived}
assert.Equal(t, "[tick dataReceived]", fmt.Sprintf("%v", eventIDs))
assert.Equal(t, "eventUnknown", eventID(15).String())
}
func TestStateMachineManager_addEventHandler(t *testing.T) {
smm := newStateMachineManager()
smm.addEventHandler(eventTick, stateNew, func(m *stateMachine, i interface{}) (id stateID, err error) {
return stateScheduled, nil
})
assert.Equal(t, 1, len(smm.handlers[stateNew]), "Unexpected size")
state, err := smm.handlers[stateNew][eventTick](nil, nil)
assert.NoError(t, err)
assert.Equal(t, stateScheduled, state, "Unexpected state")
// Add second handler to the same event
smm.addEventHandler(eventTick, stateSent, func(m *stateMachine, i interface{}) (id stateID, err error) {
return stateDataParsed, nil
})
assert.Equal(t, 1, len(smm.handlers[stateSent]), "Unexpected size")
state, err = smm.handlers[stateSent][eventTick](nil, nil)
assert.NoError(t, err)
assert.Equal(t, stateDataParsed, state, "Unexpected state")
// Add another handler to existing event/state pair. Should have no effect.
smm.addEventHandler(eventTick, stateSent, func(m *stateMachine, i interface{}) (id stateID, err error) {
return stateSkipped, nil
})
assert.Equal(t, 1, len(smm.handlers[stateSent]), "Unexpected size")
state, err = smm.handlers[stateSent][eventTick](nil, nil)
assert.NoError(t, err)
// No effect, previous handler worked.
assert.Equal(t, stateDataParsed, state, "Unexpected state")
}
func TestStateMachine_trigger(t *testing.T) {
type event struct {
state stateID
event eventID
returnState stateID
err bool
}
type args struct {
name eventID
returnState stateID
epoch uint64
data interface{}
}
tests := []struct {
name string
events []event
epochs []uint64
args args
err error
}{
{
name: "event not found",
events: []event{},
epochs: []uint64{12, 13},
args: args{name: eventTick, epoch: 12, data: nil, returnState: stateNew},
err: errors.New("no event handlers registered for event: tick, state: new"),
},
{
name: "single action",
events: []event{
{stateNew, eventTick, stateScheduled, false},
},
epochs: []uint64{12, 13},
args: args{name: eventTick, epoch: 12, data: nil, returnState: stateScheduled},
err: nil,
},
{
name: "multiple actions, has error",
events: []event{
{stateNew, eventTick, stateScheduled, false},
{stateScheduled, eventTick, stateSent, true},
{stateSent, eventTick, stateSkipped, false},
},
epochs: []uint64{12, 13},
args: args{name: eventTick, epoch: 12, data: nil, returnState: stateScheduled},
err: nil,
},
{
name: "multiple actions, no error, can cascade",
events: []event{
{stateNew, eventTick, stateScheduled, false},
{stateScheduled, eventTick, stateSent, false},
{stateSent, eventTick, stateSkipped, false},
},
epochs: []uint64{12, 13},
args: args{name: eventTick, epoch: 12, data: nil, returnState: stateScheduled},
err: nil,
},
{
name: "multiple actions, no error, no cascade",
events: []event{
{stateNew, eventTick, stateScheduled, false},
{stateScheduled, eventTick, stateSent, false},
{stateNew, eventTick, stateSkipped, false},
},
epochs: []uint64{12, 13},
args: args{name: eventTick, epoch: 12, data: nil, returnState: stateScheduled},
err: nil,
},
}
fn := func(e event) eventHandlerFn {
return func(m *stateMachine, in interface{}) (stateID, error) {
if e.err {
return m.state, errors.New("invalid")
}
return e.returnState, nil
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smm := newStateMachineManager()
expectHandlerError := false
for _, event := range tt.events {
smm.addEventHandler(event.event, event.state, fn(event))
if event.err {
expectHandlerError = true
}
}
for _, epoch := range tt.epochs {
smm.addStateMachine(epoch * 32)
}
state := smm.machines[tt.args.epoch*32]
err := state.trigger(tt.args.name, tt.args.data)
if tt.err != nil && (err == nil || tt.err.Error() != err.Error()) {
t.Errorf("unexpected error = '%v', want '%v'", err, tt.err)
}
if tt.err == nil {
if err != nil && !expectHandlerError {
t.Error(err)
}
ind := tt.args.epoch * 32
if smm.machines[ind].state != tt.args.returnState {
t.Errorf("unexpected final state: %v, want: %v (%v)",
smm.machines[ind].state, tt.args.returnState, smm.machines)
}
}
})
}
}
func TestStateMachineManager_QueueLoop(t *testing.T) {
smm := newStateMachineManager()
smm.addEventHandler(eventTick, stateNew, func(m *stateMachine, data interface{}) (stateID, error) {
return stateScheduled, nil
})
smm.addEventHandler(eventTick, stateScheduled, func(m *stateMachine, data interface{}) (stateID, error) {
if m.start < 256 {
return stateDataParsed, nil
}
return stateSkipped, nil
})
smm.addEventHandler(eventTick, stateDataParsed, func(m *stateMachine, data interface{}) (stateID, error) {
return stateSent, nil
})
smm.addEventHandler(eventTick, stateSkipped, func(m *stateMachine, data interface{}) (stateID, error) {
dataParsed, ok := data.(int)
if !ok {
return m.state, errors.New("invalid data type")
}
if dataParsed > 41 {
return stateNew, nil
}
return stateScheduled, nil
})
assert.Equal(t, 4, len(smm.handlers), "Unexpected number of state events")
smm.addStateMachine(64)
smm.addStateMachine(512)
assertState := func(startSlot uint64, state stateID) {
fsm, ok := smm.findStateMachine(startSlot)
require.Equal(t, true, ok, "State machine not found")
assert.Equal(t, state, fsm.state, "Unexpected state machine state")
}
triggerTickEvent := func() {
for _, fsm := range smm.machines {
data := 42
assert.NoError(t, fsm.trigger(eventTick, data))
}
}
assertState(64, stateNew)
assertState(512, stateNew)
triggerTickEvent()
assertState(64, stateScheduled)
assertState(512, stateScheduled)
triggerTickEvent()
assertState(64, stateDataParsed)
assertState(512, stateSkipped)
triggerTickEvent()
assertState(64, stateSent)
assertState(512, stateNew)
}
func TestStateMachineManager_removeStateMachine(t *testing.T) {
smm := newStateMachineManager()
if _, ok := smm.findStateMachine(64); ok {
t.Error("unexpected machine found")
}
smm.addStateMachine(64)
if _, ok := smm.findStateMachine(64); !ok {
t.Error("expected machine not found")
}
expectedError := fmt.Sprintf("state for machine %v is not found", 65)
assert.ErrorContains(t, expectedError, smm.removeStateMachine(65))
assert.NoError(t, smm.removeStateMachine(64))
if _, ok := smm.findStateMachine(64); ok {
t.Error("unexpected machine found")
}
}
func TestStateMachineManager_removeAllStateMachines(t *testing.T) {
smm := newStateMachineManager()
smm.addStateMachine(64)
smm.addStateMachine(128)
smm.addStateMachine(196)
keys := []uint64{64, 128, 196}
assert.DeepEqual(t, smm.keys, keys, "Keys not sorted")
assert.Equal(t, 3, len(smm.machines), "Unexpected list size")
assert.NoError(t, smm.removeAllStateMachines())
keys = []uint64{}
assert.DeepEqual(t, smm.keys, keys, "Unexpected keys")
assert.Equal(t, 0, len(smm.machines), "Expected empty list")
}
func TestStateMachineManager_findStateMachine(t *testing.T) {
smm := newStateMachineManager()
if _, ok := smm.findStateMachine(64); ok {
t.Errorf("unexpected returned value: want: %v, got: %v", false, ok)
}
smm.addStateMachine(64)
if fsm, ok := smm.findStateMachine(64); !ok || fsm == nil {
t.Errorf("unexpected returned value: want: %v, got: %v", true, ok)
}
smm.addStateMachine(512)
smm.addStateMachine(196)
smm.addStateMachine(256)
smm.addStateMachine(128)
if fsm, ok := smm.findStateMachine(128); !ok || fsm.start != 128 {
t.Errorf("unexpected start slot: %v, want: %v", fsm.start, 122)
}
if fsm, ok := smm.findStateMachine(512); !ok || fsm.start != 512 {
t.Errorf("unexpected start slot: %v, want: %v", fsm.start, 512)
}
keys := []uint64{64, 128, 196, 256, 512}
assert.DeepEqual(t, smm.keys, keys, "Keys not sorted")
}
func TestStateMachineManager_highestStartSlot(t *testing.T) {
smm := newStateMachineManager()
_, err := smm.highestStartSlot()
assert.ErrorContains(t, "no state machine exist", err)
smm.addStateMachine(64)
smm.addStateMachine(128)
smm.addStateMachine(196)
start, err := smm.highestStartSlot()
assert.NoError(t, err)
assert.Equal(t, uint64(196), start, "Incorrect highest start slot")
assert.NoError(t, smm.removeStateMachine(196))
start, err = smm.highestStartSlot()
assert.NoError(t, err)
assert.Equal(t, uint64(128), start, "Incorrect highest start slot")
}
func TestStateMachineManager_allMachinesInState(t *testing.T) {
tests := []struct {
name string
smmGen func() *stateMachineManager
expectedStates []stateID
unexpectedStates []stateID
}{
{
name: "empty manager",
smmGen: newStateMachineManager,
expectedStates: []stateID{},
unexpectedStates: []stateID{stateNew, stateScheduled, stateDataParsed, stateSkipped, stateSent},
},
{
name: "single machine default state",
smmGen: func() *stateMachineManager {
smm := newStateMachineManager()
smm.addStateMachine(64)
return smm
},
expectedStates: []stateID{stateNew},
unexpectedStates: []stateID{stateScheduled, stateDataParsed, stateSkipped, stateSent},
},
{
name: "single machine updated state",
smmGen: func() *stateMachineManager {
smm := newStateMachineManager()
m1 := smm.addStateMachine(64)
m1.setState(stateSkipped)
return smm
},
expectedStates: []stateID{stateSkipped},
unexpectedStates: []stateID{stateNew, stateScheduled, stateDataParsed, stateSent},
},
{
name: "multiple machines false",
smmGen: func() *stateMachineManager {
smm := newStateMachineManager()
smm.addStateMachine(64)
smm.addStateMachine(128)
smm.addStateMachine(196)
for _, fsm := range smm.machines {
fsm.setState(stateSkipped)
}
smm.addStateMachine(256)
return smm
},
expectedStates: []stateID{},
unexpectedStates: []stateID{stateNew, stateScheduled, stateDataParsed, stateSkipped, stateSent},
},
{
name: "multiple machines true",
smmGen: func() *stateMachineManager {
smm := newStateMachineManager()
smm.addStateMachine(64)
smm.addStateMachine(128)
smm.addStateMachine(196)
for _, fsm := range smm.machines {
fsm.setState(stateSkipped)
}
return smm
},
expectedStates: []stateID{stateSkipped},
unexpectedStates: []stateID{stateNew, stateScheduled, stateDataParsed, stateSent},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smm := tt.smmGen()
for _, state := range tt.expectedStates {
if !smm.allMachinesInState(state) {
t.Errorf("expected all machines be in state: %v", state)
}
}
for _, state := range tt.unexpectedStates {
if smm.allMachinesInState(state) {
t.Errorf("unexpected state: %v", state)
}
}
})
}
}
func TestStateMachine_isFirstLast(t *testing.T) {
checkFirst := func(m *stateMachine, want bool) {
assert.Equal(t, want, m.isFirst(), "isFirst() returned unexpected value")
}
checkLast := func(m *stateMachine, want bool) {
assert.Equal(t, want, m.isLast(), "isLast() returned unexpected value")
}
smm := newStateMachineManager()
m1 := smm.addStateMachine(64)
checkFirst(m1, true)
checkLast(m1, true)
m2 := smm.addStateMachine(128)
checkFirst(m1, true)
checkLast(m1, false)
checkFirst(m2, false)
checkLast(m2, true)
m3 := smm.addStateMachine(512)
checkFirst(m1, true)
checkLast(m1, false)
checkFirst(m2, false)
checkLast(m2, false)
checkFirst(m3, false)
checkLast(m3, true)
// Add machine with lower start slot - shouldn't be marked as last.
m4 := smm.addStateMachine(196)
checkFirst(m1, true)
checkLast(m1, false)
checkFirst(m2, false)
checkLast(m2, false)
checkFirst(m3, false)
checkLast(m3, true)
checkFirst(m4, false)
checkLast(m4, false)
// Add machine with lowest start slot - should be marked as first.
m5 := smm.addStateMachine(32)
checkFirst(m1, false)
checkLast(m1, false)
checkFirst(m2, false)
checkLast(m2, false)
checkFirst(m3, false)
checkLast(m3, true)
checkFirst(m4, false)
checkLast(m4, false)
checkFirst(m5, true)
checkLast(m5, false)
keys := []uint64{32, 64, 128, 196, 512}
assert.DeepEqual(t, smm.keys, keys, "Keys not sorted")
}