gh-119180: Use type descriptors to access annotations (PEP 749) (#122074)

This commit is contained in:
Jelle Zijlstra 2024-07-27 09:36:06 -07:00 committed by GitHub
parent 4e75509349
commit 45614ecb2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 117 additions and 5 deletions

View File

@ -524,6 +524,27 @@ def call_annotate_function(annotate, format, owner=None):
raise ValueError(f"Invalid format: {format!r}") raise ValueError(f"Invalid format: {format!r}")
# We use the descriptors from builtins.type instead of accessing
# .__annotations__ and .__annotate__ directly on class objects, because
# otherwise we could get wrong results in some cases involving metaclasses.
# See PEP 749.
_BASE_GET_ANNOTATE = type.__dict__["__annotate__"].__get__
_BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__
def get_annotate_function(obj):
"""Get the __annotate__ function for an object.
obj may be a function, class, or module, or a user-defined type with
an `__annotate__` attribute.
Returns the __annotate__ function or None.
"""
if isinstance(obj, type):
return _BASE_GET_ANNOTATE(obj)
return getattr(obj, "__annotate__", None)
def get_annotations( def get_annotations(
obj, *, globals=None, locals=None, eval_str=False, format=Format.VALUE obj, *, globals=None, locals=None, eval_str=False, format=Format.VALUE
): ):
@ -576,13 +597,20 @@ def get_annotations(
# For VALUE format, we look at __annotations__ directly. # For VALUE format, we look at __annotations__ directly.
if format != Format.VALUE: if format != Format.VALUE:
annotate = getattr(obj, "__annotate__", None) annotate = get_annotate_function(obj)
if annotate is not None: if annotate is not None:
ann = call_annotate_function(annotate, format, owner=obj) ann = call_annotate_function(annotate, format, owner=obj)
if not isinstance(ann, dict): if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotate__ returned a non-dict") raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
return dict(ann) return dict(ann)
if isinstance(obj, type):
try:
ann = _BASE_GET_ANNOTATIONS(obj)
except AttributeError:
# For static types, the descriptor raises AttributeError.
return {}
else:
ann = getattr(obj, "__annotations__", None) ann = getattr(obj, "__annotations__", None)
if ann is None: if ann is None:
return {} return {}

View File

@ -2,8 +2,10 @@
import annotationlib import annotationlib
import functools import functools
import itertools
import pickle import pickle
import unittest import unittest
from annotationlib import Format, get_annotations, get_annotate_function
from typing import Unpack from typing import Unpack
from test.test_inspect import inspect_stock_annotations from test.test_inspect import inspect_stock_annotations
@ -767,5 +769,85 @@ class TestGetAnnotations(unittest.TestCase):
self.assertEqual( self.assertEqual(
set(results.generic_func_annotations.values()), set(results.generic_func_annotations.values()),
set(results.generic_func.__type_params__) set(results.generic_func.__type_params__),
) )
class MetaclassTests(unittest.TestCase):
def test_annotated_meta(self):
class Meta(type):
a: int
class X(metaclass=Meta):
pass
class Y(metaclass=Meta):
b: float
self.assertEqual(get_annotations(Meta), {"a": int})
self.assertEqual(get_annotate_function(Meta)(Format.VALUE), {"a": int})
self.assertEqual(get_annotations(X), {})
self.assertIs(get_annotate_function(X), None)
self.assertEqual(get_annotations(Y), {"b": float})
self.assertEqual(get_annotate_function(Y)(Format.VALUE), {"b": float})
def test_unannotated_meta(self):
class Meta(type): pass
class X(metaclass=Meta):
a: str
class Y(X): pass
self.assertEqual(get_annotations(Meta), {})
self.assertIs(get_annotate_function(Meta), None)
self.assertEqual(get_annotations(Y), {})
self.assertIs(get_annotate_function(Y), None)
self.assertEqual(get_annotations(X), {"a": str})
self.assertEqual(get_annotate_function(X)(Format.VALUE), {"a": str})
def test_ordering(self):
# Based on a sample by David Ellis
# https://discuss.python.org/t/pep-749-implementing-pep-649/54974/38
def make_classes():
class Meta(type):
a: int
expected_annotations = {"a": int}
class A(type, metaclass=Meta):
b: float
expected_annotations = {"b": float}
class B(metaclass=A):
c: str
expected_annotations = {"c": str}
class C(B):
expected_annotations = {}
class D(metaclass=Meta):
expected_annotations = {}
return Meta, A, B, C, D
classes = make_classes()
class_count = len(classes)
for order in itertools.permutations(range(class_count), class_count):
names = ", ".join(classes[i].__name__ for i in order)
with self.subTest(names=names):
classes = make_classes() # Regenerate classes
for i in order:
get_annotations(classes[i])
for c in classes:
with self.subTest(c=c):
self.assertEqual(get_annotations(c), c.expected_annotations)
annotate_func = get_annotate_function(c)
if c.expected_annotations:
self.assertEqual(annotate_func(Format.VALUE), c.expected_annotations)
else:
self.assertIs(annotate_func, None)

View File

@ -0,0 +1,2 @@
Fix handling of classes with custom metaclasses in
``annotationlib.get_annotations``.