diff --git a/eth2/utils/ssz/src/cached_tree_hash.rs b/eth2/utils/ssz/src/cached_tree_hash.rs index a85da8fd9..757bfa9f7 100644 --- a/eth2/utils/ssz/src/cached_tree_hash.rs +++ b/eth2/utils/ssz/src/cached_tree_hash.rs @@ -10,6 +10,7 @@ pub struct TreeHashCache<'a> { chunk_offset: usize, cache: &'a mut [u8], chunk_modified: &'a mut [bool], + hash_count: &'a mut usize, } impl<'a> TreeHashCache<'a> { @@ -17,17 +18,20 @@ impl<'a> TreeHashCache<'a> { vec![false; bytes.len() / BYTES_PER_CHUNK] } - pub fn from_mut_slice(bytes: &'a mut [u8], changes: &'a mut [bool]) -> Option { + pub fn from_mut_slice( + bytes: &'a mut [u8], + changes: &'a mut [bool], + hash_count: &'a mut usize, + ) -> Option { if bytes.len() % BYTES_PER_CHUNK > 0 { return None; } - let chunk_modified = vec![false; bytes.len() / BYTES_PER_CHUNK]; - Some(Self { chunk_offset: 0, cache: bytes, chunk_modified: changes, + hash_count, }) } @@ -36,12 +40,13 @@ impl<'a> TreeHashCache<'a> { } pub fn modify_current_chunk(&mut self, to: &[u8]) -> Option<()> { - self.modify_chunk(0, to) + self.modify_chunk(self.chunk_offset, to) } pub fn modify_chunk(&mut self, chunk: usize, to: &[u8]) -> Option<()> { let start = chunk * BYTES_PER_CHUNK; let end = start + BYTES_PER_CHUNK; + self.cache.get_mut(start..end)?.copy_from_slice(to); self.chunk_modified[chunk] = true; @@ -79,9 +84,10 @@ impl<'a> TreeHashCache<'a> { let modified_end = modified_start + leaves; Some(TreeHashCache { - chunk_offset: self.chunk_offset + internal, + chunk_offset: 0, cache: self.cache.get_mut(leaves_start..leaves_end)?, chunk_modified: self.chunk_modified.get_mut(modified_start..modified_end)?, + hash_count: self.hash_count, }) } @@ -111,7 +117,8 @@ impl CachedTreeHash for u64 { fn cached_hash_tree_root(&self, other: &Self, cache: &mut TreeHashCache) -> Option<()> { if self != other { - cache.modify_current_chunk(&merkleize(&int_to_bytes32(*self))); + *cache.hash_count += 1; + cache.modify_current_chunk(&merkleize(&int_to_bytes32(*self)))?; } cache.increment(); @@ -156,6 +163,7 @@ impl CachedTreeHash for Inner { for chunk in (0..internal_chunks).into_iter().rev() { if cache.children_modified(chunk)? { + *cache.hash_count += 1; cache.modify_chunk(chunk, &cache.hash_children(chunk)?)?; } } @@ -197,7 +205,7 @@ pub fn merkleize(values: &[u8]) -> Vec { mod tests { use super::*; - fn join(many: Vec<&[u8]>) -> Vec { + fn join(many: Vec>) -> Vec { let mut all = vec![]; for one in many { all.extend_from_slice(&mut one.clone()) @@ -205,8 +213,7 @@ mod tests { all } - #[test] - fn cached_hash_on_inner() { + fn generic_test(index: usize) { let inner = Inner { a: 1, b: 2, @@ -216,37 +223,69 @@ mod tests { let mut cache = inner.build_cache_bytes(); - let changed_inner = Inner { - a: 42, - ..inner.clone() + let changed_inner = match index { + 0 => Inner { + a: 42, + ..inner.clone() + }, + 1 => Inner { + b: 42, + ..inner.clone() + }, + 2 => Inner { + c: 42, + ..inner.clone() + }, + 3 => Inner { + d: 42, + ..inner.clone() + }, + _ => panic!("bad index"), }; let mut changes = TreeHashCache::build_changes_vec(&cache); - let mut cache_struct = TreeHashCache::from_mut_slice(&mut cache, &mut changes).unwrap(); + let mut hash_count = 0; + let mut cache_struct = + TreeHashCache::from_mut_slice(&mut cache, &mut changes, &mut hash_count).unwrap(); - changed_inner.cached_hash_tree_root(&inner, &mut cache_struct); + changed_inner + .cached_hash_tree_root(&inner, &mut cache_struct) + .unwrap(); + + assert_eq!(*cache_struct.hash_count, 3); let new_cache = cache_struct.into_slice(); - let data1 = &int_to_bytes32(42); - let data2 = &int_to_bytes32(2); - let data3 = &int_to_bytes32(3); - let data4 = &int_to_bytes32(4); + let data1 = int_to_bytes32(1); + let data2 = int_to_bytes32(2); + let data3 = int_to_bytes32(3); + let data4 = int_to_bytes32(4); - let data = join(vec![&data1, &data2, &data3, &data4]); - let expected = merkleize(&data); + let mut data = vec![data1, data2, data3, data4]; + + data[index] = int_to_bytes32(42); + + let expected = merkleize(&join(data)); assert_eq!(expected, new_cache); } #[test] - fn build_cache_matches_merkelize() { - let data1 = &int_to_bytes32(1); - let data2 = &int_to_bytes32(2); - let data3 = &int_to_bytes32(3); - let data4 = &int_to_bytes32(4); + fn cached_hash_on_inner() { + generic_test(0); + generic_test(1); + generic_test(2); + generic_test(3); + } - let data = join(vec![&data1, &data2, &data3, &data4]); + #[test] + fn build_cache_matches_merkelize() { + let data1 = int_to_bytes32(1); + let data2 = int_to_bytes32(2); + let data3 = int_to_bytes32(3); + let data4 = int_to_bytes32(4); + + let data = join(vec![data1, data2, data3, data4]); let expected = merkleize(&data); let inner = Inner { @@ -268,7 +307,12 @@ mod tests { let data3 = hash(&int_to_bytes32(3)); let data4 = hash(&int_to_bytes32(4)); - let data = join(vec![&data1, &data2, &data3, &data4]); + let data = join(vec![ + data1.clone(), + data2.clone(), + data3.clone(), + data4.clone(), + ]); let cache = merkleize(&data);