bpo-32896: Fix error when subclassing a dataclass with a field that uses a default_factory (GH-6170)

Fix the way that new annotations in a class are detected.
This commit is contained in:
Eric V. Smith 2018-03-20 22:00:23 -04:00 committed by GitHub
parent 10b134a07c
commit 8f6eccdc64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 10 deletions

View File

@ -574,17 +574,18 @@ def _get_field(cls, a_name, a_type):
def _find_fields(cls):
# Return a list of Field objects, in order, for this class (and no
# base classes). Fields are found from __annotations__ (which is
# guaranteed to be ordered). Default values are from class
# attributes, if a field has a default. If the default value is
# a Field(), then it contains additional info beyond (and
# possibly including) the actual default value. Pseudo-fields
# ClassVars and InitVars are included, despite the fact that
# they're not real fields. That's dealt with later.
# base classes). Fields are found from the class dict's
# __annotations__ (which is guaranteed to be ordered). Default
# values are from class attributes, if a field has a default. If
# the default value is a Field(), then it contains additional
# info beyond (and possibly including) the actual default value.
# Pseudo-fields ClassVars and InitVars are included, despite the
# fact that they're not real fields. That's dealt with later.
annotations = getattr(cls, '__annotations__', {})
return [_get_field(cls, a_name, a_type)
for a_name, a_type in annotations.items()]
# If __annotations__ isn't present, then this class adds no new
# annotations.
annotations = cls.__dict__.get('__annotations__', {})
return [_get_field(cls, name, type) for name, type in annotations.items()]
def _set_new_attribute(cls, name, value):

View File

@ -1147,6 +1147,55 @@ class TestCase(unittest.TestCase):
C().x
self.assertEqual(factory.call_count, 2)
def test_default_factory_derived(self):
# See bpo-32896.
@dataclass
class Foo:
x: dict = field(default_factory=dict)
@dataclass
class Bar(Foo):
y: int = 1
self.assertEqual(Foo().x, {})
self.assertEqual(Bar().x, {})
self.assertEqual(Bar().y, 1)
@dataclass
class Baz(Foo):
pass
self.assertEqual(Baz().x, {})
def test_intermediate_non_dataclass(self):
# Test that an intermediate class that defines
# annotations does not define fields.
@dataclass
class A:
x: int
class B(A):
y: int
@dataclass
class C(B):
z: int
c = C(1, 3)
self.assertEqual((c.x, c.z), (1, 3))
# .y was not initialized.
with self.assertRaisesRegex(AttributeError,
'object has no attribute'):
c.y
# And if we again derive a non-dataclass, no fields are added.
class D(C):
t: int
d = D(4, 5)
self.assertEqual((d.x, d.z), (4, 5))
def x_test_classvar_default_factory(self):
# XXX: it's an error for a ClassVar to have a factory function
@dataclass

View File

@ -0,0 +1,2 @@
Fix an error where subclassing a dataclass with a field that uses a
default_factory would generate an incorrect class.