Use int_to_bytes in swap or not.

I also fixed an error I found through strict typing.
This commit is contained in:
Paul Hauner 2019-02-15 12:13:57 +11:00
parent 73484f04a1
commit fc04286ae6
No known key found for this signature in database
GPG Key ID: D362883A9218FCC6
2 changed files with 28 additions and 23 deletions

View File

@ -7,6 +7,7 @@ edition = "2018"
[dependencies] [dependencies]
bytes = "0.4" bytes = "0.4"
hashing = { path = "../hashing" } hashing = { path = "../hashing" }
int_to_bytes = { path = "../int_to_bytes" }
[dev-dependencies] [dev-dependencies]
yaml-rust = "0.4.2" yaml-rust = "0.4.2"

View File

@ -1,5 +1,6 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::Buf;
use hashing::hash; use hashing::hash;
use int_to_bytes::{int_to_bytes1, int_to_bytes4};
use std::cmp::max; use std::cmp::max;
use std::io::Cursor; use std::io::Cursor;
@ -12,14 +13,19 @@ use std::io::Cursor;
/// Returns `None` under any of the following conditions: /// Returns `None` under any of the following conditions:
/// - `list_size == 0` /// - `list_size == 0`
/// - `index >= list_size` /// - `index >= list_size`
/// - `list_size >= 2**24`
/// - `list_size >= usize::max_value() / 2` /// - `list_size >= usize::max_value() / 2`
pub fn get_permutated_index( pub fn get_permutated_index(
index: usize, index: usize,
list_size: usize, list_size: usize,
seed: &[u8], seed: &[u8],
shuffle_round_count: usize, shuffle_round_count: u8,
) -> Option<usize> { ) -> Option<usize> {
if list_size == 0 || index >= list_size || list_size >= usize::max_value() / 2 { if list_size == 0
|| index >= list_size
|| list_size >= usize::max_value() / 2
|| list_size >= 2_usize.pow(24)
{
return None; return None;
} }
@ -28,7 +34,7 @@ pub fn get_permutated_index(
let pivot = bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size; let pivot = bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size;
let flip = (pivot + list_size - index) % list_size; let flip = (pivot + list_size - index) % list_size;
let position = max(index, flip); let position = max(index, flip);
let source = hash_with_round_and_position(seed, round, position); let source = hash_with_round_and_position(seed, round, position)?;
let byte = source[(position % 256) / 8]; let byte = source[(position % 256) / 8];
let bit = (byte >> (position % 8)) % 2; let bit = (byte >> (position % 8)) % 2;
index = if bit == 1 { flip } else { index } index = if bit == 1 { flip } else { index }
@ -36,31 +42,23 @@ pub fn get_permutated_index(
Some(index) Some(index)
} }
fn hash_with_round_and_position(seed: &[u8], round: usize, position: usize) -> Vec<u8> { fn hash_with_round_and_position(seed: &[u8], round: u8, position: usize) -> Option<Vec<u8>> {
let mut seed = seed.to_vec(); let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round as u64)); seed.append(&mut int_to_bytes1(round));
seed.append(&mut int_to_bytes4(position as u64 / 256)); /*
hash(&seed[..]) * Note: the specification has an implicit assertion in `int_to_bytes4` that `position / 256 <
* 2**24`. For efficiency, we do not check for that here as it is checked in `get_permutated_index`.
*/
seed.append(&mut int_to_bytes4((position / 256) as u32));
Some(hash(&seed[..]))
} }
fn hash_with_round(seed: &[u8], round: usize) -> Vec<u8> { fn hash_with_round(seed: &[u8], round: u8) -> Vec<u8> {
let mut seed = seed.to_vec(); let mut seed = seed.to_vec();
seed.append(&mut int_to_bytes1(round as u64)); seed.append(&mut int_to_bytes1(round));
hash(&seed[..]) hash(&seed[..])
} }
fn int_to_bytes1(int: u64) -> Vec<u8> {
let mut bytes = BytesMut::with_capacity(8);
bytes.put_u64_le(int);
vec![bytes[0]]
}
fn int_to_bytes4(int: u64) -> Vec<u8> {
let mut bytes = BytesMut::with_capacity(8);
bytes.put_u64_le(int);
bytes[0..4].to_vec()
}
fn bytes_to_int64(bytes: &[u8]) -> u64 { fn bytes_to_int64(bytes: &[u8]) -> u64 {
let mut cursor = Cursor::new(bytes); let mut cursor = Cursor::new(bytes);
cursor.get_u64_le() cursor.get_u64_le()
@ -117,10 +115,16 @@ mod tests {
let index = test_case["index"].as_i64().unwrap() as usize; let index = test_case["index"].as_i64().unwrap() as usize;
let list_size = test_case["list_size"].as_i64().unwrap() as usize; let list_size = test_case["list_size"].as_i64().unwrap() as usize;
let permutated_index = test_case["permutated_index"].as_i64().unwrap() as usize; let permutated_index = test_case["permutated_index"].as_i64().unwrap() as usize;
let shuffle_round_count = test_case["shuffle_round_count"].as_i64().unwrap() as usize; let shuffle_round_count = test_case["shuffle_round_count"].as_i64().unwrap();
let seed_string = test_case["seed"].clone().into_string().unwrap(); let seed_string = test_case["seed"].clone().into_string().unwrap();
let seed = hex::decode(seed_string.replace("0x", "")).unwrap(); let seed = hex::decode(seed_string.replace("0x", "")).unwrap();
let shuffle_round_count = if shuffle_round_count < (u8::max_value() as i64) {
shuffle_round_count as u8
} else {
panic!("shuffle_round_count must be a u8")
};
assert_eq!( assert_eq!(
Some(permutated_index), Some(permutated_index),
get_permutated_index(index, list_size, &seed[..], shuffle_round_count), get_permutated_index(index, list_size, &seed[..], shuffle_round_count),