bpo-32896: Fix error when subclassing a dataclass with a field that uses a default_factory (GH-6170)
Fix the way that new annotations in a class are detected.
This commit is contained in:
parent
10b134a07c
commit
8f6eccdc64
|
@ -574,17 +574,18 @@ def _get_field(cls, a_name, a_type):
|
|||
|
||||
def _find_fields(cls):
|
||||
# Return a list of Field objects, in order, for this class (and no
|
||||
# base classes). Fields are found from __annotations__ (which is
|
||||
# guaranteed to be ordered). Default values are from class
|
||||
# attributes, if a field has a default. If the default value is
|
||||
# a Field(), then it contains additional info beyond (and
|
||||
# possibly including) the actual default value. Pseudo-fields
|
||||
# ClassVars and InitVars are included, despite the fact that
|
||||
# they're not real fields. That's dealt with later.
|
||||
# base classes). Fields are found from the class dict's
|
||||
# __annotations__ (which is guaranteed to be ordered). Default
|
||||
# values are from class attributes, if a field has a default. If
|
||||
# the default value is a Field(), then it contains additional
|
||||
# info beyond (and possibly including) the actual default value.
|
||||
# Pseudo-fields ClassVars and InitVars are included, despite the
|
||||
# fact that they're not real fields. That's dealt with later.
|
||||
|
||||
annotations = getattr(cls, '__annotations__', {})
|
||||
return [_get_field(cls, a_name, a_type)
|
||||
for a_name, a_type in annotations.items()]
|
||||
# If __annotations__ isn't present, then this class adds no new
|
||||
# annotations.
|
||||
annotations = cls.__dict__.get('__annotations__', {})
|
||||
return [_get_field(cls, name, type) for name, type in annotations.items()]
|
||||
|
||||
|
||||
def _set_new_attribute(cls, name, value):
|
||||
|
|
|
@ -1147,6 +1147,55 @@ class TestCase(unittest.TestCase):
|
|||
C().x
|
||||
self.assertEqual(factory.call_count, 2)
|
||||
|
||||
def test_default_factory_derived(self):
|
||||
# See bpo-32896.
|
||||
@dataclass
|
||||
class Foo:
|
||||
x: dict = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class Bar(Foo):
|
||||
y: int = 1
|
||||
|
||||
self.assertEqual(Foo().x, {})
|
||||
self.assertEqual(Bar().x, {})
|
||||
self.assertEqual(Bar().y, 1)
|
||||
|
||||
@dataclass
|
||||
class Baz(Foo):
|
||||
pass
|
||||
self.assertEqual(Baz().x, {})
|
||||
|
||||
def test_intermediate_non_dataclass(self):
|
||||
# Test that an intermediate class that defines
|
||||
# annotations does not define fields.
|
||||
|
||||
@dataclass
|
||||
class A:
|
||||
x: int
|
||||
|
||||
class B(A):
|
||||
y: int
|
||||
|
||||
@dataclass
|
||||
class C(B):
|
||||
z: int
|
||||
|
||||
c = C(1, 3)
|
||||
self.assertEqual((c.x, c.z), (1, 3))
|
||||
|
||||
# .y was not initialized.
|
||||
with self.assertRaisesRegex(AttributeError,
|
||||
'object has no attribute'):
|
||||
c.y
|
||||
|
||||
# And if we again derive a non-dataclass, no fields are added.
|
||||
class D(C):
|
||||
t: int
|
||||
d = D(4, 5)
|
||||
self.assertEqual((d.x, d.z), (4, 5))
|
||||
|
||||
|
||||
def x_test_classvar_default_factory(self):
|
||||
# XXX: it's an error for a ClassVar to have a factory function
|
||||
@dataclass
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Fix an error where subclassing a dataclass with a field that uses a
|
||||
default_factory would generate an incorrect class.
|
Loading…
Reference in New Issue