diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index a4afd50376b..d6164324914 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -574,17 +574,18 @@ def _get_field(cls, a_name, a_type): def _find_fields(cls): # Return a list of Field objects, in order, for this class (and no - # base classes). Fields are found from __annotations__ (which is - # guaranteed to be ordered). Default values are from class - # attributes, if a field has a default. If the default value is - # a Field(), then it contains additional info beyond (and - # possibly including) the actual default value. Pseudo-fields - # ClassVars and InitVars are included, despite the fact that - # they're not real fields. That's dealt with later. + # base classes). Fields are found from the class dict's + # __annotations__ (which is guaranteed to be ordered). Default + # values are from class attributes, if a field has a default. If + # the default value is a Field(), then it contains additional + # info beyond (and possibly including) the actual default value. + # Pseudo-fields ClassVars and InitVars are included, despite the + # fact that they're not real fields. That's dealt with later. - annotations = getattr(cls, '__annotations__', {}) - return [_get_field(cls, a_name, a_type) - for a_name, a_type in annotations.items()] + # If __annotations__ isn't present, then this class adds no new + # annotations. + annotations = cls.__dict__.get('__annotations__', {}) + return [_get_field(cls, name, type) for name, type in annotations.items()] def _set_new_attribute(cls, name, value): diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index db03ec1925f..9b5aad25745 100755 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -1147,6 +1147,55 @@ class TestCase(unittest.TestCase): C().x self.assertEqual(factory.call_count, 2) + def test_default_factory_derived(self): + # See bpo-32896. + @dataclass + class Foo: + x: dict = field(default_factory=dict) + + @dataclass + class Bar(Foo): + y: int = 1 + + self.assertEqual(Foo().x, {}) + self.assertEqual(Bar().x, {}) + self.assertEqual(Bar().y, 1) + + @dataclass + class Baz(Foo): + pass + self.assertEqual(Baz().x, {}) + + def test_intermediate_non_dataclass(self): + # Test that an intermediate class that defines + # annotations does not define fields. + + @dataclass + class A: + x: int + + class B(A): + y: int + + @dataclass + class C(B): + z: int + + c = C(1, 3) + self.assertEqual((c.x, c.z), (1, 3)) + + # .y was not initialized. + with self.assertRaisesRegex(AttributeError, + 'object has no attribute'): + c.y + + # And if we again derive a non-dataclass, no fields are added. + class D(C): + t: int + d = D(4, 5) + self.assertEqual((d.x, d.z), (4, 5)) + + def x_test_classvar_default_factory(self): # XXX: it's an error for a ClassVar to have a factory function @dataclass diff --git a/Misc/NEWS.d/next/Library/2018-03-20-20-53-21.bpo-32896.ewW3Ln.rst b/Misc/NEWS.d/next/Library/2018-03-20-20-53-21.bpo-32896.ewW3Ln.rst new file mode 100644 index 00000000000..8363da4667a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-03-20-20-53-21.bpo-32896.ewW3Ln.rst @@ -0,0 +1,2 @@ +Fix an error where subclassing a dataclass with a field that uses a +default_factory would generate an incorrect class.