diff --git a/Lib/test/support/hashlib_helper.py b/Lib/test/support/hashlib_helper.py index 96be74e4105c18..c9962a0ae14862 100644 --- a/Lib/test/support/hashlib_helper.py +++ b/Lib/test/support/hashlib_helper.py @@ -9,15 +9,64 @@ from types import MappingProxyType -def try_import_module(module_name): - """Try to import a module and return None on failure.""" +def _parse_fullname(fullname, *, strict=False): + """Parse a fully-qualified name .. + + The ``module_name`` component contains one or more dots. + The ``member_name`` component does not contain any dot. + """ + if fullname is None: + assert not strict + return None, None + assert isinstance(fullname, str), fullname + assert fullname.count(".") >= 1, fullname + module_name, member_name = fullname.rsplit(".", maxsplit=1) + return module_name, member_name + + +def _import_module(module_name, *, strict=False): + """Import a module from its fully-qualified name. + + If *strict* is false, import failures are suppressed and None is returned. + """ + if module_name is None: + # To prevent a TypeError in importlib.import_module + if strict: + raise ImportError("no module to import") + return None try: return importlib.import_module(module_name) - except ImportError: + except ImportError as exc: + if strict: + raise exc return None -class HID(enum.StrEnum): +def _import_member(module_name, member_name, *, strict=False): + """Import a member from a module. + + If *strict* is false, import failures are suppressed and None is returned. + """ + if member_name is None: + if strict: + raise ImportError(f"no member to import from {module_name}") + return None + module = _import_module(module_name, strict=strict) + if strict: + return getattr(module, member_name) + return getattr(module, member_name, None) + + +class Implementation(enum.StrEnum): + # Indicate that the hash function is implemented by a built-in module. + builtin = enum.auto() + # Indicate that the hash function is implemented by OpenSSL. + openssl = enum.auto() + # Indicate that the hash function is provided through the public API. + hashlib = enum.auto() + + +class _HashId(enum.StrEnum): """Enumeration containing the canonical digest names. Those names should only be used by hashlib.new() or hmac.new(). @@ -57,62 +106,158 @@ def is_keyed(self): return self.startswith("blake2") -CANONICAL_DIGEST_NAMES = frozenset(map(str, HID.__members__)) +CANONICAL_DIGEST_NAMES = frozenset(map(str, _HashId.__members__)) NON_HMAC_DIGEST_NAMES = frozenset(( - HID.shake_128, HID.shake_256, - HID.blake2s, HID.blake2b, + _HashId.shake_128, _HashId.shake_256, + _HashId.blake2s, _HashId.blake2b, )) -class HashInfo: - """Dataclass storing explicit hash constructor names. +class _HashInfoItem: + """Interface for interacting with a named object. - - *builtin* is the fully-qualified name for the explicit HACL* - hash constructor function, e.g., "_md5.md5". - - - *openssl* is the name of the "_hashlib" module method for the explicit - OpenSSL hash constructor function, e.g., "openssl_md5". + The object is entirely described by its fully-qualified *fullname*. - - *hashlib* is the name of the "hashlib" module method for the explicit - hash constructor function, e.g., "md5". + *fullname* must be None or a string ".". + If *strict* is true, *fullname* cannot be None. """ - def __init__(self, builtin, openssl=None, hashlib=None): - assert isinstance(builtin, str), builtin - assert len(builtin.split(".")) == 2, builtin + def __init__(self, fullname=None, *, strict=False): + module_name, member_name = _parse_fullname(fullname, strict=strict) + self.fullname = fullname + self.module_name = module_name + self.member_name = member_name + + def import_module(self, *, strict=False): + """Import the described module. + + If *strict* is true, an ImportError may be raised if importing fails, + otherwise, None is returned on error. + """ + return _import_module(self.module_name, strict=strict) - self.builtin = builtin - self.builtin_module_name, self.builtin_method_name = ( - self.builtin.split(".", maxsplit=1) + def import_member(self, *, strict=False): + """Import the described member. + + If *strict* is true, an AttributeError or an ImportError may be + raised if importing fails; otherwise, None is returned on error. + """ + return _import_member( + self.module_name, self.member_name, strict=strict ) - assert openssl is None or openssl.startswith("openssl_") - self.openssl = self.openssl_method_name = openssl - self.openssl_module_name = "_hashlib" if openssl else None - assert hashlib is None or isinstance(hashlib, str) - self.hashlib = self.hashlib_method_name = hashlib - self.hashlib_module_name = "hashlib" if hashlib else None +class _HashInfoBase: + """Base dataclass containing "backend" information. + + Subclasses may define an attribute named after one of the known + implementations ("builtin", "openssl" or "hashlib") which stores + an _HashInfoItem object. + + Those attributes can be retrieved through __getitem__(), e.g., + ``info["builtin"]`` returns the _HashInfoItem corresponding to + the builtin implementation. + """ + + def __init__(self, canonical_name): + self.canonical_name = canonical_name + + def __getitem__(self, implementation): + try: + attrname = Implementation(implementation) + except ValueError: + raise self.invalid_implementation_error(implementation) from None + + try: + provider = getattr(self, attrname) + except AttributeError: + raise self.invalid_implementation_error(implementation) from None + + if not isinstance(provider, _HashInfoItem): + raise KeyError(implementation) + return provider + + def invalid_implementation_error(self, implementation): + msg = f"no implementation {implementation} for {self.canonical_name}" + return AssertionError(msg) + + +class _HashTypeInfo(_HashInfoBase): + """Dataclass containing information for hash functions types. + + - *builtin* is the fully-qualified name for the builtin HACL* type, + e.g., "_md5.MD5Type". + + - *openssl* is the fully-qualified name for the OpenSSL wrapper type, + e.g., "_hashlib.HASH". + """ + + def __init__(self, canonical_name, builtin, openssl): + super().__init__(canonical_name) + self.builtin = _HashInfoItem(builtin, strict=True) + self.openssl = _HashInfoItem(openssl, strict=True) + + def fullname(self, implementation): + """Get the fully qualified name of a given implementation. + + This returns a string of the form "MODULE_NAME.OBJECT_NAME" or None + if the hash function does not have a corresponding implementation. + + *implementation* must be "builtin" or "openssl". + """ + return self[implementation].fullname def module_name(self, implementation): - match implementation: - case "builtin": - return self.builtin_module_name - case "openssl": - return self.openssl_module_name - case "hashlib": - return self.hashlib_module_name - raise AssertionError(f"invalid implementation {implementation}") + """Get the name of the module containing the hash object type.""" + return self[implementation].module_name + + def object_type_name(self, implementation): + """Get the name of the hash object class name.""" + return self[implementation].member_name + + def import_module(self, implementation, *, allow_skip=False): + """Import the module containing the hash object type. + + On error, return None if *allow_skip* is false, or raise SkipNoHash. + """ + target = self[implementation] + module = target.import_module() + if allow_skip and module is None: + reason = f"cannot import module {target.module_name}" + raise SkipNoHash(self.canonical_name, implementation, reason) + return module + + def import_object_type(self, implementation, *, allow_skip=False): + """Get the runtime hash object type. + + On error, return None if *allow_skip* is false, or raise SkipNoHash. + """ + target = self[implementation] + member = target.import_member() + if allow_skip and member is None: + reason = f"cannot import class {target.fullname}" + raise SkipNoHash(self.canonical_name, implementation, reason) + return member - def method_name(self, implementation): - match implementation: - case "builtin": - return self.builtin_method_name - case "openssl": - return self.openssl_method_name - case "hashlib": - return self.hashlib_method_name - raise AssertionError(f"invalid implementation {implementation}") + +class _HashFuncInfo(_HashInfoBase): + """Dataclass containing information for hash functions constructors. + + - *builtin* is the fully-qualified name of the HACL* + hash constructor function, e.g., "_md5.md5". + + - *openssl* is the fully-qualified name of the "_hashlib" method + for the OpenSSL named constructor, e.g., "_hashlib.openssl_md5". + + - *hashlib* is the fully-qualified name of the "hashlib" method + for the explicit named hash constructor, e.g., "hashlib.md5". + """ + + def __init__(self, canonical_name, builtin, openssl=None, hashlib=None): + super().__init__(canonical_name) + self.builtin = _HashInfoItem(builtin, strict=True) + self.openssl = _HashInfoItem(openssl, strict=False) + self.hashlib = _HashInfoItem(hashlib, strict=False) def fullname(self, implementation): """Get the fully qualified name of a given implementation. @@ -122,63 +267,226 @@ def fullname(self, implementation): *implementation* must be "builtin", "openssl" or "hashlib". """ - module_name = self.module_name(implementation) - method_name = self.method_name(implementation) - if module_name is None or method_name is None: - return None - return f"{module_name}.{method_name}" - - -# Mapping from a "canonical" name to a pair (HACL*, _hashlib.*, hashlib.*) -# constructors. If the constructor name is None, then this means that the -# algorithm can only be used by the "agile" new() interfaces. -_EXPLICIT_CONSTRUCTORS = MappingProxyType({ # fmt: skip - HID.md5: HashInfo("_md5.md5", "openssl_md5", "md5"), - HID.sha1: HashInfo("_sha1.sha1", "openssl_sha1", "sha1"), - HID.sha224: HashInfo("_sha2.sha224", "openssl_sha224", "sha224"), - HID.sha256: HashInfo("_sha2.sha256", "openssl_sha256", "sha256"), - HID.sha384: HashInfo("_sha2.sha384", "openssl_sha384", "sha384"), - HID.sha512: HashInfo("_sha2.sha512", "openssl_sha512", "sha512"), - HID.sha3_224: HashInfo( - "_sha3.sha3_224", "openssl_sha3_224", "sha3_224" + return self[implementation].fullname + + def module_name(self, implementation): + """Get the name of the constructor function module. + + The *implementation* must be "builtin", "openssl" or "hashlib". + """ + return self[implementation].module_name + + def method_name(self, implementation): + """Get the name of the constructor function module method. + + Use fullname() to get the constructor function fully-qualified name. + + The *implementation* must be "builtin", "openssl" or "hashlib". + """ + return self[implementation].member_name + + +class _HashInfo: + """Dataclass containing information for supported hash functions. + + - *builtin_object_type_fullname* is the fully-qualified name + for the builtin HACL* type, e.g., "_md5.MD5Type". + + - *openssl_object_type_fullname* is the fully-qualified name + for the OpenSSL wrapper type, i.e. "_hashlib.HASH" or "_hashlib.HASHXOF". + + - *builtin_method_fullname* is the fully-qualified name + of the HACL* hash constructor function, e.g., "_md5.md5". + + - *openssl_method_fullname* is the fully-qualified name + of the "_hashlib" module method for the explicit OpenSSL + hash constructor function, e.g., "_hashlib.openssl_md5". + + - *hashlib_method_fullname* is the fully-qualified name + of the "hashlib" module method for the explicit hash + constructor function, e.g., "hashlib.md5". + """ + + def __init__( + self, + canonical_name, + builtin_object_type_fullname, + openssl_object_type_fullname, + builtin_method_fullname, + openssl_method_fullname=None, + hashlib_method_fullname=None, + ): + self.canonical_name = canonical_name + self.type = _HashTypeInfo( + canonical_name, + builtin_object_type_fullname, + openssl_object_type_fullname, + ) + self.func = _HashFuncInfo( + canonical_name, + builtin_method_fullname, + openssl_method_fullname, + hashlib_method_fullname, + ) + + +_HASHINFO_DATABASE = MappingProxyType({ + _HashId.md5: _HashInfo( + _HashId.md5, + "_md5.MD5Type", + "_hashlib.HASH", + "_md5.md5", + "_hashlib.openssl_md5", + "hashlib.md5", + ), + _HashId.sha1: _HashInfo( + _HashId.sha1, + "_sha1.SHA1Type", + "_hashlib.HASH", + "_sha1.sha1", + "_hashlib.openssl_sha1", + "hashlib.sha1", ), - HID.sha3_256: HashInfo( - "_sha3.sha3_256", "openssl_sha3_256", "sha3_256" + _HashId.sha224: _HashInfo( + _HashId.sha224, + "_sha2.SHA224Type", + "_hashlib.HASH", + "_sha2.sha224", + "_hashlib.openssl_sha224", + "hashlib.sha224", ), - HID.sha3_384: HashInfo( - "_sha3.sha3_384", "openssl_sha3_384", "sha3_384" + _HashId.sha256: _HashInfo( + _HashId.sha256, + "_sha2.SHA256Type", + "_hashlib.HASH", + "_sha2.sha256", + "_hashlib.openssl_sha256", + "hashlib.sha256", ), - HID.sha3_512: HashInfo( - "_sha3.sha3_512", "openssl_sha3_512", "sha3_512" + _HashId.sha384: _HashInfo( + _HashId.sha384, + "_sha2.SHA384Type", + "_hashlib.HASH", + "_sha2.sha384", + "_hashlib.openssl_sha384", + "hashlib.sha384", ), - HID.shake_128: HashInfo( - "_sha3.shake_128", "openssl_shake_128", "shake_128" + _HashId.sha512: _HashInfo( + _HashId.sha512, + "_sha2.SHA512Type", + "_hashlib.HASH", + "_sha2.sha512", + "_hashlib.openssl_sha512", + "hashlib.sha512", ), - HID.shake_256: HashInfo( - "_sha3.shake_256", "openssl_shake_256", "shake_256" + _HashId.sha3_224: _HashInfo( + _HashId.sha3_224, + "_sha3.sha3_224", + "_hashlib.HASH", + "_sha3.sha3_224", + "_hashlib.openssl_sha3_224", + "hashlib.sha3_224", + ), + _HashId.sha3_256: _HashInfo( + _HashId.sha3_256, + "_sha3.sha3_256", + "_hashlib.HASH", + "_sha3.sha3_256", + "_hashlib.openssl_sha3_256", + "hashlib.sha3_256", + ), + _HashId.sha3_384: _HashInfo( + _HashId.sha3_384, + "_sha3.sha3_384", + "_hashlib.HASH", + "_sha3.sha3_384", + "_hashlib.openssl_sha3_384", + "hashlib.sha3_384", + ), + _HashId.sha3_512: _HashInfo( + _HashId.sha3_512, + "_sha3.sha3_512", + "_hashlib.HASH", + "_sha3.sha3_512", + "_hashlib.openssl_sha3_512", + "hashlib.sha3_512", + ), + _HashId.shake_128: _HashInfo( + _HashId.shake_128, + "_sha3.shake_128", + "_hashlib.HASHXOF", + "_sha3.shake_128", + "_hashlib.openssl_shake_128", + "hashlib.shake_128", + ), + _HashId.shake_256: _HashInfo( + _HashId.shake_256, + "_sha3.shake_256", + "_hashlib.HASHXOF", + "_sha3.shake_256", + "_hashlib.openssl_shake_256", + "hashlib.shake_256", + ), + _HashId.blake2s: _HashInfo( + _HashId.blake2s, + "_blake2.blake2s", + "_hashlib.HASH", + "_blake2.blake2s", + None, + "hashlib.blake2s", + ), + _HashId.blake2b: _HashInfo( + _HashId.blake2b, + "_blake2.blake2b", + "_hashlib.HASH", + "_blake2.blake2b", + None, + "hashlib.blake2b", ), - HID.blake2s: HashInfo("_blake2.blake2s", None, "blake2s"), - HID.blake2b: HashInfo("_blake2.blake2b", None, "blake2b"), }) -assert _EXPLICIT_CONSTRUCTORS.keys() == CANONICAL_DIGEST_NAMES -get_hash_info = _EXPLICIT_CONSTRUCTORS.__getitem__ +assert _HASHINFO_DATABASE.keys() == CANONICAL_DIGEST_NAMES + + +def get_hash_type_info(name): + info = _HASHINFO_DATABASE[name] + assert isinstance(info, _HashInfo), info + return info.type + + +def get_hash_func_info(name): + info = _HASHINFO_DATABASE[name] + assert isinstance(info, _HashInfo), info + return info.func + + +def _iter_hash_func_info(excluded): + for name, info in _HASHINFO_DATABASE.items(): + if name not in excluded: + yield info.func + # Mapping from canonical hash names to their explicit HACL* HMAC constructor. # There is currently no OpenSSL one-shot named function and there will likely # be none in the future. -_EXPLICIT_HMAC_CONSTRUCTORS = { - HID(name): f"_hmac.compute_{name}" - for name in CANONICAL_DIGEST_NAMES +_HMACINFO_DATABASE = { + _HashId(canonical_name): _HashInfoItem(f"_hmac.compute_{canonical_name}") + for canonical_name in CANONICAL_DIGEST_NAMES } # Neither HACL* nor OpenSSL supports HMAC over XOFs. -_EXPLICIT_HMAC_CONSTRUCTORS[HID.shake_128] = None -_EXPLICIT_HMAC_CONSTRUCTORS[HID.shake_256] = None +_HMACINFO_DATABASE[_HashId.shake_128] = _HashInfoItem() +_HMACINFO_DATABASE[_HashId.shake_256] = _HashInfoItem() # Strictly speaking, HMAC-BLAKE is meaningless as BLAKE2 is already a # keyed hash function. However, as it's exposed by HACL*, we test it. -_EXPLICIT_HMAC_CONSTRUCTORS[HID.blake2s] = '_hmac.compute_blake2s_32' -_EXPLICIT_HMAC_CONSTRUCTORS[HID.blake2b] = '_hmac.compute_blake2b_32' -_EXPLICIT_HMAC_CONSTRUCTORS = MappingProxyType(_EXPLICIT_HMAC_CONSTRUCTORS) -assert _EXPLICIT_HMAC_CONSTRUCTORS.keys() == CANONICAL_DIGEST_NAMES +_HMACINFO_DATABASE[_HashId.blake2s] = _HashInfoItem('_hmac.compute_blake2s_32') +_HMACINFO_DATABASE[_HashId.blake2b] = _HashInfoItem('_hmac.compute_blake2b_32') +_HMACINFO_DATABASE = MappingProxyType(_HMACINFO_DATABASE) +assert _HMACINFO_DATABASE.keys() == CANONICAL_DIGEST_NAMES + + +def get_hmac_item_info(name): + info = _HMACINFO_DATABASE[name] + assert isinstance(info, _HashInfoItem), info + return info def _decorate_func_or_class(decorator_func, func_or_class): @@ -230,26 +538,42 @@ def _ensure_wrapper_signature(wrapper, wrapped): ) -def requires_hashlib(): - _hashlib = try_import_module("_hashlib") +def _make_conditional_decorator(test, /, *test_args, **test_kwargs): + def decorator_func(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + test(*test_args, **test_kwargs) + return func(*args, **kwargs) + return wrapper + return functools.partial(_decorate_func_or_class, decorator_func) + + +def requires_openssl_hashlib(): + _hashlib = _import_module("_hashlib") return unittest.skipIf(_hashlib is None, "requires _hashlib") def requires_builtin_hmac(): - _hmac = try_import_module("_hmac") + _hmac = _import_module("_hmac") return unittest.skipIf(_hmac is None, "requires _hmac") class SkipNoHash(unittest.SkipTest): """A SkipTest exception raised when a hash is not available.""" - def __init__(self, digestname, implementation=None, interface=None): + def __init__(self, digestname, implementation=None, reason=None): parts = ["missing", implementation, f"hash algorithm {digestname!r}"] - if interface is not None: - parts.append(f"for {interface}") + if reason is not None: + parts.insert(0, f"{reason}: ") super().__init__(" ".join(filter(None, parts))) +class SkipNoHashInCall(SkipNoHash): + + def __init__(self, func, digestname, implementation=None): + super().__init__(digestname, implementation, f"cannot use {func}") + + def _hashlib_new(digestname, openssl, /, **kwargs): """Check availability of [hashlib|_hashlib].new(digestname, **kwargs). @@ -264,13 +588,12 @@ def _hashlib_new(digestname, openssl, /, **kwargs): # exceptions as it should be unconditionally available. hashlib = importlib.import_module("hashlib") # re-import '_hashlib' in case it was mocked - _hashlib = try_import_module("_hashlib") + _hashlib = _import_module("_hashlib") module = _hashlib if openssl and _hashlib is not None else hashlib try: module.new(digestname, **kwargs) except ValueError as exc: - interface = f"{module.__name__}.new" - raise SkipNoHash(digestname, interface=interface) from exc + raise SkipNoHashInCall(f"{module.__name__}.new", digestname) from exc return functools.partial(module.new, digestname) @@ -315,7 +638,7 @@ def _openssl_new(digestname, /, **kwargs): try: _hashlib.new(digestname, **kwargs) except ValueError as exc: - raise SkipNoHash(digestname, interface="_hashlib.new") from exc + raise SkipNoHashInCall("_hashlib.new", digestname) from exc return functools.partial(_hashlib.new, digestname) @@ -326,14 +649,15 @@ def _openssl_hash(digestname, /, **kwargs): or SkipTest is raised if none exists. """ assert isinstance(digestname, str), digestname - fullname = f"_hashlib.openssl_{digestname}" + method_name = f"openssl_{digestname}" + fullname = f"_hashlib.{method_name}" try: # re-import '_hashlib' in case it was mocked _hashlib = importlib.import_module("_hashlib") except ImportError as exc: raise SkipNoHash(fullname, "openssl") from exc try: - constructor = getattr(_hashlib, f"openssl_{digestname}", None) + constructor = getattr(_hashlib, method_name) except AttributeError as exc: raise SkipNoHash(fullname, "openssl") from exc try: @@ -343,16 +667,6 @@ def _openssl_hash(digestname, /, **kwargs): return constructor -def _make_requires_hashdigest_decorator(test, /, *test_args, **test_kwargs): - def decorator_func(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - test(*test_args, **test_kwargs) - return func(*args, **kwargs) - return wrapper - return functools.partial(_decorate_func_or_class, decorator_func) - - def requires_hashdigest(digestname, openssl=None, *, usedforsecurity=True): """Decorator raising SkipTest if a hashing algorithm is not available. @@ -370,7 +684,7 @@ def requires_hashdigest(digestname, openssl=None, *, usedforsecurity=True): ValueError: [digital envelope routines: EVP_DigestInit_ex] disabled for FIPS ValueError: unsupported hash type md4 """ - return _make_requires_hashdigest_decorator( + return _make_conditional_decorator( _hashlib_new, digestname, openssl, usedforsecurity=usedforsecurity ) @@ -380,34 +694,35 @@ def requires_openssl_hashdigest(digestname, *, usedforsecurity=True): The hashing algorithm may be missing or blocked by a strict crypto policy. """ - return _make_requires_hashdigest_decorator( + return _make_conditional_decorator( _openssl_new, digestname, usedforsecurity=usedforsecurity ) -def requires_builtin_hashdigest( - module_name, digestname, *, usedforsecurity=True -): - """Decorator raising SkipTest if a HACL* hashing algorithm is missing. +def _make_requires_builtin_hashdigest_decorator(item, *, usedforsecurity=True): + assert isinstance(item, _HashInfoItem), item + return _make_conditional_decorator( + _builtin_hash, + item.module_name, + item.member_name, + usedforsecurity=usedforsecurity, + ) - - The *module_name* is the C extension module name based on HACL*. - - The *digestname* is one of its member, e.g., 'md5'. - """ - return _make_requires_hashdigest_decorator( - _builtin_hash, module_name, digestname, usedforsecurity=usedforsecurity + +def requires_builtin_hashdigest(canonical_name, *, usedforsecurity=True): + """Decorator raising SkipTest if a HACL* hashing algorithm is missing.""" + info = get_hash_func_info(canonical_name) + return _make_requires_builtin_hashdigest_decorator( + info.builtin, usedforsecurity=usedforsecurity ) -def requires_builtin_hashes(*ignored, usedforsecurity=True): +def requires_builtin_hashes(*, exclude=(), usedforsecurity=True): """Decorator raising SkipTest if one HACL* hashing algorithm is missing.""" return _chain_decorators(( - requires_builtin_hashdigest( - api.builtin_module_name, - api.builtin_method_name, - usedforsecurity=usedforsecurity, - ) - for name, api in _EXPLICIT_CONSTRUCTORS.items() - if name not in ignored + _make_requires_builtin_hashdigest_decorator( + info.builtin, usedforsecurity=usedforsecurity + ) for info in _iter_hash_func_info(exclude) )) @@ -424,69 +739,31 @@ class HashFunctionsTrait: implementation of HMAC). """ - DIGEST_NAMES = [ - 'md5', 'sha1', - 'sha224', 'sha256', 'sha384', 'sha512', - 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512', - ] - # Default 'usedforsecurity' to use when checking a hash function. # When the trait properties are callables (e.g., _md5.md5) and # not strings, they must be called with the same 'usedforsecurity'. usedforsecurity = True - @classmethod - def setUpClass(cls): - super().setUpClass() - assert CANONICAL_DIGEST_NAMES.issuperset(cls.DIGEST_NAMES) - def is_valid_digest_name(self, digestname): - self.assertIn(digestname, self.DIGEST_NAMES) + self.assertIn(digestname, _HashId) def _find_constructor(self, digestname): # By default, a missing algorithm skips the test that uses it. self.is_valid_digest_name(digestname) self.skipTest(f"missing hash function: {digestname}") - @property - def md5(self): - return self._find_constructor("md5") + md5 = property(lambda self: self._find_constructor("md5")) + sha1 = property(lambda self: self._find_constructor("sha1")) - @property - def sha1(self): - return self._find_constructor("sha1") + sha224 = property(lambda self: self._find_constructor("sha224")) + sha256 = property(lambda self: self._find_constructor("sha256")) + sha384 = property(lambda self: self._find_constructor("sha384")) + sha512 = property(lambda self: self._find_constructor("sha512")) - @property - def sha224(self): - return self._find_constructor("sha224") - - @property - def sha256(self): - return self._find_constructor("sha256") - - @property - def sha384(self): - return self._find_constructor("sha384") - - @property - def sha512(self): - return self._find_constructor("sha512") - - @property - def sha3_224(self): - return self._find_constructor("sha3_224") - - @property - def sha3_256(self): - return self._find_constructor("sha3_256") - - @property - def sha3_384(self): - return self._find_constructor("sha3_384") - - @property - def sha3_512(self): - return self._find_constructor("sha3_512") + sha3_224 = property(lambda self: self._find_constructor("sha3_224")) + sha3_256 = property(lambda self: self._find_constructor("sha3_256")) + sha3_384 = property(lambda self: self._find_constructor("sha3_384")) + sha3_512 = property(lambda self: self._find_constructor("sha3_512")) class NamedHashFunctionsTrait(HashFunctionsTrait): @@ -497,7 +774,7 @@ class NamedHashFunctionsTrait(HashFunctionsTrait): def _find_constructor(self, digestname): self.is_valid_digest_name(digestname) - return digestname + return str(digestname) # ensure that we are an exact string class OpenSSLHashFunctionsTrait(HashFunctionsTrait): @@ -523,10 +800,10 @@ class BuiltinHashFunctionsTrait(HashFunctionsTrait): def _find_constructor(self, digestname): self.is_valid_digest_name(digestname) - info = _EXPLICIT_CONSTRUCTORS[digestname] + info = get_hash_func_info(digestname) return _builtin_hash( - info.builtin_module_name, - info.builtin_method_name, + info.builtin.module_name, + info.builtin.member_name, usedforsecurity=self.usedforsecurity, ) @@ -542,7 +819,7 @@ def find_gil_minsize(modules_names, default=2048): """ sizes = [] for module_name in modules_names: - module = try_import_module(module_name) + module = _import_module(module_name) if module is not None: sizes.append(getattr(module, '_GIL_MINSIZE', default)) return max(sizes, default=default) @@ -553,7 +830,7 @@ def _block_openssl_hash_new(blocked_name): assert isinstance(blocked_name, str), blocked_name # re-import '_hashlib' in case it was mocked - if (_hashlib := try_import_module("_hashlib")) is None: + if (_hashlib := _import_module("_hashlib")) is None: return contextlib.nullcontext() @functools.wraps(wrapped := _hashlib.new) @@ -572,7 +849,7 @@ def _block_openssl_hmac_new(blocked_name): assert isinstance(blocked_name, str), blocked_name # re-import '_hashlib' in case it was mocked - if (_hashlib := try_import_module("_hashlib")) is None: + if (_hashlib := _import_module("_hashlib")) is None: return contextlib.nullcontext() @functools.wraps(wrapped := _hashlib.hmac_new) @@ -590,7 +867,7 @@ def _block_openssl_hmac_digest(blocked_name): assert isinstance(blocked_name, str), blocked_name # re-import '_hashlib' in case it was mocked - if (_hashlib := try_import_module("_hashlib")) is None: + if (_hashlib := _import_module("_hashlib")) is None: return contextlib.nullcontext() @functools.wraps(wrapped := _hashlib.hmac_digest) @@ -607,7 +884,7 @@ def _block_builtin_hash_new(name): """Block a buitin-in hash name from the hashlib.new() interface.""" assert isinstance(name, str), name assert name.lower() == name, f"invalid name: {name}" - assert name in HID, f"invalid hash: {name}" + assert name in _HashId, f"invalid hash: {name}" # Re-import 'hashlib' in case it was mocked hashlib = importlib.import_module('hashlib') @@ -620,7 +897,7 @@ def _block_builtin_hash_new(name): # so we need to block the possibility of importing it, but only # during the call to __get_builtin_constructor(). get_builtin_constructor = getattr(hashlib, '__get_builtin_constructor') - builtin_module_name = _EXPLICIT_CONSTRUCTORS[name].builtin_module_name + builtin_module_name = get_hash_func_info(name).builtin.module_name @functools.wraps(get_builtin_constructor) def get_builtin_constructor_mock(name): @@ -632,7 +909,7 @@ def get_builtin_constructor_mock(name): return unittest.mock.patch.multiple( hashlib, __get_builtin_constructor=get_builtin_constructor_mock, - __builtin_constructor_cache=builtin_constructor_cache_mock + __builtin_constructor_cache=builtin_constructor_cache_mock, ) @@ -640,7 +917,7 @@ def _block_builtin_hmac_new(blocked_name): assert isinstance(blocked_name, str), blocked_name # re-import '_hmac' in case it was mocked - if (_hmac := try_import_module("_hmac")) is None: + if (_hmac := _import_module("_hmac")) is None: return contextlib.nullcontext() @functools.wraps(wrapped := _hmac.new) @@ -657,7 +934,7 @@ def _block_builtin_hmac_digest(blocked_name): assert isinstance(blocked_name, str), blocked_name # re-import '_hmac' in case it was mocked - if (_hmac := try_import_module("_hmac")) is None: + if (_hmac := _import_module("_hmac")) is None: return contextlib.nullcontext() @functools.wraps(wrapped := _hmac.compute_digest) @@ -671,30 +948,19 @@ def _hmac_compute_digest(key, msg, digest): def _make_hash_constructor_blocker(name, dummy, implementation): - info = _EXPLICIT_CONSTRUCTORS[name] - module_name = info.module_name(implementation) - method_name = info.method_name(implementation) - if module_name is None or method_name is None: + info = get_hash_func_info(name)[implementation] + if (wrapped := info.import_member()) is None: # function shouldn't exist for this implementation return contextlib.nullcontext() - - try: - module = importlib.import_module(module_name) - except ImportError: - # module is already disabled - return contextlib.nullcontext() - - wrapped = getattr(module, method_name) wrapper = functools.wraps(wrapped)(dummy) _ensure_wrapper_signature(wrapper, wrapped) - return unittest.mock.patch(info.fullname(implementation), wrapper) + return unittest.mock.patch(info.fullname, wrapper) def _block_hashlib_hash_constructor(name): """Block explicit public constructors.""" def dummy(data=b'', *, usedforsecurity=True, string=None): raise ValueError(f"blocked explicit public hash name: {name}") - return _make_hash_constructor_blocker(name, dummy, 'hashlib') @@ -714,23 +980,18 @@ def dummy(data=b'', *, usedforsecurity=True, string=b''): def _block_builtin_hmac_constructor(name): """Block explicit HACL* HMAC constructors.""" - fullname = _EXPLICIT_HMAC_CONSTRUCTORS[name] - if fullname is None: + info = get_hmac_item_info(name) + assert info.module_name is None or info.module_name == "_hmac", info + if (wrapped := info.import_member()) is None: # function shouldn't exist for this implementation return contextlib.nullcontext() - assert fullname.count('.') == 1, fullname - module_name, method = fullname.split('.', maxsplit=1) - assert module_name == '_hmac', module_name - try: - module = importlib.import_module(module_name) - except ImportError: - # module is already disabled - return contextlib.nullcontext() - @functools.wraps(wrapped := getattr(module, method)) + + @functools.wraps(wrapped) def wrapper(key, obj): raise ValueError(f"blocked hash name: {name}") + _ensure_wrapper_signature(wrapper, wrapped) - return unittest.mock.patch(fullname, wrapper) + return unittest.mock.patch(info.fullname, wrapper) @contextlib.contextmanager @@ -760,14 +1021,14 @@ def block_algorithm(name, *, allow_openssl=False, allow_builtin=False): # the OpenSSL implementation, except with usedforsecurity=False. # However, blocking such functions also means blocking them # so we again need to block them if we want to. - (_hashlib := try_import_module("_hashlib")) + (_hashlib := _import_module("_hashlib")) and _hashlib.get_fips_mode() and not allow_openssl ) or ( # Without OpenSSL, hashlib.() functions are aliases # to built-in functions, so both of them must be blocked # as the module may have been imported before the HACL ones. - not (_hashlib := try_import_module("_hashlib")) + not (_hashlib := _import_module("_hashlib")) and not allow_builtin ): stack.enter_context(_block_hashlib_hash_constructor(name)) @@ -794,3 +1055,21 @@ def block_algorithm(name, *, allow_openssl=False, allow_builtin=False): # _hmac.compute_digest(..., name) stack.enter_context(_block_builtin_hmac_digest(name)) yield + + +@contextlib.contextmanager +def block_openssl_algorithms(*, exclude=()): + """Block OpenSSL implementations, except those given in *exclude*.""" + with contextlib.ExitStack() as stack: + for name in CANONICAL_DIGEST_NAMES.difference(exclude): + stack.enter_context(block_algorithm(name, allow_builtin=True)) + yield + + +@contextlib.contextmanager +def block_builtin_algorithms(*, exclude=()): + """Block HACL* implementations, except those given in *exclude*.""" + with contextlib.ExitStack() as stack: + for name in CANONICAL_DIGEST_NAMES.difference(exclude): + stack.enter_context(block_algorithm(name, allow_openssl=True)) + yield diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index 5c29369d10b143..7634deeb1d8eb9 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -161,7 +161,7 @@ def hmac_digest(self, key, msg=None, digestmod=DIGESTMOD_SENTINEL): return _call_digest_func(self.hmac.digest, key, msg, digestmod) -@hashlib_helper.requires_hashlib() +@hashlib_helper.requires_openssl_hashlib() class ThroughOpenSSLAPIMixin(CreatorMixin, DigestMixin): """Mixin delegating to _hashlib.hmac_new() and _hashlib.hmac_digest().""" @@ -1431,7 +1431,7 @@ def test_compare_digest_func(self): self.assertIs(self.compare_digest, operator_compare_digest) -@hashlib_helper.requires_hashlib() +@hashlib_helper.requires_openssl_hashlib() class OpenSSLCompareDigestTestCase(CompareDigestMixin, unittest.TestCase): compare_digest = openssl_compare_digest @@ -1509,7 +1509,7 @@ def test_hmac_digest_overflow_error_openssl_only(self, size): hmac = import_fresh_module("hmac", blocked=["_hmac"]) self.do_test_hmac_digest_overflow_error_switch_to_slow(hmac, size) - @hashlib_helper.requires_builtin_hashdigest("_md5", "md5") + @hashlib_helper.requires_builtin_hashdigest("md5") @bigmemtest(size=_4G + 5, memuse=2, dry_run=False) def test_hmac_digest_overflow_error_builtin_only(self, size): hmac = import_fresh_module("hmac", blocked=["_hashlib"]) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 9ec382afb65fe4..ef72e3e5b58ec3 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -865,22 +865,22 @@ def try_import_attribute(self, fullname, default=None): return default def fetch_hash_function(self, name, implementation): - info = hashlib_helper.get_hash_info(name) - match implementation: - case "hashlib": - assert info.hashlib is not None, info - return getattr(self.hashlib, info.hashlib) - case "openssl": - try: - return getattr(self._hashlib, info.openssl, None) - except TypeError: - return None - fullname = info.fullname(implementation) + info = hashlib_helper.get_hash_func_info(name) + match hashlib_helper.Implementation(implementation): + case hashlib_helper.Implementation.hashlib: + method_name = info.hashlib.member_name + assert isinstance(method_name, str), method_name + return getattr(self.hashlib, method_name) + case hashlib_helper.Implementation.openssl: + method_name = info.openssl.member_name + assert isinstance(method_name, str | None), method_name + return getattr(self._hashlib, method_name or "", None) + fullname = info[implementation].fullname return self.try_import_attribute(fullname) def fetch_hmac_function(self, name): - fullname = hashlib_helper._EXPLICIT_HMAC_CONSTRUCTORS[name] - return self.try_import_attribute(fullname) + target = hashlib_helper.get_hmac_item_info(name) + return target.import_member() def check_openssl_hash(self, name, *, disabled=True): """Check that OpenSSL HASH interface is enabled/disabled."""