diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index ac0a2691..4b303e92 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -37,6 +37,7 @@ from typing_extensions import clear_overloads, get_overloads, overload from typing_extensions import NamedTuple from typing_extensions import override, deprecated, Buffer, TypeAliasType, TypeVar +from typing_extensions import get_typing_objects_by_name_of, is_typing_name from _typed_dict_test_helper import Foo, FooGeneric # Flags used to mark tests that only apply after a specific @@ -4988,5 +4989,91 @@ class MyAlias(TypeAliasType): pass +class IntrospectionHelperTests(BaseTestCase): + def test_typing_objects_by_name_of(self): + for name in typing_extensions.__all__: + with self.subTest(name=name): + objs = get_typing_objects_by_name_of(name) + self.assertIsInstance(objs, tuple) + self.assertIn(len(objs), (1, 2)) + te_obj = getattr(typing_extensions, name) + if len(objs) == 1: + self.assertIs(te_obj, getattr(typing, name, te_obj)) + else: + self.assertTrue(hasattr(typing, name)) + self.assertIsNot(te_obj, getattr(typing, name)) + + with self.assertRaisesRegex( + ValueError, + "Neither typing nor typing_extensions has an object called 'foo'" + ): + get_typing_objects_by_name_of("foo") + + def test_typing_objects_by_name_not_in_typing_extensions(self): + objs = get_typing_objects_by_name_of("ByteString") + self.assertIsInstance(objs, tuple) + self.assertEqual(len(objs), 1) + bytestring = objs[0] + self.assertIs(bytestring, typing.ByteString) + self.assertEqual(bytestring.__module__, "typing") + + def test_typing_objects_by_name_of_2(self): + classvar_objs = get_typing_objects_by_name_of("ClassVar") + self.assertEqual(len(classvar_objs), 1) + classvar_obj = classvar_objs[0] + self.assertIs(classvar_obj, typing.ClassVar) + self.assertIs(classvar_obj, typing_extensions.ClassVar) + self.assertEqual(classvar_obj.__module__, "typing") + + @skipIf(TYPING_3_12_0, "We reexport TypeAliasType from typing on 3.12+") + def test_typing_objects_by_name_of_2(self): + name = "TypeAliasType" + # Sanity check; the test won't work correctly if this doesn't hold true: + self.assertFalse(hasattr(typing, name)) + typealiastype_objs = get_typing_objects_by_name_of(name) + self.assertEqual(len(typealiastype_objs), 1) + typealiastype_obj = typealiastype_objs[0] + self.assertIs(typealiastype_obj, typing_extensions.TypeAliasType) + self.assertEqual(typealiastype_obj.__module__, "typing_extensions") + + @skipUnless( + (3, 8) <= sys.version_info < (3, 12), + ( + "Needs a Python version where typing.Protocol " + "and typing_extensions.Protocol are different objects" + ) + ) + def test_typing_objects_by_name_of_3(self): + name = "Protocol" + # Sanity check; the test won't work correctly if this doesn't hold true: + self.assertTrue(hasattr(typing, name)) + protocol_objs = get_typing_objects_by_name_of(name) + self.assertEqual(len(protocol_objs), 2) + modules = {obj.__module__ for obj in protocol_objs} + self.assertEqual(modules, {"typing", "typing_extensions"}) + + def test_is_typing_name(self): + for name in typing_extensions.__all__: + te_obj = getattr(typing_extensions, name) + self.assertTrue(is_typing_name(te_obj, name)) + if hasattr(typing, name): + typing_obj = getattr(typing, name) + self.assertTrue(is_typing_name(typing_obj, name)) + + def test_is_typing_name_fails_appropriately(self): + self.assertFalse(is_typing_name(typing_extensions.NoReturn, "ClassVar")) + self.assertFalse(is_typing_name(typing.NoReturn, "ClassVar")) + error_msg = "Neither typing nor typing_extensions has an object called 'foo'" + with self.assertRaisesRegex(ValueError, error_msg): + is_typing_name(typing_extensions.NoReturn, "foo") + with self.assertRaisesRegex(ValueError, error_msg): + is_typing_name(typing_extensions.NoReturn, "foo") + + def test_is_typing_name_not_in_typing_extensions(self): + # Sanity check -- this is a useless test otherwise: + self.assertFalse(hasattr(typing_extensions, "ByteString")) + self.assertTrue(is_typing_name(typing.ByteString, "ByteString")) + + if __name__ == '__main__': main() diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 9aa84d7e..a66ca6bd 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -85,6 +85,11 @@ 'NoReturn', 'Required', 'NotRequired', + + # Introspection helpers unique to typing_extensions + # These will never be added to typing.py in CPython + 'get_typing_objects_by_name_of', + 'is_typing_name', ] # for backward compatibility @@ -2873,3 +2878,38 @@ def __ror__(self, left): if not _is_unionable(left): return NotImplemented return typing.Union[left, self] + + +############################################################# +# Introspection helpers for third-party libraries +# +# These are not part of the typing-module API, +# and nor will they ever become part of the typing-module API. +# +# They are specific to typing-extensions +############################################################## + + +@functools.lru_cache(maxsize=None) +def get_typing_objects_by_name_of(name: str) -> typing.Tuple[Any, ...]: + try: + te_obj = globals()[name] + except KeyError: + try: + typing_obj = getattr(typing, name) + except AttributeError: + raise ValueError( + f"Neither typing nor typing_extensions has an object called {name!r}!" + ) from None + else: + return (typing_obj,) + else: + if hasattr(typing, name): + typing_obj = getattr(typing, name) + if typing_obj is not te_obj: + return (te_obj, typing_obj) + return (te_obj,) + + +def is_typing_name(obj: object, name: str) -> bool: + return any(obj is thing for thing in get_typing_objects_by_name_of(name))