diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index b4036ffb189..eea24232f9f 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -524,6 +524,27 @@ def call_annotate_function(annotate, format, owner=None): raise ValueError(f"Invalid format: {format!r}") +# We use the descriptors from builtins.type instead of accessing +# .__annotations__ and .__annotate__ directly on class objects, because +# otherwise we could get wrong results in some cases involving metaclasses. +# See PEP 749. +_BASE_GET_ANNOTATE = type.__dict__["__annotate__"].__get__ +_BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__ + + +def get_annotate_function(obj): + """Get the __annotate__ function for an object. + + obj may be a function, class, or module, or a user-defined type with + an `__annotate__` attribute. + + Returns the __annotate__ function or None. + """ + if isinstance(obj, type): + return _BASE_GET_ANNOTATE(obj) + return getattr(obj, "__annotate__", None) + + def get_annotations( obj, *, globals=None, locals=None, eval_str=False, format=Format.VALUE ): @@ -576,16 +597,23 @@ def get_annotations( # For VALUE format, we look at __annotations__ directly. if format != Format.VALUE: - annotate = getattr(obj, "__annotate__", None) + annotate = get_annotate_function(obj) if annotate is not None: ann = call_annotate_function(annotate, format, owner=obj) if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotate__ returned a non-dict") return dict(ann) - ann = getattr(obj, "__annotations__", None) - if ann is None: - return {} + if isinstance(obj, type): + try: + ann = _BASE_GET_ANNOTATIONS(obj) + except AttributeError: + # For static types, the descriptor raises AttributeError. + return {} + else: + ann = getattr(obj, "__annotations__", None) + if ann is None: + return {} if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index e68d63c91d1..e459d27d3c4 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -2,8 +2,10 @@ import annotationlib import functools +import itertools import pickle import unittest +from annotationlib import Format, get_annotations, get_annotate_function from typing import Unpack from test.test_inspect import inspect_stock_annotations @@ -767,5 +769,85 @@ class TestGetAnnotations(unittest.TestCase): self.assertEqual( set(results.generic_func_annotations.values()), - set(results.generic_func.__type_params__) + set(results.generic_func.__type_params__), ) + + +class MetaclassTests(unittest.TestCase): + def test_annotated_meta(self): + class Meta(type): + a: int + + class X(metaclass=Meta): + pass + + class Y(metaclass=Meta): + b: float + + self.assertEqual(get_annotations(Meta), {"a": int}) + self.assertEqual(get_annotate_function(Meta)(Format.VALUE), {"a": int}) + + self.assertEqual(get_annotations(X), {}) + self.assertIs(get_annotate_function(X), None) + + self.assertEqual(get_annotations(Y), {"b": float}) + self.assertEqual(get_annotate_function(Y)(Format.VALUE), {"b": float}) + + def test_unannotated_meta(self): + class Meta(type): pass + + class X(metaclass=Meta): + a: str + + class Y(X): pass + + self.assertEqual(get_annotations(Meta), {}) + self.assertIs(get_annotate_function(Meta), None) + + self.assertEqual(get_annotations(Y), {}) + self.assertIs(get_annotate_function(Y), None) + + self.assertEqual(get_annotations(X), {"a": str}) + self.assertEqual(get_annotate_function(X)(Format.VALUE), {"a": str}) + + def test_ordering(self): + # Based on a sample by David Ellis + # https://discuss.python.org/t/pep-749-implementing-pep-649/54974/38 + + def make_classes(): + class Meta(type): + a: int + expected_annotations = {"a": int} + + class A(type, metaclass=Meta): + b: float + expected_annotations = {"b": float} + + class B(metaclass=A): + c: str + expected_annotations = {"c": str} + + class C(B): + expected_annotations = {} + + class D(metaclass=Meta): + expected_annotations = {} + + return Meta, A, B, C, D + + classes = make_classes() + class_count = len(classes) + for order in itertools.permutations(range(class_count), class_count): + names = ", ".join(classes[i].__name__ for i in order) + with self.subTest(names=names): + classes = make_classes() # Regenerate classes + for i in order: + get_annotations(classes[i]) + for c in classes: + with self.subTest(c=c): + self.assertEqual(get_annotations(c), c.expected_annotations) + annotate_func = get_annotate_function(c) + if c.expected_annotations: + self.assertEqual(annotate_func(Format.VALUE), c.expected_annotations) + else: + self.assertIs(annotate_func, None) diff --git a/Misc/NEWS.d/next/Library/2024-07-23-17-13-10.gh-issue-119180.5PZELo.rst b/Misc/NEWS.d/next/Library/2024-07-23-17-13-10.gh-issue-119180.5PZELo.rst new file mode 100644 index 00000000000..d65e89f7523 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-07-23-17-13-10.gh-issue-119180.5PZELo.rst @@ -0,0 +1,2 @@ +Fix handling of classes with custom metaclasses in +``annotationlib.get_annotations``.