From fc04286ae6c996fd5567dce2045a307fc32752bd Mon Sep 17 00:00:00 2001 From: Paul Hauner Date: Fri, 15 Feb 2019 12:13:57 +1100 Subject: [PATCH] Use `int_to_bytes` in swap or not. I also fixed an error I found through strict typing. --- eth2/utils/swap_or_not_shuffle/Cargo.toml | 1 + eth2/utils/swap_or_not_shuffle/src/lib.rs | 50 ++++++++++++----------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/eth2/utils/swap_or_not_shuffle/Cargo.toml b/eth2/utils/swap_or_not_shuffle/Cargo.toml index 1c898f7b1..3dc03da82 100644 --- a/eth2/utils/swap_or_not_shuffle/Cargo.toml +++ b/eth2/utils/swap_or_not_shuffle/Cargo.toml @@ -7,6 +7,7 @@ edition = "2018" [dependencies] bytes = "0.4" hashing = { path = "../hashing" } +int_to_bytes = { path = "../int_to_bytes" } [dev-dependencies] yaml-rust = "0.4.2" diff --git a/eth2/utils/swap_or_not_shuffle/src/lib.rs b/eth2/utils/swap_or_not_shuffle/src/lib.rs index 3566ac23a..22d657bb4 100644 --- a/eth2/utils/swap_or_not_shuffle/src/lib.rs +++ b/eth2/utils/swap_or_not_shuffle/src/lib.rs @@ -1,5 +1,6 @@ -use bytes::{Buf, BufMut, BytesMut}; +use bytes::Buf; use hashing::hash; +use int_to_bytes::{int_to_bytes1, int_to_bytes4}; use std::cmp::max; use std::io::Cursor; @@ -12,14 +13,19 @@ use std::io::Cursor; /// Returns `None` under any of the following conditions: /// - `list_size == 0` /// - `index >= list_size` +/// - `list_size >= 2**24` /// - `list_size >= usize::max_value() / 2` pub fn get_permutated_index( index: usize, list_size: usize, seed: &[u8], - shuffle_round_count: usize, + shuffle_round_count: u8, ) -> Option { - if list_size == 0 || index >= list_size || list_size >= usize::max_value() / 2 { + if list_size == 0 + || index >= list_size + || list_size >= usize::max_value() / 2 + || list_size >= 2_usize.pow(24) + { return None; } @@ -28,7 +34,7 @@ pub fn get_permutated_index( let pivot = bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size; let flip = (pivot + list_size - index) % list_size; let position = max(index, flip); - let source = hash_with_round_and_position(seed, round, position); + let source = hash_with_round_and_position(seed, round, position)?; let byte = source[(position % 256) / 8]; let bit = (byte >> (position % 8)) % 2; index = if bit == 1 { flip } else { index } @@ -36,31 +42,23 @@ pub fn get_permutated_index( Some(index) } -fn hash_with_round_and_position(seed: &[u8], round: usize, position: usize) -> Vec { +fn hash_with_round_and_position(seed: &[u8], round: u8, position: usize) -> Option> { let mut seed = seed.to_vec(); - seed.append(&mut int_to_bytes1(round as u64)); - seed.append(&mut int_to_bytes4(position as u64 / 256)); - hash(&seed[..]) + seed.append(&mut int_to_bytes1(round)); + /* + * Note: the specification has an implicit assertion in `int_to_bytes4` that `position / 256 < + * 2**24`. For efficiency, we do not check for that here as it is checked in `get_permutated_index`. + */ + seed.append(&mut int_to_bytes4((position / 256) as u32)); + Some(hash(&seed[..])) } -fn hash_with_round(seed: &[u8], round: usize) -> Vec { +fn hash_with_round(seed: &[u8], round: u8) -> Vec { let mut seed = seed.to_vec(); - seed.append(&mut int_to_bytes1(round as u64)); + seed.append(&mut int_to_bytes1(round)); hash(&seed[..]) } -fn int_to_bytes1(int: u64) -> Vec { - let mut bytes = BytesMut::with_capacity(8); - bytes.put_u64_le(int); - vec![bytes[0]] -} - -fn int_to_bytes4(int: u64) -> Vec { - let mut bytes = BytesMut::with_capacity(8); - bytes.put_u64_le(int); - bytes[0..4].to_vec() -} - fn bytes_to_int64(bytes: &[u8]) -> u64 { let mut cursor = Cursor::new(bytes); cursor.get_u64_le() @@ -117,10 +115,16 @@ mod tests { let index = test_case["index"].as_i64().unwrap() as usize; let list_size = test_case["list_size"].as_i64().unwrap() as usize; let permutated_index = test_case["permutated_index"].as_i64().unwrap() as usize; - let shuffle_round_count = test_case["shuffle_round_count"].as_i64().unwrap() as usize; + let shuffle_round_count = test_case["shuffle_round_count"].as_i64().unwrap(); let seed_string = test_case["seed"].clone().into_string().unwrap(); let seed = hex::decode(seed_string.replace("0x", "")).unwrap(); + let shuffle_round_count = if shuffle_round_count < (u8::max_value() as i64) { + shuffle_round_count as u8 + } else { + panic!("shuffle_round_count must be a u8") + }; + assert_eq!( Some(permutated_index), get_permutated_index(index, list_size, &seed[..], shuffle_round_count),