diff --git a/eth2deposit/key_handling/key_derivation/tree.py b/eth2deposit/key_handling/key_derivation/tree.py index 10582b6..e1359dc 100644 --- a/eth2deposit/key_handling/key_derivation/tree.py +++ b/eth2deposit/key_handling/key_derivation/tree.py @@ -38,10 +38,10 @@ def _parent_SK_to_lamport_PK(*, parent_SK: int, index: int) -> bytes: def _HKDF_mod_r(*, IKM: bytes, key_info: bytes=b'') -> int: - L = 48 + L = 48 # `ceil((3 * ceil(log2(r))) / 16)`, where `r` is the order of the BLS 12-381 curve okm = HKDF( salt=b'BLS-SIG-KEYGEN-SALT-', - IKM=IKM + b'\x00', + IKM=IKM + b'\x00', # add postfix `I2OSP(0, 1)` L=L, info=key_info + L.to_bytes(2, 'big'), ) diff --git a/tests/test_key_handling/test_key_derivation/test_path.py b/tests/test_key_handling/test_key_derivation/test_path.py index faaac4b..8443e9f 100644 --- a/tests/test_key_handling/test_key_derivation/test_path.py +++ b/tests/test_key_handling/test_key_derivation/test_path.py @@ -39,7 +39,6 @@ def test_parent_SK_to_lamport_PK() -> None: parent_SK = test_vector['master_SK'] index = test_vector['child_index'] lamport_PK = bytes.fromhex(test_vector['compressed_lamport_PK']) - print(_parent_SK_to_lamport_PK(parent_SK=parent_SK, index=index).hex()) assert lamport_PK == _parent_SK_to_lamport_PK(parent_SK=parent_SK, index=index) diff --git a/tests/test_key_handling/test_key_derivation/test_tree.py b/tests/test_key_handling/test_key_derivation/test_tree.py index 8e06540..9b3463f 100644 --- a/tests/test_key_handling/test_key_derivation/test_tree.py +++ b/tests/test_key_handling/test_key_derivation/test_tree.py @@ -16,20 +16,35 @@ with open(test_vector_filefolder, 'r') as f: test_vectors = json.load(f)['kdf_tests'] -def test_hkdf_mod_r() -> None: - for test in test_vectors: - seed = bytes.fromhex(test['seed']) - assert bls.KeyGen(seed) == _HKDF_mod_r(IKM=seed) +@pytest.mark.parametrize( + 'test', + test_vectors +) +def test_hkdf_mod_r(test) -> None: + seed = bytes.fromhex(test['seed']) + assert bls.KeyGen(seed) == _HKDF_mod_r(IKM=seed) -def test_derive_master_SK() -> None: - for test in test_vectors: - seed = bytes.fromhex(test['seed']) - master_SK = test['master_SK'] - assert derive_master_SK(seed=seed) == master_SK + +@pytest.mark.parametrize( + 'test', + test_vectors +) +def test_derive_master_SK(test) -> None: + seed = bytes.fromhex(test['seed']) + master_SK = test['master_SK'] + assert derive_master_SK(seed=seed) == master_SK -def test_derive_child_SK() -> None: +@pytest.mark.parametrize( + 'test', + test_vectors +) +def test_derive_child_SK(test) -> None: + parent_SK = test['master_SK'] + index = test['child_index'] + child_SK = test['child_SK'] + assert derive_child_SK(parent_SK=parent_SK, index=index) == child_SK for test in test_vectors: parent_SK = test['master_SK'] index = test['child_index']