diff --git a/sharding/shard.go b/sharding/shard.go index fd550d761..abaac5305 100644 --- a/sharding/shard.go +++ b/sharding/shard.go @@ -25,6 +25,11 @@ func MakeShard(shardID *big.Int) *Shard { } } +// ShardID gets the shard's identifier. +func (s *Shard) ShardID() *big.Int { + return s.shardID +} + // ValidateShardID checks if header belongs to shard. func (s *Shard) ValidateShardID(h *CollationHeader) error { if s.shardID.Cmp(h.shardID) != 0 { diff --git a/sharding/shard_test.go b/sharding/shard_test.go index c67736df6..5116e00d4 100644 --- a/sharding/shard_test.go +++ b/sharding/shard_test.go @@ -1,24 +1,22 @@ package sharding import ( + "math/big" "testing" ) func TestShard_ValidateShardID(t *testing.T) { - tests := []struct { - headers []*CollationHeader - }{ - { - headers: nil, - }, { - headers: nil, - }, + header := &CollationHeader{shardID: big.NewInt(4)} + shard := MakeShard(big.NewInt(3)) + + if err := shard.ValidateShardID(header); err == nil { + t.Fatalf("Shard ID validation incorrect. Function should throw error when shardID's do not match. want=%d. got=%d", header.shardID.Int64(), shard.ShardID().Int64()) } - for _, tt := range tests { - t.Logf("val: %v", tt.headers) - if 0 == 1 { - t.Fatalf("Wrong number of transactions. want=%d. got=%d", 5, 3) - } + header2 := &CollationHeader{shardID: big.NewInt(100)} + shard2 := MakeShard(big.NewInt(100)) + + if err := shard2.ValidateShardID(header2); err != nil { + t.Fatalf("Shard ID validation incorrect. Function should not throw error when shardID's match. want=%d. got=%d", header2.shardID.Int64(), shard2.ShardID().Int64()) } }