From 96f619faa74a8a32c2c297833cdeb0393c0b6b13 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 18 Sep 2024 08:39:22 -0700 Subject: [PATCH] gh-124206: Fix calling get_annotate_function() on static types (#124208) Fixes #124206. No news entry because the bug this fixes was never released. --- Lib/annotationlib.py | 6 +++++- Lib/test/test_annotationlib.py | 21 +++++++++++++++++++++ Lib/test/test_typing.py | 19 +++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 9d1943b27e8..09a844ddb56 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -575,7 +575,11 @@ def get_annotate_function(obj): Returns the __annotate__ function or None. """ if isinstance(obj, type): - return _BASE_GET_ANNOTATE(obj) + try: + return _BASE_GET_ANNOTATE(obj) + except AttributeError: + # AttributeError is raised for static types. + return None return getattr(obj, "__annotate__", None) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index ce4f92624d9..309f6d21201 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -928,6 +928,27 @@ class MetaclassTests(unittest.TestCase): self.assertIs(annotate_func, None) +class TestGetAnnotateFunction(unittest.TestCase): + def test_static_class(self): + self.assertIsNone(get_annotate_function(object)) + self.assertIsNone(get_annotate_function(int)) + + def test_unannotated_class(self): + class C: + pass + + self.assertIsNone(get_annotate_function(C)) + + D = type("D", (), {}) + self.assertIsNone(get_annotate_function(D)) + + def test_annotated_class(self): + class C: + a: int + + self.assertEqual(get_annotate_function(C)(Format.VALUE), {"a": int}) + + class TestAnnotationLib(unittest.TestCase): def test__all__(self): support.check__all__(self, annotationlib) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 6e036b60033..3ac6b97383f 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -7043,6 +7043,25 @@ class GetTypeHintTests(BaseTestCase): self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]}) self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]}) + def test_get_type_hints_format(self): + class C: + x: undefined + + with self.assertRaises(NameError): + get_type_hints(C) + + with self.assertRaises(NameError): + get_type_hints(C, format=annotationlib.Format.VALUE) + + annos = get_type_hints(C, format=annotationlib.Format.FORWARDREF) + self.assertIsInstance(annos, dict) + self.assertEqual(list(annos), ['x']) + self.assertIsInstance(annos['x'], annotationlib.ForwardRef) + self.assertEqual(annos['x'].__arg__, 'undefined') + + self.assertEqual(get_type_hints(C, format=annotationlib.Format.SOURCE), + {'x': 'undefined'}) + class GetUtilitiesTestCase(TestCase): def test_get_origin(self):