Add DeepSSZEqual and DeepNotSSZEqual (#8421)

This commit is contained in:
Ivan Martinez 2021-02-09 15:57:22 -05:00 committed by GitHub
parent 2f98e6aaaf
commit cd3851c3d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 186 additions and 11 deletions

View File

@ -7,7 +7,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/prysmaticlabs/eth2-types"
types "github.com/prysmaticlabs/eth2-types"
ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/beacon-chain/cache/depositcache"
"github.com/prysmaticlabs/prysm/beacon-chain/core/blocks"
@ -880,11 +880,11 @@ func TestUpdateJustifiedInitSync(t *testing.T) {
require.NoError(t, service.updateJustifiedInitSync(ctx, newCp))
assert.DeepEqual(t, currentCp, service.prevJustifiedCheckpt, "Incorrect previous justified checkpoint")
assert.DeepEqual(t, newCp, service.CurrentJustifiedCheckpt(), "Incorrect current justified checkpoint in cache")
assert.DeepSSZEqual(t, currentCp, service.prevJustifiedCheckpt, "Incorrect previous justified checkpoint")
assert.DeepSSZEqual(t, newCp, service.CurrentJustifiedCheckpt(), "Incorrect current justified checkpoint in cache")
cp, err := service.beaconDB.JustifiedCheckpoint(ctx)
require.NoError(t, err)
assert.DeepEqual(t, newCp, cp, "Incorrect current justified checkpoint in db")
assert.DeepSSZEqual(t, newCp, cp, "Incorrect current justified checkpoint in db")
}
func TestHandleEpochBoundary_BadMetrics(t *testing.T) {

View File

@ -4,7 +4,7 @@ import (
"testing"
"github.com/gogo/protobuf/proto"
"github.com/prysmaticlabs/eth2-types"
types "github.com/prysmaticlabs/eth2-types"
ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/beacon-chain/core/state"
pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
@ -72,8 +72,8 @@ func TestGenesisBeaconState_OK(t *testing.T) {
// Recent state checks.
assert.DeepEqual(t, make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector), newState.Slashings(), "Slashings was not correctly initialized")
assert.DeepEqual(t, []*pb.PendingAttestation{}, newState.CurrentEpochAttestations(), "CurrentEpochAttestations was not correctly initialized")
assert.DeepEqual(t, []*pb.PendingAttestation{}, newState.PreviousEpochAttestations(), "PreviousEpochAttestations was not correctly initialized")
assert.DeepSSZEqual(t, []*pb.PendingAttestation{}, newState.CurrentEpochAttestations(), "CurrentEpochAttestations was not correctly initialized")
assert.DeepSSZEqual(t, []*pb.PendingAttestation{}, newState.PreviousEpochAttestations(), "PreviousEpochAttestations was not correctly initialized")
zeroHash := params.BeaconConfig().ZeroHash[:]
// History root checks.
@ -82,7 +82,7 @@ func TestGenesisBeaconState_OK(t *testing.T) {
// Deposit root checks.
assert.DeepEqual(t, eth1Data.DepositRoot, newState.Eth1Data().DepositRoot, "Eth1Data DepositRoot was not correctly initialized")
assert.DeepEqual(t, []*ethpb.Eth1Data{}, newState.Eth1DataVotes(), "Eth1DataVotes was not correctly initialized")
assert.DeepSSZEqual(t, []*ethpb.Eth1Data{}, newState.Eth1DataVotes(), "Eth1DataVotes was not correctly initialized")
}
func TestGenesisState_HashEquality(t *testing.T) {

View File

@ -313,7 +313,7 @@ func TestPool_MarkIncludedProposerSlashing(t *testing.T) {
p.MarkIncludedProposerSlashing(tt.args.slashing)
assert.Equal(t, len(tt.want.pending), len(p.pendingProposerSlashing))
for i := range p.pendingProposerSlashing {
assert.DeepEqual(t, tt.want.pending[i], p.pendingProposerSlashing[i], "Unexpected pending proposer slashing at index %d", i)
assert.DeepSSZEqual(t, tt.want.pending[i], p.pendingProposerSlashing[i], "Unexpected pending proposer slashing at index %d", i)
}
assert.DeepEqual(t, tt.want.included, p.included)
})

View File

@ -25,6 +25,16 @@ func DeepNotEqual(tb assertions.AssertionTestingTB, expected, actual interface{}
assertions.DeepNotEqual(tb.Errorf, expected, actual, msg...)
}
// DeepSSZEqual compares values using sszutil.DeepEqual.
func DeepSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepSSZEqual(tb.Errorf, expected, actual, msg...)
}
// DeepNotSSZEqual compares values using sszutil.DeepEqual.
func DeepNotSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepNotSSZEqual(tb.Errorf, expected, actual, msg...)
}
// NoError asserts that error is nil.
func NoError(tb assertions.AssertionTestingTB, err error, msg ...interface{}) {
assertions.NoError(tb.Errorf, err, msg...)

View File

@ -7,7 +7,9 @@ go_library(
importpath = "github.com/prysmaticlabs/prysm/shared/testutil/assertions",
visibility = ["//visibility:public"],
deps = [
"//shared/sszutil:go_default_library",
"@com_github_d4l3k_messagediff//:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
],
)

View File

@ -8,6 +8,8 @@ import (
"strings"
"github.com/d4l3k/messagediff"
"github.com/gogo/protobuf/proto"
"github.com/prysmaticlabs/prysm/shared/sszutil"
"github.com/sirupsen/logrus/hooks/test"
)
@ -39,7 +41,7 @@ func NotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...i
// DeepEqual compares values using DeepEqual.
func DeepEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !reflect.DeepEqual(expected, actual) {
if !isDeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual)
@ -49,7 +51,26 @@ func DeepEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...
// DeepNotEqual compares values using DeepEqual.
func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if reflect.DeepEqual(expected, actual) {
if isDeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
}
}
// DeepSSZEqual compares values using sszutil.DeepEqual.
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !sszutil.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual)
loggerFn("%s:%d %s, want: %#v, got: %#v, diff: %s", filepath.Base(file), line, errMsg, expected, actual, diff)
}
}
// DeepNotSSZEqual compares values using sszutil.DeepEqual.
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if sszutil.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
@ -144,6 +165,14 @@ func parseMsg(defaultMsg string, msg ...interface{}) string {
return defaultMsg
}
func isDeepEqual(expected, actual interface{}) bool {
_, isProto := expected.(proto.Message)
if isProto {
return proto.Equal(expected.(proto.Message), actual.(proto.Message))
}
return reflect.DeepEqual(expected, actual)
}
// TBMock exposes enough testing.TB methods for assertions.
type TBMock struct {
ErrorfMsg string

View File

@ -297,6 +297,130 @@ func TestAssert_DeepNotEqual(t *testing.T) {
}
}
func TestAssert_DeepSSZEqual(t *testing.T) {
type args struct {
tb *assertions.TBMock
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
expectedResult bool
}{
{
name: "equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{42},
},
expectedResult: true,
},
{
name: "equal structs",
args: args{
tb: &assertions.TBMock{},
expected: &eth.Checkpoint{
Epoch: 5,
Root: []byte("hi there"),
},
actual: &eth.Checkpoint{
Epoch: 5,
Root: []byte("hi there"),
},
},
expectedResult: true,
},
{
name: "non-equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{41},
},
expectedResult: false,
},
}
for _, tt := range tests {
verify := func() {
if tt.expectedResult && tt.args.tb.ErrorfMsg != "" {
t.Errorf("Unexpected error: %s %v", tt.name, tt.args.tb.ErrorfMsg)
}
}
t.Run(fmt.Sprintf("Assert/%s", tt.name), func(t *testing.T) {
assert.DeepSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
t.Run(fmt.Sprintf("Require/%s", tt.name), func(t *testing.T) {
require.DeepSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
}
}
func TestAssert_DeepNotSSZEqual(t *testing.T) {
type args struct {
tb *assertions.TBMock
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
expectedResult bool
}{
{
name: "equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{42},
},
expectedResult: true,
},
{
name: "non-equal values",
args: args{
tb: &assertions.TBMock{},
expected: struct{ I uint64 }{42},
actual: struct{ I uint64 }{41},
},
expectedResult: false,
},
{
name: "not equal structs",
args: args{
tb: &assertions.TBMock{},
expected: &eth.Checkpoint{
Epoch: 5,
Root: []byte("hello there"),
},
actual: &eth.Checkpoint{
Epoch: 3,
Root: []byte("hi there"),
},
},
expectedResult: true,
},
}
for _, tt := range tests {
verify := func() {
if !tt.expectedResult && tt.args.tb.ErrorfMsg != "" {
t.Errorf("Unexpected error: %s %v", tt.name, tt.args.tb.ErrorfMsg)
}
}
t.Run(fmt.Sprintf("Assert/%s", tt.name), func(t *testing.T) {
assert.DeepNotSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
t.Run(fmt.Sprintf("Require/%s", tt.name), func(t *testing.T) {
require.DeepNotSSZEqual(tt.args.tb, tt.args.expected, tt.args.actual)
verify()
})
}
}
func TestAssert_NoError(t *testing.T) {
type args struct {
tb *assertions.TBMock

View File

@ -25,6 +25,16 @@ func DeepNotEqual(tb assertions.AssertionTestingTB, expected, actual interface{}
assertions.DeepNotEqual(tb.Fatalf, expected, actual, msg...)
}
// DeepSSZEqual compares values using DeepEqual.
func DeepSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepSSZEqual(tb.Fatalf, expected, actual, msg...)
}
// DeepNotSSZEqual compares values using DeepEqual.
func DeepNotSSZEqual(tb assertions.AssertionTestingTB, expected, actual interface{}, msg ...interface{}) {
assertions.DeepNotSSZEqual(tb.Fatalf, expected, actual, msg...)
}
// NoError asserts that error is nil.
func NoError(tb assertions.AssertionTestingTB, err error, msg ...interface{}) {
assertions.NoError(tb.Fatalf, err, msg...)