Skip to content

gh-104873: Add typing.get_protocol_members and typing.is_protocol #104878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Doc/library/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3388,6 +3388,38 @@ Introspection helpers

.. versionadded:: 3.8

.. function:: get_protocol_members(tp)

Return the set of members defined in a :class:`Protocol`.

::

>>> from typing import Protocol, get_protocol_members
>>> class P(Protocol):
... def a(self) -> str: ...
... b: int
>>> get_protocol_members(P)
frozenset({'a', 'b'})

Raise :exc:`TypeError` for arguments that are not Protocols.

.. versionadded:: 3.13

.. function:: is_protocol(tp)

Determine if a type is a :class:`Protocol`.

For example::

class P(Protocol):
def a(self) -> str: ...
b: int

is_protocol(P) # => True
is_protocol(int) # => False

.. versionadded:: 3.13

.. function:: is_typeddict(tp)

Check if a type is a :class:`TypedDict`.
Expand Down
8 changes: 8 additions & 0 deletions Doc/whatsnew/3.13.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ traceback
to format the nested exceptions of a :exc:`BaseExceptionGroup` instance, recursively.
(Contributed by Irit Katriel in :gh:`105292`.)

typing
------

* Add :func:`typing.get_protocol_members` to return the set of members
defining a :class:`typing.Protocol`. Add :func:`typing.is_protocol` to
check whether a class is a :class:`typing.Protocol`. (Contributed by Jelle Zijlstra in
:gh:`104873`.)

Optimizations
=============

Expand Down
69 changes: 67 additions & 2 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from typing import Generic, ClassVar, Final, final, Protocol
from typing import assert_type, cast, runtime_checkable
from typing import get_type_hints
from typing import get_origin, get_args
from typing import get_origin, get_args, get_protocol_members
from typing import override
from typing import is_typeddict
from typing import is_typeddict, is_protocol
from typing import reveal_type
from typing import dataclass_transform
from typing import no_type_check, no_type_check_decorator
Expand Down Expand Up @@ -3363,6 +3363,18 @@ def meth(self): pass
self.assertNotIn("__callable_proto_members_only__", vars(NonP))
self.assertNotIn("__callable_proto_members_only__", vars(NonPR))

self.assertEqual(get_protocol_members(P), {"x"})
self.assertEqual(get_protocol_members(PR), {"meth"})

# the returned object should be immutable,
# and should be a different object to the original attribute
# to prevent users from (accidentally or deliberately)
# mutating the attribute on the original class
self.assertIsInstance(get_protocol_members(P), frozenset)
self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)
self.assertIsInstance(get_protocol_members(PR), frozenset)
self.assertIsNot(get_protocol_members(PR), P.__protocol_attrs__)

acceptable_extra_attrs = {
'_is_protocol', '_is_runtime_protocol', '__parameters__',
'__init__', '__annotations__', '__subclasshook__',
Expand Down Expand Up @@ -3778,6 +3790,59 @@ def __init__(self):

Foo() # Previously triggered RecursionError

def test_get_protocol_members(self):
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(object)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(object())
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Protocol)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Generic)

class P(Protocol):
a: int
def b(self) -> str: ...
@property
def c(self) -> int: ...

self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'})
self.assertIsInstance(get_protocol_members(P), frozenset)
self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)

class Concrete:
a: int
def b(self) -> str: return "capybara"
@property
def c(self) -> int: return 5

with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Concrete)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(Concrete())

class ConcreteInherit(P):
a: int = 42
def b(self) -> str: return "capybara"
@property
def c(self) -> int: return 5

with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(ConcreteInherit)
with self.assertRaisesRegex(TypeError, "not a Protocol"):
get_protocol_members(ConcreteInherit())

def test_is_protocol(self):
self.assertTrue(is_protocol(Proto))
self.assertTrue(is_protocol(Point))
self.assertFalse(is_protocol(Concrete))
self.assertFalse(is_protocol(Concrete()))
self.assertFalse(is_protocol(Generic))
self.assertFalse(is_protocol(object))

# Protocol is not itself a protocol
self.assertFalse(is_protocol(Protocol))

def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self):
# Ensure the cache is empty, or this test won't work correctly
collections.abc.Sized._abc_registry_clear()
Expand Down
42 changes: 42 additions & 0 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@
'get_args',
'get_origin',
'get_overloads',
'get_protocol_members',
'get_type_hints',
'is_protocol',
'is_typeddict',
'LiteralString',
'Never',
Expand Down Expand Up @@ -3337,3 +3339,43 @@ def method(self) -> None:
# read-only property, TypeError if it's a builtin class.
pass
return method


def is_protocol(tp: type, /) -> bool:
"""Return True if the given type is a Protocol.

Example::

>>> from typing import Protocol, is_protocol
>>> class P(Protocol):
... def a(self) -> str: ...
... b: int
>>> is_protocol(P)
True
>>> is_protocol(int)
False
"""
return (
isinstance(tp, type)
and getattr(tp, '_is_protocol', False)
and tp != Protocol
)


def get_protocol_members(tp: type, /) -> frozenset[str]:
"""Return the set of members defined in a Protocol.

Example::

>>> from typing import Protocol, get_protocol_members
>>> class P(Protocol):
... def a(self) -> str: ...
... b: int
>>> get_protocol_members(P)
frozenset({'a', 'b'})

Raise a TypeError for arguments that are not Protocols.
"""
if not is_protocol(tp):
raise TypeError(f'{tp!r} is not a Protocol')
return frozenset(tp.__protocol_attrs__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add :func:`typing.get_protocol_members` to return the set of members
defining a :class:`typing.Protocol`. Add :func:`typing.is_protocol` to
check whether a class is a :class:`typing.Protocol`. Patch by Jelle Zijlstra.