diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 9d1943b27e8..09a844ddb56 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -575,7 +575,11 @@ def get_annotate_function(obj): Returns the __annotate__ function or None. """ if isinstance(obj, type): - return _BASE_GET_ANNOTATE(obj) + try: + return _BASE_GET_ANNOTATE(obj) + except AttributeError: + # AttributeError is raised for static types. + return None return getattr(obj, "__annotate__", None) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index ce4f92624d9..309f6d21201 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -928,6 +928,27 @@ class MetaclassTests(unittest.TestCase): self.assertIs(annotate_func, None) +class TestGetAnnotateFunction(unittest.TestCase): + def test_static_class(self): + self.assertIsNone(get_annotate_function(object)) + self.assertIsNone(get_annotate_function(int)) + + def test_unannotated_class(self): + class C: + pass + + self.assertIsNone(get_annotate_function(C)) + + D = type("D", (), {}) + self.assertIsNone(get_annotate_function(D)) + + def test_annotated_class(self): + class C: + a: int + + self.assertEqual(get_annotate_function(C)(Format.VALUE), {"a": int}) + + class TestAnnotationLib(unittest.TestCase): def test__all__(self): support.check__all__(self, annotationlib) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 6e036b60033..3ac6b97383f 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -7043,6 +7043,25 @@ class GetTypeHintTests(BaseTestCase): self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]}) self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]}) + def test_get_type_hints_format(self): + class C: + x: undefined + + with self.assertRaises(NameError): + get_type_hints(C) + + with self.assertRaises(NameError): + get_type_hints(C, format=annotationlib.Format.VALUE) + + annos = get_type_hints(C, format=annotationlib.Format.FORWARDREF) + self.assertIsInstance(annos, dict) + self.assertEqual(list(annos), ['x']) + self.assertIsInstance(annos['x'], annotationlib.ForwardRef) + self.assertEqual(annos['x'].__arg__, 'undefined') + + self.assertEqual(get_type_hints(C, format=annotationlib.Format.SOURCE), + {'x': 'undefined'}) + class GetUtilitiesTestCase(TestCase): def test_get_origin(self):