diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index 949b108c60c..487be8f28a7 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -3388,6 +3388,38 @@ Introspection helpers .. versionadded:: 3.8 +.. function:: get_protocol_members(tp) + + Return the set of members defined in a :class:`Protocol`. + + :: + + >>> from typing import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) + frozenset({'a', 'b'}) + + Raise :exc:`TypeError` for arguments that are not Protocols. + + .. versionadded:: 3.13 + +.. function:: is_protocol(tp) + + Determine if a type is a :class:`Protocol`. + + For example:: + + class P(Protocol): + def a(self) -> str: ... + b: int + + is_protocol(P) # => True + is_protocol(int) # => False + + .. versionadded:: 3.13 + .. function:: is_typeddict(tp) Check if a type is a :class:`TypedDict`. diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst index 78d2a7b6b29..fcd10e522c8 100644 --- a/Doc/whatsnew/3.13.rst +++ b/Doc/whatsnew/3.13.rst @@ -120,6 +120,14 @@ traceback to format the nested exceptions of a :exc:`BaseExceptionGroup` instance, recursively. (Contributed by Irit Katriel in :gh:`105292`.) +typing +------ + +* Add :func:`typing.get_protocol_members` to return the set of members + defining a :class:`typing.Protocol`. Add :func:`typing.is_protocol` to + check whether a class is a :class:`typing.Protocol`. (Contributed by Jelle Zijlstra in + :gh:`104873`.) + Optimizations ============= diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 432fc88b1c0..a36d801c525 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -24,9 +24,9 @@ from typing import Callable from typing import Generic, ClassVar, Final, final, Protocol from typing import assert_type, cast, runtime_checkable from typing import get_type_hints -from typing import get_origin, get_args +from typing import get_origin, get_args, get_protocol_members from typing import override -from typing import is_typeddict +from typing import is_typeddict, is_protocol from typing import reveal_type from typing import dataclass_transform from typing import no_type_check, no_type_check_decorator @@ -3363,6 +3363,18 @@ class ProtocolTests(BaseTestCase): self.assertNotIn("__callable_proto_members_only__", vars(NonP)) self.assertNotIn("__callable_proto_members_only__", vars(NonPR)) + self.assertEqual(get_protocol_members(P), {"x"}) + self.assertEqual(get_protocol_members(PR), {"meth"}) + + # the returned object should be immutable, + # and should be a different object to the original attribute + # to prevent users from (accidentally or deliberately) + # mutating the attribute on the original class + self.assertIsInstance(get_protocol_members(P), frozenset) + self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__) + self.assertIsInstance(get_protocol_members(PR), frozenset) + self.assertIsNot(get_protocol_members(PR), P.__protocol_attrs__) + acceptable_extra_attrs = { '_is_protocol', '_is_runtime_protocol', '__parameters__', '__init__', '__annotations__', '__subclasshook__', @@ -3778,6 +3790,59 @@ class ProtocolTests(BaseTestCase): Foo() # Previously triggered RecursionError + def test_get_protocol_members(self): + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(object) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(object()) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Protocol) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Generic) + + class P(Protocol): + a: int + def b(self) -> str: ... + @property + def c(self) -> int: ... + + self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'}) + self.assertIsInstance(get_protocol_members(P), frozenset) + self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__) + + class Concrete: + a: int + def b(self) -> str: return "capybara" + @property + def c(self) -> int: return 5 + + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Concrete) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(Concrete()) + + class ConcreteInherit(P): + a: int = 42 + def b(self) -> str: return "capybara" + @property + def c(self) -> int: return 5 + + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(ConcreteInherit) + with self.assertRaisesRegex(TypeError, "not a Protocol"): + get_protocol_members(ConcreteInherit()) + + def test_is_protocol(self): + self.assertTrue(is_protocol(Proto)) + self.assertTrue(is_protocol(Point)) + self.assertFalse(is_protocol(Concrete)) + self.assertFalse(is_protocol(Concrete())) + self.assertFalse(is_protocol(Generic)) + self.assertFalse(is_protocol(object)) + + # Protocol is not itself a protocol + self.assertFalse(is_protocol(Protocol)) + def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self): # Ensure the cache is empty, or this test won't work correctly collections.abc.Sized._abc_registry_clear() diff --git a/Lib/typing.py b/Lib/typing.py index a531e7d7abb..4e6dc447735 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -131,7 +131,9 @@ __all__ = [ 'get_args', 'get_origin', 'get_overloads', + 'get_protocol_members', 'get_type_hints', + 'is_protocol', 'is_typeddict', 'LiteralString', 'Never', @@ -3337,3 +3339,43 @@ def override[F: _Func](method: F, /) -> F: # read-only property, TypeError if it's a builtin class. pass return method + + +def is_protocol(tp: type, /) -> bool: + """Return True if the given type is a Protocol. + + Example:: + + >>> from typing import Protocol, is_protocol + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> is_protocol(P) + True + >>> is_protocol(int) + False + """ + return ( + isinstance(tp, type) + and getattr(tp, '_is_protocol', False) + and tp != Protocol + ) + + +def get_protocol_members(tp: type, /) -> frozenset[str]: + """Return the set of members defined in a Protocol. + + Example:: + + >>> from typing import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) + frozenset({'a', 'b'}) + + Raise a TypeError for arguments that are not Protocols. + """ + if not is_protocol(tp): + raise TypeError(f'{tp!r} is not a Protocol') + return frozenset(tp.__protocol_attrs__) diff --git a/Misc/NEWS.d/next/Library/2023-05-24-09-55-33.gh-issue-104873.BKQ54y.rst b/Misc/NEWS.d/next/Library/2023-05-24-09-55-33.gh-issue-104873.BKQ54y.rst new file mode 100644 index 00000000000..c901d83812f --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-05-24-09-55-33.gh-issue-104873.BKQ54y.rst @@ -0,0 +1,3 @@ +Add :func:`typing.get_protocol_members` to return the set of members +defining a :class:`typing.Protocol`. Add :func:`typing.is_protocol` to +check whether a class is a :class:`typing.Protocol`. Patch by Jelle Zijlstra.