diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index b55a497db30..8ab04dd5b97 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -171,7 +171,11 @@ _FIELD_INITVAR = object() # Not a field, but an InitVar. # The name of an attribute on the class where we store the Field # objects. Also used to check if a class is a Data Class. -_MARKER = '__dataclass_fields__' +_FIELDS = '__dataclass_fields__' + +# The name of an attribute on the class that stores the parameters to +# @dataclass. +_PARAMS = '__dataclass_params__' # The name of the function, that if it exists, is called at the end of # __init__. @@ -192,7 +196,7 @@ class InitVar(metaclass=_InitVarMeta): # name and type are filled in after the fact, not in __init__. They're # not known at the time this class is instantiated, but it's # convenient if they're available later. -# When cls._MARKER is filled in with a list of Field objects, the name +# When cls._FIELDS is filled in with a list of Field objects, the name # and type fields will have been populated. class Field: __slots__ = ('name', @@ -236,6 +240,32 @@ class Field: ')') +class _DataclassParams: + __slots__ = ('init', + 'repr', + 'eq', + 'order', + 'unsafe_hash', + 'frozen', + ) + def __init__(self, init, repr, eq, order, unsafe_hash, frozen): + self.init = init + self.repr = repr + self.eq = eq + self.order = order + self.unsafe_hash = unsafe_hash + self.frozen = frozen + + def __repr__(self): + return ('_DataclassParams(' + f'init={self.init},' + f'repr={self.repr},' + f'eq={self.eq},' + f'order={self.order},' + f'unsafe_hash={self.unsafe_hash},' + f'frozen={self.frozen}' + ')') + # This function is used instead of exposing Field creation directly, # so that a type checker can be told (via overloads) that this is a # function whose type depends on its parameters. @@ -285,6 +315,7 @@ def _create_fn(name, args, body, *, globals=None, locals=None, args = ','.join(args) body = '\n'.join(f' {b}' for b in body) + # Compute the text of the entire function. txt = f'def {name}({args}){return_annotation}:\n{body}' exec(txt, globals, locals) @@ -432,12 +463,29 @@ def _repr_fn(fields): ')"']) -def _frozen_setattr(self, name, value): - raise FrozenInstanceError(f'cannot assign to field {name!r}') - - -def _frozen_delattr(self, name): - raise FrozenInstanceError(f'cannot delete field {name!r}') +def _frozen_get_del_attr(cls, fields): + # XXX: globals is modified on the first call to _create_fn, then the + # modified version is used in the second call. Is this okay? + globals = {'cls': cls, + 'FrozenInstanceError': FrozenInstanceError} + if fields: + fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)' + else: + # Special case for the zero-length tuple. + fields_str = '()' + return (_create_fn('__setattr__', + ('self', 'name', 'value'), + (f'if type(self) is cls or name in {fields_str}:', + ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', + f'super(cls, self).__setattr__(name, value)'), + globals=globals), + _create_fn('__delattr__', + ('self', 'name'), + (f'if type(self) is cls or name in {fields_str}:', + ' raise FrozenInstanceError(f"cannot delete field {name!r}")', + f'super(cls, self).__delattr__(name)'), + globals=globals), + ) def _cmp_fn(name, op, self_tuple, other_tuple): @@ -583,23 +631,32 @@ _hash_action = {(False, False, False, False): (''), # version of this table. -def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen): +def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): # Now that dicts retain insertion order, there's no reason to use # an ordered dict. I am leveraging that ordering here, because # derived class fields overwrite base class fields, but the order # is defined by the base class, which is found first. fields = {} + setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, + unsafe_hash, frozen)) + # Find our base classes in reverse MRO order, and exclude # ourselves. In reversed order so that more derived classes # override earlier field definitions in base classes. + # As long as we're iterating over them, see if any are frozen. + any_frozen_base = False + has_dataclass_bases = False for b in cls.__mro__[-1:0:-1]: # Only process classes that have been processed by our - # decorator. That is, they have a _MARKER attribute. - base_fields = getattr(b, _MARKER, None) + # decorator. That is, they have a _FIELDS attribute. + base_fields = getattr(b, _FIELDS, None) if base_fields: + has_dataclass_bases = True for f in base_fields.values(): fields[f.name] = f + if getattr(b, _PARAMS).frozen: + any_frozen_base = True # Now find fields in our class. While doing so, validate some # things, and set the default values (as class attributes) @@ -623,20 +680,21 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen): else: setattr(cls, f.name, f.default) - # We're inheriting from a frozen dataclass, but we're not frozen. - if cls.__setattr__ is _frozen_setattr and not frozen: - raise TypeError('cannot inherit non-frozen dataclass from a ' - 'frozen one') + # Check rules that apply if we are derived from any dataclasses. + if has_dataclass_bases: + # Raise an exception if any of our bases are frozen, but we're not. + if any_frozen_base and not frozen: + raise TypeError('cannot inherit non-frozen dataclass from a ' + 'frozen one') - # We're inheriting from a non-frozen dataclass, but we're frozen. - if (hasattr(cls, _MARKER) and cls.__setattr__ is not _frozen_setattr - and frozen): - raise TypeError('cannot inherit frozen dataclass from a ' - 'non-frozen one') + # Raise an exception if we're frozen, but none of our bases are. + if not any_frozen_base and frozen: + raise TypeError('cannot inherit frozen dataclass from a ' + 'non-frozen one') - # Remember all of the fields on our class (including bases). This + # Remember all of the fields on our class (including bases). This also # marks this class as being a dataclass. - setattr(cls, _MARKER, fields) + setattr(cls, _FIELDS, fields) # Was this class defined with an explicit __hash__? Note that if # __eq__ is defined in this class, then python will automatically @@ -704,10 +762,10 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen): 'functools.total_ordering') if frozen: - for name, fn in [('__setattr__', _frozen_setattr), - ('__delattr__', _frozen_delattr)]: - if _set_new_attribute(cls, name, fn): - raise TypeError(f'Cannot overwrite attribute {name} ' + # XXX: Which fields are frozen? InitVar? ClassVar? hashed-only? + for fn in _frozen_get_del_attr(cls, field_list): + if _set_new_attribute(cls, fn.__name__, fn): + raise TypeError(f'Cannot overwrite attribute {fn.__name__} ' f'in class {cls.__name__}') # Decide if/how we're going to create a hash function. @@ -759,7 +817,7 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False, """ def wrap(cls): - return _process_class(cls, repr, eq, order, unsafe_hash, init, frozen) + return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen) # See if we're being called as @dataclass or @dataclass(). if _cls is None: @@ -779,7 +837,7 @@ def fields(class_or_instance): # Might it be worth caching this, per class? try: - fields = getattr(class_or_instance, _MARKER) + fields = getattr(class_or_instance, _FIELDS) except AttributeError: raise TypeError('must be called with a dataclass type or instance') @@ -790,13 +848,13 @@ def fields(class_or_instance): def _is_dataclass_instance(obj): """Returns True if obj is an instance of a dataclass.""" - return not isinstance(obj, type) and hasattr(obj, _MARKER) + return not isinstance(obj, type) and hasattr(obj, _FIELDS) def is_dataclass(obj): """Returns True if obj is a dataclass or an instance of a dataclass.""" - return hasattr(obj, _MARKER) + return hasattr(obj, _FIELDS) def asdict(obj, *, dict_factory=dict): @@ -953,7 +1011,7 @@ def replace(obj, **changes): # It's an error to have init=False fields in 'changes'. # If a field is not in 'changes', read its value from the provided obj. - for f in getattr(obj, _MARKER).values(): + for f in getattr(obj, _FIELDS).values(): if not f.init: # Error if this field is specified in changes. if f.name in changes: diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 46d485c0157..3e672636094 100755 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -2476,41 +2476,92 @@ class TestFrozen(unittest.TestCase): d = D(0, 10) with self.assertRaises(FrozenInstanceError): d.i = 5 + with self.assertRaises(FrozenInstanceError): + d.j = 6 self.assertEqual(d.i, 0) + self.assertEqual(d.j, 10) - def test_inherit_from_nonfrozen_from_frozen(self): - @dataclass(frozen=True) - class C: - i: int + # Test both ways: with an intermediate normal (non-dataclass) + # class and without an intermediate class. + def test_inherit_nonfrozen_from_frozen(self): + for intermediate_class in [True, False]: + with self.subTest(intermediate_class=intermediate_class): + @dataclass(frozen=True) + class C: + i: int - with self.assertRaisesRegex(TypeError, - 'cannot inherit non-frozen dataclass from a frozen one'): - @dataclass - class D(C): - pass + if intermediate_class: + class I(C): pass + else: + I = C - def test_inherit_from_frozen_from_nonfrozen(self): - @dataclass - class C: - i: int + with self.assertRaisesRegex(TypeError, + 'cannot inherit non-frozen dataclass from a frozen one'): + @dataclass + class D(I): + pass - with self.assertRaisesRegex(TypeError, - 'cannot inherit frozen dataclass from a non-frozen one'): - @dataclass(frozen=True) - class D(C): - pass + def test_inherit_frozen_from_nonfrozen(self): + for intermediate_class in [True, False]: + with self.subTest(intermediate_class=intermediate_class): + @dataclass + class C: + i: int + + if intermediate_class: + class I(C): pass + else: + I = C + + with self.assertRaisesRegex(TypeError, + 'cannot inherit frozen dataclass from a non-frozen one'): + @dataclass(frozen=True) + class D(I): + pass def test_inherit_from_normal_class(self): - class C: - pass + for intermediate_class in [True, False]: + with self.subTest(intermediate_class=intermediate_class): + class C: + pass + + if intermediate_class: + class I(C): pass + else: + I = C + + @dataclass(frozen=True) + class D(I): + i: int + + d = D(10) + with self.assertRaises(FrozenInstanceError): + d.i = 5 + + def test_non_frozen_normal_derived(self): + # See bpo-32953. @dataclass(frozen=True) - class D(C): - i: int + class D: + x: int + y: int = 10 - d = D(10) + class S(D): + pass + + s = S(3) + self.assertEqual(s.x, 3) + self.assertEqual(s.y, 10) + s.cached = True + + # But can't change the frozen attributes. with self.assertRaises(FrozenInstanceError): - d.i = 5 + s.x = 5 + with self.assertRaises(FrozenInstanceError): + s.y = 5 + self.assertEqual(s.x, 3) + self.assertEqual(s.y, 10) + self.assertEqual(s.cached, True) if __name__ == '__main__': diff --git a/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst b/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst new file mode 100644 index 00000000000..fbea34aa9a2 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst @@ -0,0 +1,4 @@ +If a non-dataclass inherits from a frozen dataclass, allow attributes to be +added to the derived class. Only attributes from from the frozen dataclass +cannot be assigned to. Require all dataclasses in a hierarchy to be either +all frozen or all non-frozen.