bpo-32953: Dataclasses: frozen should not be inherited for non-dataclass derived classes (#6147)
If a non-dataclass derives from a frozen dataclass, allow attributes to be set. Require either all of the dataclasses in a class hierarchy to be frozen, or all non-frozen. Store `@dataclass` parameters on the class object under `__dataclass_params__`. This is needed to detect frozen base classes.
This commit is contained in:
parent
3fe33043ee
commit
f199bc655e
|
@ -171,7 +171,11 @@ _FIELD_INITVAR = object() # Not a field, but an InitVar.
|
|||
|
||||
# The name of an attribute on the class where we store the Field
|
||||
# objects. Also used to check if a class is a Data Class.
|
||||
_MARKER = '__dataclass_fields__'
|
||||
_FIELDS = '__dataclass_fields__'
|
||||
|
||||
# The name of an attribute on the class that stores the parameters to
|
||||
# @dataclass.
|
||||
_PARAMS = '__dataclass_params__'
|
||||
|
||||
# The name of the function, that if it exists, is called at the end of
|
||||
# __init__.
|
||||
|
@ -192,7 +196,7 @@ class InitVar(metaclass=_InitVarMeta):
|
|||
# name and type are filled in after the fact, not in __init__. They're
|
||||
# not known at the time this class is instantiated, but it's
|
||||
# convenient if they're available later.
|
||||
# When cls._MARKER is filled in with a list of Field objects, the name
|
||||
# When cls._FIELDS is filled in with a list of Field objects, the name
|
||||
# and type fields will have been populated.
|
||||
class Field:
|
||||
__slots__ = ('name',
|
||||
|
@ -236,6 +240,32 @@ class Field:
|
|||
')')
|
||||
|
||||
|
||||
class _DataclassParams:
|
||||
__slots__ = ('init',
|
||||
'repr',
|
||||
'eq',
|
||||
'order',
|
||||
'unsafe_hash',
|
||||
'frozen',
|
||||
)
|
||||
def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
|
||||
self.init = init
|
||||
self.repr = repr
|
||||
self.eq = eq
|
||||
self.order = order
|
||||
self.unsafe_hash = unsafe_hash
|
||||
self.frozen = frozen
|
||||
|
||||
def __repr__(self):
|
||||
return ('_DataclassParams('
|
||||
f'init={self.init},'
|
||||
f'repr={self.repr},'
|
||||
f'eq={self.eq},'
|
||||
f'order={self.order},'
|
||||
f'unsafe_hash={self.unsafe_hash},'
|
||||
f'frozen={self.frozen}'
|
||||
')')
|
||||
|
||||
# This function is used instead of exposing Field creation directly,
|
||||
# so that a type checker can be told (via overloads) that this is a
|
||||
# function whose type depends on its parameters.
|
||||
|
@ -285,6 +315,7 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
|
|||
args = ','.join(args)
|
||||
body = '\n'.join(f' {b}' for b in body)
|
||||
|
||||
# Compute the text of the entire function.
|
||||
txt = f'def {name}({args}){return_annotation}:\n{body}'
|
||||
|
||||
exec(txt, globals, locals)
|
||||
|
@ -432,12 +463,29 @@ def _repr_fn(fields):
|
|||
')"'])
|
||||
|
||||
|
||||
def _frozen_setattr(self, name, value):
|
||||
raise FrozenInstanceError(f'cannot assign to field {name!r}')
|
||||
|
||||
|
||||
def _frozen_delattr(self, name):
|
||||
raise FrozenInstanceError(f'cannot delete field {name!r}')
|
||||
def _frozen_get_del_attr(cls, fields):
|
||||
# XXX: globals is modified on the first call to _create_fn, then the
|
||||
# modified version is used in the second call. Is this okay?
|
||||
globals = {'cls': cls,
|
||||
'FrozenInstanceError': FrozenInstanceError}
|
||||
if fields:
|
||||
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
|
||||
else:
|
||||
# Special case for the zero-length tuple.
|
||||
fields_str = '()'
|
||||
return (_create_fn('__setattr__',
|
||||
('self', 'name', 'value'),
|
||||
(f'if type(self) is cls or name in {fields_str}:',
|
||||
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
|
||||
f'super(cls, self).__setattr__(name, value)'),
|
||||
globals=globals),
|
||||
_create_fn('__delattr__',
|
||||
('self', 'name'),
|
||||
(f'if type(self) is cls or name in {fields_str}:',
|
||||
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
|
||||
f'super(cls, self).__delattr__(name)'),
|
||||
globals=globals),
|
||||
)
|
||||
|
||||
|
||||
def _cmp_fn(name, op, self_tuple, other_tuple):
|
||||
|
@ -583,23 +631,32 @@ _hash_action = {(False, False, False, False): (''),
|
|||
# version of this table.
|
||||
|
||||
|
||||
def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
|
||||
def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
|
||||
# Now that dicts retain insertion order, there's no reason to use
|
||||
# an ordered dict. I am leveraging that ordering here, because
|
||||
# derived class fields overwrite base class fields, but the order
|
||||
# is defined by the base class, which is found first.
|
||||
fields = {}
|
||||
|
||||
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
|
||||
unsafe_hash, frozen))
|
||||
|
||||
# Find our base classes in reverse MRO order, and exclude
|
||||
# ourselves. In reversed order so that more derived classes
|
||||
# override earlier field definitions in base classes.
|
||||
# As long as we're iterating over them, see if any are frozen.
|
||||
any_frozen_base = False
|
||||
has_dataclass_bases = False
|
||||
for b in cls.__mro__[-1:0:-1]:
|
||||
# Only process classes that have been processed by our
|
||||
# decorator. That is, they have a _MARKER attribute.
|
||||
base_fields = getattr(b, _MARKER, None)
|
||||
# decorator. That is, they have a _FIELDS attribute.
|
||||
base_fields = getattr(b, _FIELDS, None)
|
||||
if base_fields:
|
||||
has_dataclass_bases = True
|
||||
for f in base_fields.values():
|
||||
fields[f.name] = f
|
||||
if getattr(b, _PARAMS).frozen:
|
||||
any_frozen_base = True
|
||||
|
||||
# Now find fields in our class. While doing so, validate some
|
||||
# things, and set the default values (as class attributes)
|
||||
|
@ -623,20 +680,21 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
|
|||
else:
|
||||
setattr(cls, f.name, f.default)
|
||||
|
||||
# We're inheriting from a frozen dataclass, but we're not frozen.
|
||||
if cls.__setattr__ is _frozen_setattr and not frozen:
|
||||
raise TypeError('cannot inherit non-frozen dataclass from a '
|
||||
'frozen one')
|
||||
# Check rules that apply if we are derived from any dataclasses.
|
||||
if has_dataclass_bases:
|
||||
# Raise an exception if any of our bases are frozen, but we're not.
|
||||
if any_frozen_base and not frozen:
|
||||
raise TypeError('cannot inherit non-frozen dataclass from a '
|
||||
'frozen one')
|
||||
|
||||
# We're inheriting from a non-frozen dataclass, but we're frozen.
|
||||
if (hasattr(cls, _MARKER) and cls.__setattr__ is not _frozen_setattr
|
||||
and frozen):
|
||||
raise TypeError('cannot inherit frozen dataclass from a '
|
||||
'non-frozen one')
|
||||
# Raise an exception if we're frozen, but none of our bases are.
|
||||
if not any_frozen_base and frozen:
|
||||
raise TypeError('cannot inherit frozen dataclass from a '
|
||||
'non-frozen one')
|
||||
|
||||
# Remember all of the fields on our class (including bases). This
|
||||
# Remember all of the fields on our class (including bases). This also
|
||||
# marks this class as being a dataclass.
|
||||
setattr(cls, _MARKER, fields)
|
||||
setattr(cls, _FIELDS, fields)
|
||||
|
||||
# Was this class defined with an explicit __hash__? Note that if
|
||||
# __eq__ is defined in this class, then python will automatically
|
||||
|
@ -704,10 +762,10 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
|
|||
'functools.total_ordering')
|
||||
|
||||
if frozen:
|
||||
for name, fn in [('__setattr__', _frozen_setattr),
|
||||
('__delattr__', _frozen_delattr)]:
|
||||
if _set_new_attribute(cls, name, fn):
|
||||
raise TypeError(f'Cannot overwrite attribute {name} '
|
||||
# XXX: Which fields are frozen? InitVar? ClassVar? hashed-only?
|
||||
for fn in _frozen_get_del_attr(cls, field_list):
|
||||
if _set_new_attribute(cls, fn.__name__, fn):
|
||||
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
|
||||
f'in class {cls.__name__}')
|
||||
|
||||
# Decide if/how we're going to create a hash function.
|
||||
|
@ -759,7 +817,7 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
|
|||
"""
|
||||
|
||||
def wrap(cls):
|
||||
return _process_class(cls, repr, eq, order, unsafe_hash, init, frozen)
|
||||
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
|
||||
|
||||
# See if we're being called as @dataclass or @dataclass().
|
||||
if _cls is None:
|
||||
|
@ -779,7 +837,7 @@ def fields(class_or_instance):
|
|||
|
||||
# Might it be worth caching this, per class?
|
||||
try:
|
||||
fields = getattr(class_or_instance, _MARKER)
|
||||
fields = getattr(class_or_instance, _FIELDS)
|
||||
except AttributeError:
|
||||
raise TypeError('must be called with a dataclass type or instance')
|
||||
|
||||
|
@ -790,13 +848,13 @@ def fields(class_or_instance):
|
|||
|
||||
def _is_dataclass_instance(obj):
|
||||
"""Returns True if obj is an instance of a dataclass."""
|
||||
return not isinstance(obj, type) and hasattr(obj, _MARKER)
|
||||
return not isinstance(obj, type) and hasattr(obj, _FIELDS)
|
||||
|
||||
|
||||
def is_dataclass(obj):
|
||||
"""Returns True if obj is a dataclass or an instance of a
|
||||
dataclass."""
|
||||
return hasattr(obj, _MARKER)
|
||||
return hasattr(obj, _FIELDS)
|
||||
|
||||
|
||||
def asdict(obj, *, dict_factory=dict):
|
||||
|
@ -953,7 +1011,7 @@ def replace(obj, **changes):
|
|||
# It's an error to have init=False fields in 'changes'.
|
||||
# If a field is not in 'changes', read its value from the provided obj.
|
||||
|
||||
for f in getattr(obj, _MARKER).values():
|
||||
for f in getattr(obj, _FIELDS).values():
|
||||
if not f.init:
|
||||
# Error if this field is specified in changes.
|
||||
if f.name in changes:
|
||||
|
|
|
@ -2476,41 +2476,92 @@ class TestFrozen(unittest.TestCase):
|
|||
d = D(0, 10)
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
d.i = 5
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
d.j = 6
|
||||
self.assertEqual(d.i, 0)
|
||||
self.assertEqual(d.j, 10)
|
||||
|
||||
def test_inherit_from_nonfrozen_from_frozen(self):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
i: int
|
||||
# Test both ways: with an intermediate normal (non-dataclass)
|
||||
# class and without an intermediate class.
|
||||
def test_inherit_nonfrozen_from_frozen(self):
|
||||
for intermediate_class in [True, False]:
|
||||
with self.subTest(intermediate_class=intermediate_class):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
i: int
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'cannot inherit non-frozen dataclass from a frozen one'):
|
||||
@dataclass
|
||||
class D(C):
|
||||
pass
|
||||
if intermediate_class:
|
||||
class I(C): pass
|
||||
else:
|
||||
I = C
|
||||
|
||||
def test_inherit_from_frozen_from_nonfrozen(self):
|
||||
@dataclass
|
||||
class C:
|
||||
i: int
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'cannot inherit non-frozen dataclass from a frozen one'):
|
||||
@dataclass
|
||||
class D(I):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'cannot inherit frozen dataclass from a non-frozen one'):
|
||||
@dataclass(frozen=True)
|
||||
class D(C):
|
||||
pass
|
||||
def test_inherit_frozen_from_nonfrozen(self):
|
||||
for intermediate_class in [True, False]:
|
||||
with self.subTest(intermediate_class=intermediate_class):
|
||||
@dataclass
|
||||
class C:
|
||||
i: int
|
||||
|
||||
if intermediate_class:
|
||||
class I(C): pass
|
||||
else:
|
||||
I = C
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'cannot inherit frozen dataclass from a non-frozen one'):
|
||||
@dataclass(frozen=True)
|
||||
class D(I):
|
||||
pass
|
||||
|
||||
def test_inherit_from_normal_class(self):
|
||||
class C:
|
||||
pass
|
||||
for intermediate_class in [True, False]:
|
||||
with self.subTest(intermediate_class=intermediate_class):
|
||||
class C:
|
||||
pass
|
||||
|
||||
if intermediate_class:
|
||||
class I(C): pass
|
||||
else:
|
||||
I = C
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class D(I):
|
||||
i: int
|
||||
|
||||
d = D(10)
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
d.i = 5
|
||||
|
||||
def test_non_frozen_normal_derived(self):
|
||||
# See bpo-32953.
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class D(C):
|
||||
i: int
|
||||
class D:
|
||||
x: int
|
||||
y: int = 10
|
||||
|
||||
d = D(10)
|
||||
class S(D):
|
||||
pass
|
||||
|
||||
s = S(3)
|
||||
self.assertEqual(s.x, 3)
|
||||
self.assertEqual(s.y, 10)
|
||||
s.cached = True
|
||||
|
||||
# But can't change the frozen attributes.
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
d.i = 5
|
||||
s.x = 5
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
s.y = 5
|
||||
self.assertEqual(s.x, 3)
|
||||
self.assertEqual(s.y, 10)
|
||||
self.assertEqual(s.cached, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
If a non-dataclass inherits from a frozen dataclass, allow attributes to be
|
||||
added to the derived class. Only attributes from from the frozen dataclass
|
||||
cannot be assigned to. Require all dataclasses in a hierarchy to be either
|
||||
all frozen or all non-frozen.
|
Loading…
Reference in New Issue