diff --git a/sharding/database/inmemory.go b/sharding/database/inmemory.go index 02150e91c..2f1920f61 100644 --- a/sharding/database/inmemory.go +++ b/sharding/database/inmemory.go @@ -11,16 +11,16 @@ import ( // ShardKV is an in-memory mapping of hashes to RLP encoded values. type ShardKV struct { - kv map[common.Hash][]byte + kv map[common.Hash]*[]byte } // NewShardKV initializes a keyval store in memory. func NewShardKV() *ShardKV { - return &ShardKV{kv: make(map[common.Hash][]byte)} + return &ShardKV{kv: make(map[common.Hash]*[]byte)} } // Get fetches a val from the mappping by key. -func (sb *ShardKV) Get(k common.Hash) ([]byte, error) { +func (sb *ShardKV) Get(k common.Hash) (*[]byte, error) { v, ok := sb.kv[k] if !ok { return nil, fmt.Errorf("key not found: %v", k) @@ -37,7 +37,7 @@ func (sb *ShardKV) Has(k common.Hash) bool { // Put updates a key's value in the mapping. func (sb *ShardKV) Put(k common.Hash, v []byte) error { // there is no error in a simple setting of a value in a go map. - sb.kv[k] = v + sb.kv[k] = &v return nil } diff --git a/sharding/database/inmemory_test.go b/sharding/database/inmemory_test.go index d61c195f2..67c1d11d1 100644 --- a/sharding/database/inmemory_test.go +++ b/sharding/database/inmemory_test.go @@ -8,7 +8,7 @@ import ( ) // Verifies that ShardKV implements the ShardBackend interface. -var _ = sharding.ShardBackend(&shardKV{}) +var _ = sharding.ShardBackend(&ShardKV{}) func Test_ShardKVPut(t *testing.T) { kv := NewShardKV() @@ -55,7 +55,7 @@ func Test_ShardKVGet(t *testing.T) { hash2 := common.StringToHash("") val2, err := kv.Get(hash2) - if err == nil { + if val2 != nil { t.Errorf("non-existent key should not have a value. key=%v, value=%v", hash2, val2) } } diff --git a/sharding/shard.go b/sharding/shard.go index eeedc088f..1d62af965 100644 --- a/sharding/shard.go +++ b/sharding/shard.go @@ -12,7 +12,7 @@ import ( // ShardBackend defines an interface for a shardDB's necessary method // signatures. type ShardBackend interface { - Get(k common.Hash) ([]byte, error) + Get(k common.Hash) (*[]byte, error) Has(k common.Hash) bool Put(k common.Hash, val []byte) error Delete(k common.Hash) error @@ -49,12 +49,15 @@ func (s *Shard) ValidateShardID(h *CollationHeader) error { func (s *Shard) HeaderByHash(hash *common.Hash) (*CollationHeader, error) { encoded, err := s.shardDB.Get(*hash) if err != nil { - return nil, fmt.Errorf("header not found: %v", err) + return nil, fmt.Errorf("get failed: %v", err) + } + if encoded == nil { + return nil, fmt.Errorf("no value set for header hash: %v", hash.String()) } var header CollationHeader - stream := rlp.NewStream(bytes.NewReader(encoded), uint64(len(encoded))) + stream := rlp.NewStream(bytes.NewReader(*encoded), uint64(len(*encoded))) if err := header.DecodeRLP(stream); err != nil { return nil, fmt.Errorf("could not decode RLP header: %v", err) } @@ -87,13 +90,16 @@ func (s *Shard) CanonicalHeaderHash(shardID *big.Int, period *big.Int) (*common. // fetches the RLP encoded collation header corresponding to the key. encoded, err := s.shardDB.Get(key) if err != nil { + return nil, err + } + if encoded == nil { return nil, fmt.Errorf("no canonical collation header set for period=%v, shardID=%v pair: %v", shardID, period, err) } // RLP decodes the header, computes its hash. var header CollationHeader - stream := rlp.NewStream(bytes.NewReader(encoded), uint64(len(encoded))) + stream := rlp.NewStream(bytes.NewReader(*encoded), uint64(len(*encoded))) if err := header.DecodeRLP(stream); err != nil { return nil, fmt.Errorf("could not decode RLP header: %v", err) } @@ -120,18 +126,25 @@ func (s *Shard) CanonicalCollation(shardID *big.Int, period *big.Int) (*Collatio func (s *Shard) BodyByChunkRoot(chunkRoot *common.Hash) ([]byte, error) { body, err := s.shardDB.Get(*chunkRoot) if err != nil { - return nil, fmt.Errorf("no corresponding body with chunk root found: %v", err) + return nil, err } - return body, nil + if body == nil { + return nil, fmt.Errorf("no corresponding body with chunk root found: %v", chunkRoot.String()) + } + return *body, nil } // CheckAvailability is used by notaries to confirm a header's data availability. func (s *Shard) CheckAvailability(header *CollationHeader) (bool, error) { key := dataAvailabilityLookupKey(header.ChunkRoot()) - availability, err := s.shardDB.Get(key) + val, err := s.shardDB.Get(key) if err != nil { - return false, fmt.Errorf("key not found: %v", key) + return false, err } + if val == nil { + return false, fmt.Errorf("availability not set for header") + } + availability := *val // availability is a byte array of length 1. return availability[0] != 0, nil } @@ -146,7 +159,7 @@ func (s *Shard) SetAvailability(chunkRoot *common.Hash, availability bool) error encoded = []byte{0} } if err := s.shardDB.Put(key, encoded); err != nil { - return fmt.Errorf("cannot update shardDB: %v", err) + return err } return nil } @@ -179,7 +192,7 @@ func (s *Shard) SaveBody(body []byte) error { chunkRoot := common.BytesToHash(body) if err := s.shardDB.Put(chunkRoot, body); err != nil { - return fmt.Errorf("cannot update shardDB: %v", err) + return err } s.SetAvailability(&chunkRoot, true) return nil @@ -227,7 +240,7 @@ func (s *Shard) SetCanonical(header *CollationHeader) error { // sets the key to be the canonical collation lookup key and val as RLP encoded // collation header. if err := s.shardDB.Put(key, encoded); err != nil { - return fmt.Errorf("cannot update shardDB: %v", err) + return err } return nil } diff --git a/sharding/shard_test.go b/sharding/shard_test.go index 7d9b68b4b..fc2e73c24 100644 --- a/sharding/shard_test.go +++ b/sharding/shard_test.go @@ -13,10 +13,10 @@ import ( ) type mockShardDB struct { - kv map[common.Hash][]byte + kv map[common.Hash]*[]byte } -func (m *mockShardDB) Get(k common.Hash) ([]byte, error) { +func (m *mockShardDB) Get(k common.Hash) (*[]byte, error) { return nil, nil } @@ -64,7 +64,7 @@ func TestShard_HeaderByHash(t *testing.T) { header := NewCollationHeader(big.NewInt(1), &emptyHash, big.NewInt(1), &emptyAddr, []byte{}) // creates a mockDB that always returns nil values from .Get and errors in every other method. - mockDB := &mockShardDB{kv: make(map[common.Hash][]byte)} + mockDB := &mockShardDB{kv: make(map[common.Hash]*[]byte)} // creates a well-functioning shardDB. shardDB := database.NewShardKV() @@ -128,6 +128,16 @@ func TestShard_CollationByHash(t *testing.T) { t.Errorf("should not be able to fetch collation before saving first") } + // should not be able to fetch collation if the header has been saved but body has not. + if err := shard.SaveHeader(header); err != nil { + t.Fatalf("could not save header: %v", err) + } + + if _, err := shard.CollationByHash(&hash); err == nil { + t.Errorf("should not be able to fetch collation if body has not been saved") + } + + // properly saves the collation. if err := shard.SaveCollation(collation); err != nil { t.Fatalf("cannot save collation: %v", err) } @@ -277,12 +287,21 @@ func TestShard_BodyByChunkRoot(t *testing.T) { // it should throw error if fetching non-existent chunk root. emptyHash := common.StringToHash("") if _, err := shard.BodyByChunkRoot(&emptyHash); err == nil { - t.Errorf("non-existent chunk root should throw error") + t.Errorf("non-existent chunk root should throw error: %v", err) } if !bytes.Equal(body, dbBody) { t.Errorf("bodies not equal. want=%v. got=%v", body, dbBody) } + + // setting the val of the key to nil. + if err := shard.shardDB.Put(emptyHash, nil); err != nil { + t.Fatalf("could not update shardDB: %v", err) + } + if _, err := shard.BodyByChunkRoot(&emptyHash); err != nil { + t.Errorf("value set as nil in shardDB should return error from BodyByChunkRoot") + } + } func TestShard_CheckAvailability(t *testing.T) { @@ -327,7 +346,7 @@ func TestShard_SetAvailability(t *testing.T) { header := NewCollationHeader(big.NewInt(1), &chunkRoot, big.NewInt(1), nil, []byte{}) // creates a mockDB that always returns nil values from .Get and errors in every other method. - mockDB := &mockShardDB{kv: make(map[common.Hash][]byte)} + mockDB := &mockShardDB{kv: make(map[common.Hash]*[]byte)} // creates a well-functioning shardDB. shardDB := database.NewShardKV() @@ -382,7 +401,7 @@ func TestShard_SaveCollation(t *testing.T) { func TestShard_SaveHeader(t *testing.T) { // creates a mockDB that always returns nil values from .Get and errors in every other method. - mockDB := &mockShardDB{kv: make(map[common.Hash][]byte)} + mockDB := &mockShardDB{kv: make(map[common.Hash]*[]byte)} emptyHash := common.StringToHash("") errorShard := NewShard(big.NewInt(1), mockDB) @@ -394,7 +413,7 @@ func TestShard_SaveHeader(t *testing.T) { func TestShard_SaveBody(t *testing.T) { // creates a mockDB that always returns nil values from .Get and errors in every other method. - mockDB := &mockShardDB{kv: make(map[common.Hash][]byte)} + mockDB := &mockShardDB{kv: make(map[common.Hash]*[]byte)} errorShard := NewShard(big.NewInt(1), mockDB) if err := errorShard.SaveBody([]byte{1, 2, 3}); err == nil {