diff --git a/eth2deposit/key_handling/key_derivation/mnemonic.py b/eth2deposit/key_handling/key_derivation/mnemonic.py index 7c8537a..d2ad84a 100644 --- a/eth2deposit/key_handling/key_derivation/mnemonic.py +++ b/eth2deposit/key_handling/key_derivation/mnemonic.py @@ -26,12 +26,27 @@ def _resource_path(relative_path: str) -> str: def _get_word_list(language: str, path: str) -> Sequence[str]: path = _resource_path(path) - return open(os.path.join(path, '%s.txt' % language), encoding='utf-8').readlines() + dirty_list = open(os.path.join(path, '%s.txt' % language), encoding='utf-8').readlines() + return [word.replace('\n', '') for word in dirty_list] -def _get_word(*, word_list: Sequence[str], index: int) -> str: +def _index_to_word(word_list: Sequence[str], index: int) -> str: + """ + Given the index of a word in the word list, return the corresponding word. + """ assert index < 2048 - return word_list[index][:-1] + return word_list[index] + + +def _word_to_index(word_list: Sequence[str], word: str) -> int: + try: + return word_list.index(word) + except ValueError: + raise ValueError('Word %s not in BIP39 word-list' % word) + + +def _uint11_array_to_uint(unit_array: Sequence[int]) -> int: + return sum([x << i * 11 for i, x in enumerate(reversed(unit_array))]) def get_seed(*, mnemonic: str, password: str) -> bytes: @@ -55,6 +70,49 @@ def get_languages(path: str) -> Tuple[str, ...]: return languages +def determine_mnemonic_language(mnemonic: str, words_path:str) -> Sequence[str]: + """ + Given a `mnemonic` determine what language it is written in. + """ + languages = get_languages(words_path) + word_language_map = {word: lang for lang in languages for word in _get_word_list(lang, words_path)} + try: + mnemonic_list = mnemonic.split(' ') + word_languages = [word_language_map[word] for word in mnemonic_list] + return set(word_languages) + except KeyError: + raise ValueError('Word not found in mnemonic word lists for any language.') + + +def _get_checksum(entropy: bytes) -> int: + """ + Determine the index of the checksum word given the entropy + """ + entropy_length = len(entropy) * 8 + assert entropy_length in range(128, 257, 32) + checksum_length = (entropy_length // 32) + return int.from_bytes(SHA256(entropy), 'big') >> 256 - checksum_length + + +def verify_mnemonic(mnemonic: str, words_path: str) -> bool: + languages = determine_mnemonic_language(mnemonic, words_path) + for language in languages: + try: + word_list = _get_word_list(language, words_path) + mnemonic_list = mnemonic.split(' ') + word_indices = [_word_to_index(word_list, word) for word in mnemonic_list] + mnemonic_int = _uint11_array_to_uint(word_indices) + checksum_length = len(mnemonic_list)//3 + checksum = mnemonic_int & 2**checksum_length - 1 + entropy = (mnemonic_int - checksum) >> checksum_length + entropy_bits = entropy.to_bytes(checksum_length * 4, 'big') + return _get_checksum(entropy_bits) == checksum + except ValueError: + pass + return False + + + def get_mnemonic(*, language: str, words_path: str, entropy: Optional[bytes]=None) -> str: """ Return a mnemonic string in a given `language` based on `entropy`. @@ -64,7 +122,7 @@ def get_mnemonic(*, language: str, words_path: str, entropy: Optional[bytes]=Non entropy_length = len(entropy) * 8 assert entropy_length in range(128, 257, 32) checksum_length = (entropy_length // 32) - checksum = int.from_bytes(SHA256(entropy), 'big') >> 256 - checksum_length + checksum = _get_checksum(entropy) entropy_bits = int.from_bytes(entropy, 'big') << checksum_length entropy_bits += checksum entropy_length += checksum_length @@ -72,6 +130,6 @@ def get_mnemonic(*, language: str, words_path: str, entropy: Optional[bytes]=Non word_list = _get_word_list(language, words_path) for i in range(entropy_length // 11 - 1, -1, -1): index = (entropy_bits >> i * 11) & 2**11 - 1 - word = _get_word(word_list=word_list, index=index) + word = _index_to_word(word_list, index) mnemonic.append(word) return ' '.join(mnemonic) diff --git a/tests/test_key_handling/test_key_derivation/test_mnemonic.py b/tests/test_key_handling/test_key_derivation/test_mnemonic.py index d374675..f2d58e5 100644 --- a/tests/test_key_handling/test_key_derivation/test_mnemonic.py +++ b/tests/test_key_handling/test_key_derivation/test_mnemonic.py @@ -8,6 +8,7 @@ from typing import ( from eth2deposit.key_handling.key_derivation.mnemonic import ( get_seed, get_mnemonic, + verify_mnemonic, ) @@ -30,3 +31,10 @@ def test_bip39(language: str, test: Sequence[str]) -> None: assert get_mnemonic(language=language, words_path=WORD_LISTS_PATH, entropy=test_entropy) == test_mnemonic assert get_seed(mnemonic=test_mnemonic, password='TREZOR') == test_seed + +@pytest.mark.parametrize( + 'test_mnemonic,is_valid', + [(test_mnemonic[1], True) for _, language_test_vectors in test_vectors.items() for test_mnemonic in language_test_vectors] +) +def test_verify_mnemonic(test_mnemonic: str, is_valid: bool) -> None: + assert verify_mnemonic(test_mnemonic, WORD_LISTS_PATH) == is_valid