diff --git a/sharding/collation.go b/sharding/collation.go index c0e976974..245558404 100644 --- a/sharding/collation.go +++ b/sharding/collation.go @@ -18,11 +18,27 @@ type Collation struct { // CollationHeader base struct. type CollationHeader struct { - shardID *big.Int //the shard ID of the shard. - chunkRoot *common.Hash //the root of the chunk tree which identifies collation body. - period *big.Int //the period number in which collation to be included. - proposerAddress *common.Address //address of the collation proposer. - proposerSignature []byte //the proposer's signature for calculating collation hash. + data collationHeaderData +} + +type collationHeaderData struct { + ShardID *big.Int // the shard ID of the shard. + ChunkRoot *common.Hash // the root of the chunk tree which identifies collation body. + Period *big.Int // the period number in which collation to be Pncluded. + ProposerAddress *common.Address // address of the collation proposer. + ProposerSignature []byte // the proposer's signature for calculating collation hash. +} + +// NewCollationHeader initializes a collation header struct. +func NewCollationHeader(shardID *big.Int, chunkRoot *common.Hash, period *big.Int, proposerAddress *common.Address, proposerSignature []byte) *CollationHeader { + data := collationHeaderData{ + ShardID: shardID, + ChunkRoot: chunkRoot, + Period: period, + ProposerAddress: proposerAddress, + ProposerSignature: proposerSignature, + } + return &CollationHeader{data: data} } // Hash takes the keccak256 of the collation header's contents. @@ -34,13 +50,28 @@ func (h *CollationHeader) Hash() (hash common.Hash) { } // ShardID is the identifier for a shard. -func (h *CollationHeader) ShardID() *big.Int { return h.shardID } +func (h *CollationHeader) ShardID() *big.Int { return h.data.ShardID } // Period the collation corresponds to. -func (h *CollationHeader) Period() *big.Int { return h.period } +func (h *CollationHeader) Period() *big.Int { return h.data.Period } // ChunkRoot of the serialized collation body. -func (h *CollationHeader) ChunkRoot() *common.Hash { return h.chunkRoot } +func (h *CollationHeader) ChunkRoot() *common.Hash { return h.data.ChunkRoot } + +// EncodeRLP gives an encoded representation of the collation header. +func (h *CollationHeader) EncodeRLP() ([]byte, error) { + encoded, err := rlp.EncodeToBytes(&h.data) + if err != nil { + return nil, err + } + return encoded, nil +} + +// DecodeRLP uses an RLP Stream to populate the data field of a collation header. +func (h *CollationHeader) DecodeRLP(s *rlp.Stream) error { + err := s.Decode(&h.data) + return err +} // Header returns the collation's header. func (c *Collation) Header() *CollationHeader { return c.header } @@ -52,13 +83,18 @@ func (c *Collation) Body() []byte { return c.body } func (c *Collation) Transactions() []*types.Transaction { return c.transactions } // ProposerAddress is the coinbase addr of the creator for the collation. -func (c *Collation) ProposerAddress() *common.Address { return c.header.proposerAddress } - -// SetHeader updates the collation's header. -func (c *Collation) SetHeader(h *CollationHeader) { c.header = h } +func (c *Collation) ProposerAddress() *common.Address { + return c.header.data.ProposerAddress +} // AddTransaction adds to the collation's body of tx blobs. func (c *Collation) AddTransaction(tx *types.Transaction) { // TODO: Include blob serialization instead. c.transactions = append(c.transactions, tx) } + +// SetChunkRoot updates the collation header's chunk root. +func (c *Collation) SetChunkRoot() { + chunkRoot := common.BytesToHash(c.body) + c.header.data.ChunkRoot = &chunkRoot +} diff --git a/sharding/shard.go b/sharding/shard.go index 22fcb5148..0cde3b2f6 100644 --- a/sharding/shard.go +++ b/sharding/shard.go @@ -1,8 +1,8 @@ package sharding import ( + "bytes" "fmt" - "log" "math/big" "github.com/ethereum/go-ethereum/common" @@ -49,12 +49,14 @@ func (s *Shard) HeaderByHash(hash *common.Hash) (*CollationHeader, error) { if err != nil { return nil, fmt.Errorf("header not found: %v", err) } - log.Printf("encoded header in func: %v", encoded) + var header CollationHeader - if err := rlp.DecodeBytes(encoded, &header); err != nil { - return nil, fmt.Errorf("could not decode header: %v", err) + + stream := rlp.NewStream(bytes.NewReader(encoded), uint64(len(encoded))) + if err := header.DecodeRLP(stream); err != nil { + return nil, fmt.Errorf("could not decode RLP header: %v", err) } - log.Printf("decoded header in func: %v", header) + return &header, nil } @@ -64,9 +66,7 @@ func (s *Shard) CollationByHash(headerHash *common.Hash) (*Collation, error) { if err != nil { return nil, err } - if header.ChunkRoot() == nil { - return nil, fmt.Errorf("invalid header fetched: %v", header) - } + body, err := s.BodyByChunkRoot(header.ChunkRoot()) if err != nil { return nil, err @@ -102,7 +102,6 @@ func (s *Shard) CanonicalCollation(shardID *big.Int, period *big.Int) (*Collatio // BodyByChunkRoot fetches a collation body. func (s *Shard) BodyByChunkRoot(chunkRoot *common.Hash) ([]byte, error) { - log.Printf("Chunk Root: %v", chunkRoot) body, err := s.shardDB.Get(*chunkRoot) if err != nil { return nil, fmt.Errorf("no corresponding body with chunk root found: %v", err) @@ -148,10 +147,11 @@ func (s *Shard) SetAvailability(chunkRoot *common.Hash, availability bool) error // SaveHeader adds the collation header to shardDB. func (s *Shard) SaveHeader(header *CollationHeader) error { - encoded, err := rlp.EncodeToBytes(header) + encoded, err := header.EncodeRLP() if err != nil { return fmt.Errorf("cannot encode header: %v", err) } + // Uses the hash of the header as the key. hash := header.Hash() s.shardDB.Put(hash, encoded) diff --git a/sharding/shard_test.go b/sharding/shard_test.go index 98585fd78..ee7cd6103 100644 --- a/sharding/shard_test.go +++ b/sharding/shard_test.go @@ -18,7 +18,9 @@ func (c *Collation) Hash() (hash common.Hash) { return hash } func TestShard_ValidateShardID(t *testing.T) { - header := &CollationHeader{shardID: big.NewInt(4)} + emptyHash := common.StringToHash("") + emptyAddr := common.StringToAddress("") + header := NewCollationHeader(big.NewInt(1), &emptyHash, big.NewInt(1), &emptyAddr, []byte{}) shardDB := makeShardKV() shard := MakeShard(big.NewInt(3), shardDB) @@ -26,7 +28,7 @@ func TestShard_ValidateShardID(t *testing.T) { t.Errorf("ShardID validation incorrect. Function should throw error when ShardID's do not match. want=%d. got=%d", header.ShardID().Int64(), shard.ShardID().Int64()) } - header2 := &CollationHeader{shardID: big.NewInt(100)} + header2 := NewCollationHeader(big.NewInt(100), &emptyHash, big.NewInt(1), &emptyAddr, []byte{}) shard2 := MakeShard(big.NewInt(100), shardDB) if err := shard2.ValidateShardID(header2); err != nil { @@ -35,8 +37,9 @@ func TestShard_ValidateShardID(t *testing.T) { } func TestShard_HeaderByHash(t *testing.T) { - root := common.StringToHash("hi") - header := &CollationHeader{shardID: big.NewInt(1), chunkRoot: &root} + emptyHash := common.StringToHash("") + emptyAddr := common.StringToAddress("") + header := NewCollationHeader(big.NewInt(1), &emptyHash, big.NewInt(1), &emptyAddr, []byte{}) shardDB := makeShardKV() shard := MakeShard(big.NewInt(1), shardDB) @@ -49,8 +52,6 @@ func TestShard_HeaderByHash(t *testing.T) { if err != nil { t.Fatalf("could not fetch collation header by hash: %v", err) } - log.Printf("header in first test: %v", header.ChunkRoot().String()) - log.Printf("db header in first test: %v", dbHeader.ChunkRoot().String()) // Compare the hashes. if header.Hash() != dbHeader.Hash() { t.Errorf("headers do not match. want=%v. got=%v", header, dbHeader) @@ -58,10 +59,19 @@ func TestShard_HeaderByHash(t *testing.T) { } func TestShard_CollationByHash(t *testing.T) { + emptyAddr := common.StringToAddress("") + + // Empty chunk root. + header := NewCollationHeader(big.NewInt(1), nil, big.NewInt(1), &emptyAddr, []byte{}) + collation := &Collation{ - header: &CollationHeader{shardID: big.NewInt(1)}, + header: header, body: []byte{1, 2, 3}, } + + // We set the chunk root. + collation.SetChunkRoot() + shardDB := makeShardKV() shard := MakeShard(big.NewInt(1), shardDB) @@ -69,7 +79,7 @@ func TestShard_CollationByHash(t *testing.T) { t.Fatalf("cannot save collation: %v", err) } hash := collation.Header().Hash() - + log.Printf("header hash: %v", hash) dbCollation, err := shard.CollationByHash(&hash) if err != nil { t.Fatalf("could not fetch collation by hash: %v", err)