diff --git a/sharding/shard.go b/sharding/shard.go index 8a4edc710..8eca64482 100644 --- a/sharding/shard.go +++ b/sharding/shard.go @@ -77,22 +77,32 @@ func (s *Shard) CollationByHash(headerHash *common.Hash) (*Collation, error) { return &Collation{header: header, body: body}, nil } -// CanonicalCollationHash gets a collation header hash that has been set as canonical for +// CanonicalHeaderHash gets a collation header hash that has been set as canonical for // shardID/period pair -func (s *Shard) CanonicalCollationHash(shardID *big.Int, period *big.Int) (*common.Hash, error) { +func (s *Shard) CanonicalHeaderHash(shardID *big.Int, period *big.Int) (*common.Hash, error) { key := canonicalCollationLookupKey(shardID, period) - hash := common.BytesToHash(key.Bytes()) - collationHashBytes, err := s.shardDB.Get(hash) - if err != nil || len(collationHashBytes) == 0 { - return nil, fmt.Errorf("no canonical collation set for period, shardID pair: %v", err) + + // fetches the RLP encoded collation header corresponding to the key. + encoded, err := s.shardDB.Get(key) + if err != nil || len(encoded) == 0 { + return nil, fmt.Errorf("no canonical collation header set for period, shardID pair: %v", err) } - collationHash := common.BytesToHash(collationHashBytes) + + // RLP decodes the header, computes its hash. + var header CollationHeader + + stream := rlp.NewStream(bytes.NewReader(encoded), uint64(len(encoded))) + if err := header.DecodeRLP(stream); err != nil { + return nil, fmt.Errorf("could not decode RLP header: %v", err) + } + + collationHash := header.Hash() return &collationHash, nil } // CanonicalCollation fetches the collation set as canonical in the shardDB. func (s *Shard) CanonicalCollation(shardID *big.Int, period *big.Int) (*Collation, error) { - h, err := s.CanonicalCollationHash(shardID, period) + h, err := s.CanonicalHeaderHash(shardID, period) if err != nil { return nil, fmt.Errorf("hash not found: %v", err) } @@ -203,11 +213,14 @@ func (s *Shard) SetCanonical(header *CollationHeader) error { if err != nil { return err } + key := canonicalCollationLookupKey(dbHeader.ShardID(), dbHeader.Period()) encoded, err := dbHeader.EncodeRLP() if err != nil { return fmt.Errorf("cannot encode header: %v", err) } + // sets the key to be the canonical collation lookup key and val as RLP encoded + // collation header. if err := s.shardDB.Put(key, encoded); err != nil { return fmt.Errorf("cannot update shardDB: %v", err) } diff --git a/sharding/shard_test.go b/sharding/shard_test.go index 90e1f052b..8f1eae80a 100644 --- a/sharding/shard_test.go +++ b/sharding/shard_test.go @@ -90,3 +90,34 @@ func TestShard_CollationByHash(t *testing.T) { t.Errorf("collations do not match. want=%v. got=%v", collation, dbCollation) } } + +func TestShard_CanonicalHeaderHash(t *testing.T) { + shardID := big.NewInt(1) + period := big.NewInt(1) + proposerAddress := common.StringToAddress("") + proposerSignature := []byte{} + emptyHash := common.StringToHash("") + header := NewCollationHeader(shardID, &emptyHash, period, &proposerAddress, proposerSignature) + + shardDB := database.MakeShardKV() + shard := MakeShard(shardID, shardDB) + + if err := shard.SaveHeader(header); err != nil { + t.Fatalf("failed to save header to shardDB: %v", err) + } + + if err := shard.SetCanonical(header); err != nil { + t.Fatalf("failed to set header as canonical: %v", err) + } + + headerHash := header.Hash() + + canonicalHeaderHash, err := shard.CanonicalHeaderHash(shardID, period) + if err != nil { + t.Fatalf("failed to get canonical header hash from shardDB: %v", err) + } + + if canonicalHeaderHash.String() != headerHash.String() { + t.Errorf("header hashes do not match. want=%v. got=%v", headerHash.String(), canonicalHeaderHash.String()) + } +}