bpo-32953: Dataclasses: frozen should not be inherited for non-dataclass derived classes (#6147)

If a non-dataclass derives from a frozen dataclass, allow attributes to be set.
Require either all of the dataclasses in a class hierarchy to be frozen, or all non-frozen.
Store `@dataclass` parameters on the class object under `__dataclass_params__`. This is needed to detect frozen base classes.
This commit is contained in:
Eric V. Smith 2018-03-18 20:40:34 -04:00 committed by GitHub
parent 3fe33043ee
commit f199bc655e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 168 additions and 55 deletions

View File

@ -171,7 +171,11 @@ _FIELD_INITVAR = object() # Not a field, but an InitVar.
# The name of an attribute on the class where we store the Field
# objects. Also used to check if a class is a Data Class.
_MARKER = '__dataclass_fields__'
_FIELDS = '__dataclass_fields__'
# The name of an attribute on the class that stores the parameters to
# @dataclass.
_PARAMS = '__dataclass_params__'
# The name of the function, that if it exists, is called at the end of
# __init__.
@ -192,7 +196,7 @@ class InitVar(metaclass=_InitVarMeta):
# name and type are filled in after the fact, not in __init__. They're
# not known at the time this class is instantiated, but it's
# convenient if they're available later.
# When cls._MARKER is filled in with a list of Field objects, the name
# When cls._FIELDS is filled in with a list of Field objects, the name
# and type fields will have been populated.
class Field:
__slots__ = ('name',
@ -236,6 +240,32 @@ class Field:
')')
class _DataclassParams:
__slots__ = ('init',
'repr',
'eq',
'order',
'unsafe_hash',
'frozen',
)
def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
self.init = init
self.repr = repr
self.eq = eq
self.order = order
self.unsafe_hash = unsafe_hash
self.frozen = frozen
def __repr__(self):
return ('_DataclassParams('
f'init={self.init},'
f'repr={self.repr},'
f'eq={self.eq},'
f'order={self.order},'
f'unsafe_hash={self.unsafe_hash},'
f'frozen={self.frozen}'
')')
# This function is used instead of exposing Field creation directly,
# so that a type checker can be told (via overloads) that this is a
# function whose type depends on its parameters.
@ -285,6 +315,7 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
# Compute the text of the entire function.
txt = f'def {name}({args}){return_annotation}:\n{body}'
exec(txt, globals, locals)
@ -432,12 +463,29 @@ def _repr_fn(fields):
')"'])
def _frozen_setattr(self, name, value):
raise FrozenInstanceError(f'cannot assign to field {name!r}')
def _frozen_delattr(self, name):
raise FrozenInstanceError(f'cannot delete field {name!r}')
def _frozen_get_del_attr(cls, fields):
# XXX: globals is modified on the first call to _create_fn, then the
# modified version is used in the second call. Is this okay?
globals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
if fields:
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
else:
# Special case for the zero-length tuple.
fields_str = '()'
return (_create_fn('__setattr__',
('self', 'name', 'value'),
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
globals=globals),
_create_fn('__delattr__',
('self', 'name'),
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
globals=globals),
)
def _cmp_fn(name, op, self_tuple, other_tuple):
@ -583,23 +631,32 @@ _hash_action = {(False, False, False, False): (''),
# version of this table.
def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# Now that dicts retain insertion order, there's no reason to use
# an ordered dict. I am leveraging that ordering here, because
# derived class fields overwrite base class fields, but the order
# is defined by the base class, which is found first.
fields = {}
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
unsafe_hash, frozen))
# Find our base classes in reverse MRO order, and exclude
# ourselves. In reversed order so that more derived classes
# override earlier field definitions in base classes.
# As long as we're iterating over them, see if any are frozen.
any_frozen_base = False
has_dataclass_bases = False
for b in cls.__mro__[-1:0:-1]:
# Only process classes that have been processed by our
# decorator. That is, they have a _MARKER attribute.
base_fields = getattr(b, _MARKER, None)
# decorator. That is, they have a _FIELDS attribute.
base_fields = getattr(b, _FIELDS, None)
if base_fields:
has_dataclass_bases = True
for f in base_fields.values():
fields[f.name] = f
if getattr(b, _PARAMS).frozen:
any_frozen_base = True
# Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes)
@ -623,20 +680,21 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
else:
setattr(cls, f.name, f.default)
# We're inheriting from a frozen dataclass, but we're not frozen.
if cls.__setattr__ is _frozen_setattr and not frozen:
raise TypeError('cannot inherit non-frozen dataclass from a '
'frozen one')
# Check rules that apply if we are derived from any dataclasses.
if has_dataclass_bases:
# Raise an exception if any of our bases are frozen, but we're not.
if any_frozen_base and not frozen:
raise TypeError('cannot inherit non-frozen dataclass from a '
'frozen one')
# We're inheriting from a non-frozen dataclass, but we're frozen.
if (hasattr(cls, _MARKER) and cls.__setattr__ is not _frozen_setattr
and frozen):
raise TypeError('cannot inherit frozen dataclass from a '
'non-frozen one')
# Raise an exception if we're frozen, but none of our bases are.
if not any_frozen_base and frozen:
raise TypeError('cannot inherit frozen dataclass from a '
'non-frozen one')
# Remember all of the fields on our class (including bases). This
# Remember all of the fields on our class (including bases). This also
# marks this class as being a dataclass.
setattr(cls, _MARKER, fields)
setattr(cls, _FIELDS, fields)
# Was this class defined with an explicit __hash__? Note that if
# __eq__ is defined in this class, then python will automatically
@ -704,10 +762,10 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
'functools.total_ordering')
if frozen:
for name, fn in [('__setattr__', _frozen_setattr),
('__delattr__', _frozen_delattr)]:
if _set_new_attribute(cls, name, fn):
raise TypeError(f'Cannot overwrite attribute {name} '
# XXX: Which fields are frozen? InitVar? ClassVar? hashed-only?
for fn in _frozen_get_del_attr(cls, field_list):
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
# Decide if/how we're going to create a hash function.
@ -759,7 +817,7 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
"""
def wrap(cls):
return _process_class(cls, repr, eq, order, unsafe_hash, init, frozen)
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
# See if we're being called as @dataclass or @dataclass().
if _cls is None:
@ -779,7 +837,7 @@ def fields(class_or_instance):
# Might it be worth caching this, per class?
try:
fields = getattr(class_or_instance, _MARKER)
fields = getattr(class_or_instance, _FIELDS)
except AttributeError:
raise TypeError('must be called with a dataclass type or instance')
@ -790,13 +848,13 @@ def fields(class_or_instance):
def _is_dataclass_instance(obj):
"""Returns True if obj is an instance of a dataclass."""
return not isinstance(obj, type) and hasattr(obj, _MARKER)
return not isinstance(obj, type) and hasattr(obj, _FIELDS)
def is_dataclass(obj):
"""Returns True if obj is a dataclass or an instance of a
dataclass."""
return hasattr(obj, _MARKER)
return hasattr(obj, _FIELDS)
def asdict(obj, *, dict_factory=dict):
@ -953,7 +1011,7 @@ def replace(obj, **changes):
# It's an error to have init=False fields in 'changes'.
# If a field is not in 'changes', read its value from the provided obj.
for f in getattr(obj, _MARKER).values():
for f in getattr(obj, _FIELDS).values():
if not f.init:
# Error if this field is specified in changes.
if f.name in changes:

View File

@ -2476,41 +2476,92 @@ class TestFrozen(unittest.TestCase):
d = D(0, 10)
with self.assertRaises(FrozenInstanceError):
d.i = 5
with self.assertRaises(FrozenInstanceError):
d.j = 6
self.assertEqual(d.i, 0)
self.assertEqual(d.j, 10)
def test_inherit_from_nonfrozen_from_frozen(self):
@dataclass(frozen=True)
class C:
i: int
# Test both ways: with an intermediate normal (non-dataclass)
# class and without an intermediate class.
def test_inherit_nonfrozen_from_frozen(self):
for intermediate_class in [True, False]:
with self.subTest(intermediate_class=intermediate_class):
@dataclass(frozen=True)
class C:
i: int
with self.assertRaisesRegex(TypeError,
'cannot inherit non-frozen dataclass from a frozen one'):
@dataclass
class D(C):
pass
if intermediate_class:
class I(C): pass
else:
I = C
def test_inherit_from_frozen_from_nonfrozen(self):
@dataclass
class C:
i: int
with self.assertRaisesRegex(TypeError,
'cannot inherit non-frozen dataclass from a frozen one'):
@dataclass
class D(I):
pass
with self.assertRaisesRegex(TypeError,
'cannot inherit frozen dataclass from a non-frozen one'):
@dataclass(frozen=True)
class D(C):
pass
def test_inherit_frozen_from_nonfrozen(self):
for intermediate_class in [True, False]:
with self.subTest(intermediate_class=intermediate_class):
@dataclass
class C:
i: int
if intermediate_class:
class I(C): pass
else:
I = C
with self.assertRaisesRegex(TypeError,
'cannot inherit frozen dataclass from a non-frozen one'):
@dataclass(frozen=True)
class D(I):
pass
def test_inherit_from_normal_class(self):
class C:
pass
for intermediate_class in [True, False]:
with self.subTest(intermediate_class=intermediate_class):
class C:
pass
if intermediate_class:
class I(C): pass
else:
I = C
@dataclass(frozen=True)
class D(I):
i: int
d = D(10)
with self.assertRaises(FrozenInstanceError):
d.i = 5
def test_non_frozen_normal_derived(self):
# See bpo-32953.
@dataclass(frozen=True)
class D(C):
i: int
class D:
x: int
y: int = 10
d = D(10)
class S(D):
pass
s = S(3)
self.assertEqual(s.x, 3)
self.assertEqual(s.y, 10)
s.cached = True
# But can't change the frozen attributes.
with self.assertRaises(FrozenInstanceError):
d.i = 5
s.x = 5
with self.assertRaises(FrozenInstanceError):
s.y = 5
self.assertEqual(s.x, 3)
self.assertEqual(s.y, 10)
self.assertEqual(s.cached, True)
if __name__ == '__main__':

View File

@ -0,0 +1,4 @@
If a non-dataclass inherits from a frozen dataclass, allow attributes to be
added to the derived class. Only attributes from from the frozen dataclass
cannot be assigned to. Require all dataclasses in a hierarchy to be either
all frozen or all non-frozen.