From e43d27f3e4ad54f43e890d84b1f0b2ea9fb2e36f Mon Sep 17 00:00:00 2001 From: Paul Hauner Date: Fri, 24 May 2019 13:56:17 +1000 Subject: [PATCH] Add extra level of comparisons to `CompareFields` --- eth2/utils/compare_fields/src/lib.rs | 170 +++++++++++++++++++- eth2/utils/compare_fields/tests/tests.rs | 46 ------ eth2/utils/compare_fields_derive/src/lib.rs | 57 ++++--- 3 files changed, 207 insertions(+), 66 deletions(-) delete mode 100644 eth2/utils/compare_fields/tests/tests.rs diff --git a/eth2/utils/compare_fields/src/lib.rs b/eth2/utils/compare_fields/src/lib.rs index 75f20b3c5..a0166eb50 100644 --- a/eth2/utils/compare_fields/src/lib.rs +++ b/eth2/utils/compare_fields/src/lib.rs @@ -1,3 +1,152 @@ +//! Provides field-by-field comparisons for structs and vecs. +//! +//! Returns comparisons as data, without making assumptions about the desired equality (e.g., +//! does not `panic!` on inequality). +//! +//! Note: `compare_fields_derive` requires `PartialEq` and `Debug` implementations. +//! +//! ## Example +//! +//! ```rust +//! use compare_fields::{CompareFields, Comparison, FieldComparison}; +//! use compare_fields_derive::CompareFields; +//! +//! #[derive(PartialEq, Debug, CompareFields)] +//! pub struct Bar { +//! a: u64, +//! b: u16, +//! #[compare_fields(as_slice)] +//! c: Vec +//! } +//! +//! #[derive(Clone, PartialEq, Debug, CompareFields)] +//! pub struct Foo { +//! d: String +//! } +//! +//! let cat = Foo {d: "cat".to_string()}; +//! let dog = Foo {d: "dog".to_string()}; +//! let chicken = Foo {d: "chicken".to_string()}; +//! +//! let mut bar_a = Bar { +//! a: 42, +//! b: 12, +//! c: vec![ cat.clone(), dog.clone() ], +//! }; +//! +//! let mut bar_b = Bar { +//! a: 42, +//! b: 99, +//! c: vec![ chicken.clone(), dog.clone()] +//! }; +//! +//! let cat_dog = Comparison::Child(FieldComparison { +//! field_name: "d".to_string(), +//! equal: false, +//! a: "\"cat\"".to_string(), +//! b: "\"dog\"".to_string(), +//! }); +//! assert_eq!(cat.compare_fields(&dog), vec![cat_dog]); +//! +//! let bar_a_b = vec![ +//! Comparison::Child(FieldComparison { +//! field_name: "a".to_string(), +//! equal: true, +//! a: "42".to_string(), +//! b: "42".to_string(), +//! }), +//! Comparison::Child(FieldComparison { +//! field_name: "b".to_string(), +//! equal: false, +//! a: "12".to_string(), +//! b: "99".to_string(), +//! }), +//! Comparison::Parent{ +//! field_name: "c".to_string(), +//! equal: false, +//! children: vec![ +//! FieldComparison { +//! field_name: "0".to_string(), +//! equal: false, +//! a: "Some(Foo { d: \"cat\" })".to_string(), +//! b: "Some(Foo { d: \"chicken\" })".to_string(), +//! }, +//! FieldComparison { +//! field_name: "1".to_string(), +//! equal: true, +//! a: "Some(Foo { d: \"dog\" })".to_string(), +//! b: "Some(Foo { d: \"dog\" })".to_string(), +//! } +//! ] +//! } +//! ]; +//! assert_eq!(bar_a.compare_fields(&bar_b), bar_a_b); +//! +//! +//! +//! // TODO: +//! ``` +use std::fmt::Debug; + +#[derive(Debug, PartialEq, Clone)] +pub enum Comparison { + Child(FieldComparison), + Parent { + field_name: String, + equal: bool, + children: Vec, + }, +} + +impl Comparison { + pub fn child>(field_name: String, a: &T, b: &T) -> Self { + Comparison::Child(FieldComparison::new(field_name, a, b)) + } + + pub fn parent(field_name: String, equal: bool, children: Vec) -> Self { + Comparison::Parent { + field_name, + equal, + children, + } + } + + pub fn from_slice>(field_name: String, a: &[T], b: &[T]) -> Self { + let mut children = vec![]; + + for i in 0..std::cmp::max(a.len(), b.len()) { + children.push(FieldComparison::new( + format!("{:}", i), + &a.get(i), + &b.get(i), + )); + } + + Self::parent(field_name, a == b, children) + } + + pub fn retain_children(&mut self, f: F) + where + F: FnMut(&FieldComparison) -> bool, + { + match self { + Comparison::Child(_) => (), + Comparison::Parent { children, .. } => children.retain(f), + } + } + + pub fn equal(&self) -> bool { + match self { + Comparison::Child(fc) => fc.equal, + Comparison::Parent { equal, .. } => *equal, + } + } + + pub fn not_equal(&self) -> bool { + !self.equal() + } +} + #[derive(Debug, PartialEq, Clone)] pub struct FieldComparison { pub field_name: String, @@ -7,5 +156,24 @@ pub struct FieldComparison { } pub trait CompareFields { - fn compare_fields(&self, b: &Self) -> Vec; + fn compare_fields(&self, b: &Self) -> Vec; +} + +impl FieldComparison { + pub fn new>(field_name: String, a: &T, b: &T) -> Self { + Self { + field_name, + equal: a == b, + a: format!("{:?}", a), + b: format!("{:?}", b), + } + } + + pub fn equal(&self) -> bool { + self.equal + } + + pub fn not_equal(&self) -> bool { + !self.equal() + } } diff --git a/eth2/utils/compare_fields/tests/tests.rs b/eth2/utils/compare_fields/tests/tests.rs deleted file mode 100644 index 96ea94810..000000000 --- a/eth2/utils/compare_fields/tests/tests.rs +++ /dev/null @@ -1,46 +0,0 @@ -use compare_fields::{CompareFields, FieldComparison}; -use compare_fields_derive::CompareFields; - -#[derive(Clone, Debug, CompareFields)] -pub struct Simple { - a: u64, - b: u16, - c: Vec, -} - -#[test] -fn compare() { - let foo = Simple { - a: 42, - b: 12, - c: vec![1, 2], - }; - - let mut bar = foo.clone(); - - let comparisons = foo.compare_fields(&bar); - - assert!(!comparisons.iter().any(|c| c.equal == false)); - - assert_eq!( - comparisons[0], - FieldComparison { - equal: true, - field_name: "a".to_string(), - a: "42".to_string(), - b: "42".to_string(), - } - ); - - bar.a = 30; - - assert_eq!( - foo.compare_fields(&bar)[0], - FieldComparison { - equal: false, - field_name: "a".to_string(), - a: "42".to_string(), - b: "30".to_string(), - } - ); -} diff --git a/eth2/utils/compare_fields_derive/src/lib.rs b/eth2/utils/compare_fields_derive/src/lib.rs index 89c61796c..c4ca3d64c 100644 --- a/eth2/utils/compare_fields_derive/src/lib.rs +++ b/eth2/utils/compare_fields_derive/src/lib.rs @@ -5,7 +5,16 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, DeriveInput}; -#[proc_macro_derive(CompareFields)] +fn is_slice(field: &syn::Field) -> bool { + for attr in &field.attrs { + if attr.tts.to_string() == "( as_slice )" { + return true; + } + } + false +} + +#[proc_macro_derive(CompareFields, attributes(compare_fields))] pub fn compare_fields_derive(input: TokenStream) -> TokenStream { let item = parse_macro_input!(input as DeriveInput); @@ -17,37 +26,47 @@ pub fn compare_fields_derive(input: TokenStream) -> TokenStream { _ => panic!("compare_fields_derive only supports structs."), }; - let mut idents_a = vec![]; - let mut field_names = vec![]; + let mut quotes = vec![]; for field in struct_data.fields.iter() { - let ident = match &field.ident { + let ident_a = match &field.ident { Some(ref ident) => ident, _ => panic!("compare_fields_derive only supports named struct fields."), }; - field_names.push(format!("{:}", ident)); - idents_a.push(ident); - } + let field_name = format!("{:}", ident_a); + let ident_b = ident_a.clone(); - let idents_b = idents_a.clone(); - let idents_c = idents_a.clone(); - let idents_d = idents_a.clone(); + let quote = if is_slice(field) { + quote! { + comparisons.push(compare_fields::Comparison::from_slice( + #field_name.to_string(), + &self.#ident_a, + &b.#ident_b) + ); + } + } else { + quote! { + comparisons.push( + compare_fields::Comparison::child( + #field_name.to_string(), + &self.#ident_a, + &b.#ident_b + ) + ); + } + }; + + quotes.push(quote); + } let output = quote! { impl #impl_generics compare_fields::CompareFields for #name #ty_generics #where_clause { - fn compare_fields(&self, b: &Self) -> Vec { + fn compare_fields(&self, b: &Self) -> Vec { let mut comparisons = vec![]; #( - comparisons.push( - compare_fields::FieldComparison { - equal: self.#idents_a == b.#idents_b, - field_name: #field_names.to_string(), - a: format!("{:?}", self.#idents_c), - b: format!("{:?}", b.#idents_d), - } - ); + #quotes )* comparisons