Apply suggestions from @hwwhww's code review

* Tests are of type none.
* Parameters passed in to tests instead of using global variables
* remove debugging print left in

Co-authored-by: Hsiao-Wei Wang <hwwang156@gmail.com>
This commit is contained in:
Carl Beekhuizen 2020-06-29 11:48:59 +02:00 committed by GitHub
parent 9d2008eb4e
commit 9557f3cabb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 13 deletions

View File

@ -38,10 +38,10 @@ def _parent_SK_to_lamport_PK(*, parent_SK: int, index: int) -> bytes:
def _HKDF_mod_r(*, IKM: bytes, key_info: bytes=b'') -> int:
L = 48
L = 48 # `ceil((3 * ceil(log2(r))) / 16)`, where `r` is the order of the BLS 12-381 curve
okm = HKDF(
salt=b'BLS-SIG-KEYGEN-SALT-',
IKM=IKM + b'\x00',
IKM=IKM + b'\x00', # add postfix `I2OSP(0, 1)`
L=L,
info=key_info + L.to_bytes(2, 'big'),
)

View File

@ -39,7 +39,6 @@ def test_parent_SK_to_lamport_PK() -> None:
parent_SK = test_vector['master_SK']
index = test_vector['child_index']
lamport_PK = bytes.fromhex(test_vector['compressed_lamport_PK'])
print(_parent_SK_to_lamport_PK(parent_SK=parent_SK, index=index).hex())
assert lamport_PK == _parent_SK_to_lamport_PK(parent_SK=parent_SK, index=index)

View File

@ -16,20 +16,35 @@ with open(test_vector_filefolder, 'r') as f:
test_vectors = json.load(f)['kdf_tests']
def test_hkdf_mod_r() -> None:
for test in test_vectors:
@pytest.mark.parametrize(
'test',
test_vectors
)
def test_hkdf_mod_r(test) -> None:
seed = bytes.fromhex(test['seed'])
assert bls.KeyGen(seed) == _HKDF_mod_r(IKM=seed)
def test_derive_master_SK() -> None:
for test in test_vectors:
@pytest.mark.parametrize(
'test',
test_vectors
)
def test_derive_master_SK(test) -> None:
seed = bytes.fromhex(test['seed'])
master_SK = test['master_SK']
assert derive_master_SK(seed=seed) == master_SK
def test_derive_child_SK() -> None:
@pytest.mark.parametrize(
'test',
test_vectors
)
def test_derive_child_SK(test) -> None:
parent_SK = test['master_SK']
index = test['child_index']
child_SK = test['child_SK']
assert derive_child_SK(parent_SK=parent_SK, index=index) == child_SK
for test in test_vectors:
parent_SK = test['master_SK']
index = test['child_index']