bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)
Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
This commit is contained in:
parent
2a2247ce5e
commit
ea8fc52e75
|
@ -18,6 +18,142 @@ __all__ = ['dataclass',
|
|||
'is_dataclass',
|
||||
]
|
||||
|
||||
# Conditions for adding methods. The boxes indicate what action the
|
||||
# dataclass decorator takes. For all of these tables, when I talk
|
||||
# about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring
|
||||
# to the arguments to the @dataclass decorator. When checking if a
|
||||
# dunder method already exists, I mean check for an entry in the
|
||||
# class's __dict__. I never check to see if an attribute is defined
|
||||
# in a base class.
|
||||
|
||||
# Key:
|
||||
# +=========+=========================================+
|
||||
# + Value | Meaning |
|
||||
# +=========+=========================================+
|
||||
# | <blank> | No action: no method is added. |
|
||||
# +---------+-----------------------------------------+
|
||||
# | add | Generated method is added. |
|
||||
# +---------+-----------------------------------------+
|
||||
# | add* | Generated method is added only if the |
|
||||
# | | existing attribute is None and if the |
|
||||
# | | user supplied a __eq__ method in the |
|
||||
# | | class definition. |
|
||||
# +---------+-----------------------------------------+
|
||||
# | raise | TypeError is raised. |
|
||||
# +---------+-----------------------------------------+
|
||||
# | None | Attribute is set to None. |
|
||||
# +=========+=========================================+
|
||||
|
||||
# __init__
|
||||
#
|
||||
# +--- init= parameter
|
||||
# |
|
||||
# v | | |
|
||||
# | no | yes | <--- class has __init__ in __dict__?
|
||||
# +=======+=======+=======+
|
||||
# | False | | |
|
||||
# +-------+-------+-------+
|
||||
# | True | add | | <- the default
|
||||
# +=======+=======+=======+
|
||||
|
||||
# __repr__
|
||||
#
|
||||
# +--- repr= parameter
|
||||
# |
|
||||
# v | | |
|
||||
# | no | yes | <--- class has __repr__ in __dict__?
|
||||
# +=======+=======+=======+
|
||||
# | False | | |
|
||||
# +-------+-------+-------+
|
||||
# | True | add | | <- the default
|
||||
# +=======+=======+=======+
|
||||
|
||||
|
||||
# __setattr__
|
||||
# __delattr__
|
||||
#
|
||||
# +--- frozen= parameter
|
||||
# |
|
||||
# v | | |
|
||||
# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__?
|
||||
# +=======+=======+=======+
|
||||
# | False | | | <- the default
|
||||
# +-------+-------+-------+
|
||||
# | True | add | raise |
|
||||
# +=======+=======+=======+
|
||||
# Raise because not adding these methods would break the "frozen-ness"
|
||||
# of the class.
|
||||
|
||||
# __eq__
|
||||
#
|
||||
# +--- eq= parameter
|
||||
# |
|
||||
# v | | |
|
||||
# | no | yes | <--- class has __eq__ in __dict__?
|
||||
# +=======+=======+=======+
|
||||
# | False | | |
|
||||
# +-------+-------+-------+
|
||||
# | True | add | | <- the default
|
||||
# +=======+=======+=======+
|
||||
|
||||
# __lt__
|
||||
# __le__
|
||||
# __gt__
|
||||
# __ge__
|
||||
#
|
||||
# +--- order= parameter
|
||||
# |
|
||||
# v | | |
|
||||
# | no | yes | <--- class has any comparison method in __dict__?
|
||||
# +=======+=======+=======+
|
||||
# | False | | | <- the default
|
||||
# +-------+-------+-------+
|
||||
# | True | add | raise |
|
||||
# +=======+=======+=======+
|
||||
# Raise because to allow this case would interfere with using
|
||||
# functools.total_ordering.
|
||||
|
||||
# __hash__
|
||||
|
||||
# +------------------- hash= parameter
|
||||
# | +----------- eq= parameter
|
||||
# | | +--- frozen= parameter
|
||||
# | | |
|
||||
# v v v | | |
|
||||
# | no | yes | <--- class has __hash__ in __dict__?
|
||||
# +=========+=======+=======+========+========+
|
||||
# | 1 None | False | False | | | No __eq__, use the base class __hash__
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 2 None | False | True | | | No __eq__, use the base class __hash__
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 3 None | True | False | None | | <-- the default, not hashable
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 4 None | True | True | add | add* | Frozen, so hashable
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 5 False | False | False | | |
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 6 False | False | True | | |
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 7 False | True | False | | |
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 8 False | True | True | | |
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# | 9 True | False | False | add | add* | Has no __eq__, but hashable
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# |10 True | False | True | add | add* | Has no __eq__, but hashable
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# |11 True | True | False | add | add* | Not frozen, but hashable
|
||||
# +---------+-------+-------+--------+--------+
|
||||
# |12 True | True | True | add | add* | Frozen, so hashable
|
||||
# +=========+=======+=======+========+========+
|
||||
# For boxes that are blank, __hash__ is untouched and therefore
|
||||
# inherited from the base class. If the base is object, then
|
||||
# id-based hashing is used.
|
||||
# Note that a class may have already __hash__=None if it specified an
|
||||
# __eq__ method in the class body (not one that was created by
|
||||
# @dataclass).
|
||||
|
||||
|
||||
# Raised when an attempt is made to modify a frozen class.
|
||||
class FrozenInstanceError(AttributeError): pass
|
||||
|
||||
|
@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields):
|
|||
# return "(self.x,self.y)".
|
||||
|
||||
# Special case for the 0-tuple.
|
||||
if len(fields) == 0:
|
||||
if not fields:
|
||||
return '()'
|
||||
# Note the trailing comma, needed if this turns out to be a 1-tuple.
|
||||
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
|
||||
|
||||
|
||||
def _create_fn(name, args, body, globals=None, locals=None,
|
||||
def _create_fn(name, args, body, *, globals=None, locals=None,
|
||||
return_type=MISSING):
|
||||
# Note that we mutate locals when exec() is called. Caller beware!
|
||||
if locals is None:
|
||||
|
@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name):
|
|||
body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
|
||||
|
||||
# If no body lines, use 'pass'.
|
||||
if len(body_lines) == 0:
|
||||
if not body_lines:
|
||||
body_lines = ['pass']
|
||||
|
||||
locals = {f'_type_{f.name}': f.type for f in fields}
|
||||
|
@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
|
|||
'return NotImplemented'])
|
||||
|
||||
|
||||
def _set_eq_fns(cls, fields):
|
||||
# Create and set the equality comparison methods on cls.
|
||||
# Pre-compute self_tuple and other_tuple, then re-use them for
|
||||
# each function.
|
||||
self_tuple = _tuple_str('self', fields)
|
||||
other_tuple = _tuple_str('other', fields)
|
||||
for name, op in [('__eq__', '=='),
|
||||
('__ne__', '!='),
|
||||
]:
|
||||
_set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
|
||||
|
||||
|
||||
def _set_order_fns(cls, fields):
|
||||
# Create and set the ordering methods on cls.
|
||||
# Pre-compute self_tuple and other_tuple, then re-use them for
|
||||
# each function.
|
||||
self_tuple = _tuple_str('self', fields)
|
||||
other_tuple = _tuple_str('other', fields)
|
||||
for name, op in [('__lt__', '<'),
|
||||
('__le__', '<='),
|
||||
('__gt__', '>'),
|
||||
('__ge__', '>='),
|
||||
]:
|
||||
_set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
|
||||
|
||||
|
||||
def _hash_fn(fields):
|
||||
self_tuple = _tuple_str('self', fields)
|
||||
return _create_fn('__hash__',
|
||||
|
@ -431,20 +541,20 @@ def _find_fields(cls):
|
|||
# a Field(), then it contains additional info beyond (and
|
||||
# possibly including) the actual default value. Pseudo-fields
|
||||
# ClassVars and InitVars are included, despite the fact that
|
||||
# they're not real fields. That's deal with later.
|
||||
# they're not real fields. That's dealt with later.
|
||||
|
||||
annotations = getattr(cls, '__annotations__', {})
|
||||
|
||||
return [_get_field(cls, a_name, a_type)
|
||||
for a_name, a_type in annotations.items()]
|
||||
|
||||
|
||||
def _set_attribute(cls, name, value):
|
||||
# Raise TypeError if an attribute by this name already exists.
|
||||
def _set_new_attribute(cls, name, value):
|
||||
# Never overwrites an existing attribute. Returns True if the
|
||||
# attribute already exists.
|
||||
if name in cls.__dict__:
|
||||
raise TypeError(f'Cannot overwrite attribute {name} '
|
||||
f'in {cls.__name__}')
|
||||
return True
|
||||
setattr(cls, name, value)
|
||||
return False
|
||||
|
||||
|
||||
def _process_class(cls, repr, eq, order, hash, init, frozen):
|
||||
|
@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
|
|||
# be inherited down.
|
||||
is_frozen = frozen or cls.__setattr__ is _frozen_setattr
|
||||
|
||||
# Was this class defined with an __eq__? Used in __hash__ logic.
|
||||
auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None
|
||||
|
||||
# If we're generating ordering methods, we must be generating
|
||||
# the eq methods.
|
||||
if order and not eq:
|
||||
|
@ -505,10 +618,10 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
|
|||
has_post_init = hasattr(cls, _POST_INIT_NAME)
|
||||
|
||||
# Include InitVars and regular fields (so, not ClassVars).
|
||||
_set_attribute(cls, '__init__',
|
||||
_init_fn(list(filter(lambda f: f._field_type
|
||||
in (_FIELD, _FIELD_INITVAR),
|
||||
fields.values())),
|
||||
flds = [f for f in fields.values()
|
||||
if f._field_type in (_FIELD, _FIELD_INITVAR)]
|
||||
_set_new_attribute(cls, '__init__',
|
||||
_init_fn(flds,
|
||||
is_frozen,
|
||||
has_post_init,
|
||||
# The name to use for the "self" param
|
||||
|
@ -519,48 +632,77 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
|
|||
|
||||
# Get the fields as a list, and include only real fields. This is
|
||||
# used in all of the following methods.
|
||||
field_list = list(filter(lambda f: f._field_type is _FIELD,
|
||||
fields.values()))
|
||||
field_list = [f for f in fields.values() if f._field_type is _FIELD]
|
||||
|
||||
if repr:
|
||||
_set_attribute(cls, '__repr__',
|
||||
_repr_fn(list(filter(lambda f: f.repr, field_list))))
|
||||
|
||||
if is_frozen:
|
||||
_set_attribute(cls, '__setattr__', _frozen_setattr)
|
||||
_set_attribute(cls, '__delattr__', _frozen_delattr)
|
||||
|
||||
generate_hash = False
|
||||
if hash is None:
|
||||
if eq and frozen:
|
||||
# Generate a hash function.
|
||||
generate_hash = True
|
||||
elif eq and not frozen:
|
||||
# Not hashable.
|
||||
_set_attribute(cls, '__hash__', None)
|
||||
elif not eq:
|
||||
# Otherwise, use the base class definition of hash(). That is,
|
||||
# don't set anything on this class.
|
||||
pass
|
||||
else:
|
||||
assert "can't get here"
|
||||
else:
|
||||
generate_hash = hash
|
||||
if generate_hash:
|
||||
_set_attribute(cls, '__hash__',
|
||||
_hash_fn(list(filter(lambda f: f.compare
|
||||
if f.hash is None
|
||||
else f.hash,
|
||||
field_list))))
|
||||
flds = [f for f in field_list if f.repr]
|
||||
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
|
||||
|
||||
if eq:
|
||||
# Create and __eq__ and __ne__ methods.
|
||||
_set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
|
||||
# Create _eq__ method. There's no need for a __ne__ method,
|
||||
# since python will call __eq__ and negate it.
|
||||
flds = [f for f in field_list if f.compare]
|
||||
self_tuple = _tuple_str('self', flds)
|
||||
other_tuple = _tuple_str('other', flds)
|
||||
_set_new_attribute(cls, '__eq__',
|
||||
_cmp_fn('__eq__', '==',
|
||||
self_tuple, other_tuple))
|
||||
|
||||
if order:
|
||||
# Create and __lt__, __le__, __gt__, and __ge__ methods.
|
||||
# Create and set the comparison functions.
|
||||
_set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
|
||||
# Create and set the ordering methods.
|
||||
flds = [f for f in field_list if f.compare]
|
||||
self_tuple = _tuple_str('self', flds)
|
||||
other_tuple = _tuple_str('other', flds)
|
||||
for name, op in [('__lt__', '<'),
|
||||
('__le__', '<='),
|
||||
('__gt__', '>'),
|
||||
('__ge__', '>='),
|
||||
]:
|
||||
if _set_new_attribute(cls, name,
|
||||
_cmp_fn(name, op, self_tuple, other_tuple)):
|
||||
raise TypeError(f'Cannot overwrite attribute {name} '
|
||||
f'in {cls.__name__}. Consider using '
|
||||
'functools.total_ordering')
|
||||
|
||||
if is_frozen:
|
||||
for name, fn in [('__setattr__', _frozen_setattr),
|
||||
('__delattr__', _frozen_delattr)]:
|
||||
if _set_new_attribute(cls, name, fn):
|
||||
raise TypeError(f'Cannot overwrite attribute {name} '
|
||||
f'in {cls.__name__}')
|
||||
|
||||
# Decide if/how we're going to create a hash function.
|
||||
# TODO: Move this table to module scope, so it's not recreated
|
||||
# all the time.
|
||||
generate_hash = {(None, False, False): ('', ''),
|
||||
(None, False, True): ('', ''),
|
||||
(None, True, False): ('none', ''),
|
||||
(None, True, True): ('fn', 'fn-x'),
|
||||
(False, False, False): ('', ''),
|
||||
(False, False, True): ('', ''),
|
||||
(False, True, False): ('', ''),
|
||||
(False, True, True): ('', ''),
|
||||
(True, False, False): ('fn', 'fn-x'),
|
||||
(True, False, True): ('fn', 'fn-x'),
|
||||
(True, True, False): ('fn', 'fn-x'),
|
||||
(True, True, True): ('fn', 'fn-x'),
|
||||
}[None if hash is None else bool(hash), # Force bool() if not None.
|
||||
bool(eq),
|
||||
bool(frozen)]['__hash__' in cls.__dict__]
|
||||
# No need to call _set_new_attribute here, since we already know if
|
||||
# we're overwriting a __hash__ or not.
|
||||
if generate_hash == '':
|
||||
# Do nothing.
|
||||
pass
|
||||
elif generate_hash == 'none':
|
||||
cls.__hash__ = None
|
||||
elif generate_hash in ('fn', 'fn-x'):
|
||||
if generate_hash == 'fn' or auto_hash_test:
|
||||
flds = [f for f in field_list
|
||||
if (f.compare if f.hash is None else f.hash)]
|
||||
cls.__hash__ = _hash_fn(flds)
|
||||
else:
|
||||
assert False, f"can't get here: {generate_hash}"
|
||||
|
||||
if not getattr(cls, '__doc__'):
|
||||
# Create a class doc-string.
|
||||
|
|
|
@ -9,6 +9,7 @@ import unittest
|
|||
from unittest.mock import Mock
|
||||
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
|
||||
from collections import deque, OrderedDict, namedtuple
|
||||
from functools import total_ordering
|
||||
|
||||
# Just any custom exception we can catch.
|
||||
class CustomError(Exception): pass
|
||||
|
@ -82,63 +83,7 @@ class TestCase(unittest.TestCase):
|
|||
class C(B):
|
||||
x: int = 0
|
||||
|
||||
def test_overwriting_init(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __init__ '
|
||||
'in C'):
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
def __init__(self, x):
|
||||
self.x = 2 * x
|
||||
|
||||
@dataclass(init=False)
|
||||
class C:
|
||||
x: int
|
||||
def __init__(self, x):
|
||||
self.x = 2 * x
|
||||
self.assertEqual(C(5).x, 10)
|
||||
|
||||
def test_overwriting_repr(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __repr__ '
|
||||
'in C'):
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
def __repr__(self):
|
||||
pass
|
||||
|
||||
@dataclass(repr=False)
|
||||
class C:
|
||||
x: int
|
||||
def __repr__(self):
|
||||
return 'x'
|
||||
self.assertEqual(repr(C(0)), 'x')
|
||||
|
||||
def test_overwriting_cmp(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __eq__ '
|
||||
'in C'):
|
||||
# This will generate the comparison functions, make sure we can't
|
||||
# overwrite them.
|
||||
@dataclass(hash=False, frozen=False)
|
||||
class C:
|
||||
x: int
|
||||
def __eq__(self):
|
||||
pass
|
||||
|
||||
@dataclass(order=False, eq=False)
|
||||
class C:
|
||||
x: int
|
||||
def __eq__(self, other):
|
||||
return True
|
||||
self.assertEqual(C(0), 'x')
|
||||
|
||||
def test_overwriting_hash(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __hash__ '
|
||||
'in C'):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
x: int
|
||||
|
@ -152,9 +97,6 @@ class TestCase(unittest.TestCase):
|
|||
return 600
|
||||
self.assertEqual(hash(C(0)), 600)
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __hash__ '
|
||||
'in C'):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
x: int
|
||||
|
@ -168,33 +110,6 @@ class TestCase(unittest.TestCase):
|
|||
return 600
|
||||
self.assertEqual(hash(C(0)), 600)
|
||||
|
||||
def test_overwriting_frozen(self):
|
||||
# frozen uses __setattr__ and __delattr__
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __setattr__ '
|
||||
'in C'):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
x: int
|
||||
def __setattr__(self):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __delattr__ '
|
||||
'in C'):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
x: int
|
||||
def __delattr__(self):
|
||||
pass
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class C:
|
||||
x: int
|
||||
def __setattr__(self, name, value):
|
||||
self.__dict__['x'] = value * 2
|
||||
self.assertEqual(C(10).x, 20)
|
||||
|
||||
def test_overwrite_fields_in_derived_class(self):
|
||||
# Note that x from C1 replaces x in Base, but the order remains
|
||||
# the same as defined in Base.
|
||||
|
@ -239,34 +154,6 @@ class TestCase(unittest.TestCase):
|
|||
first = next(iter(sig.parameters))
|
||||
self.assertEqual('self', first)
|
||||
|
||||
def test_repr(self):
|
||||
@dataclass
|
||||
class B:
|
||||
x: int
|
||||
|
||||
@dataclass
|
||||
class C(B):
|
||||
y: int = 10
|
||||
|
||||
o = C(4)
|
||||
self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
|
||||
|
||||
@dataclass
|
||||
class D(C):
|
||||
x: int = 20
|
||||
self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
@dataclass
|
||||
class D:
|
||||
i: int
|
||||
@dataclass
|
||||
class E:
|
||||
pass
|
||||
self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
|
||||
self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
|
||||
|
||||
def test_0_field_compare(self):
|
||||
# Ensure that order=False is the default.
|
||||
@dataclass
|
||||
|
@ -420,80 +307,8 @@ class TestCase(unittest.TestCase):
|
|||
self.assertEqual(hash(C(4)), hash((4,)))
|
||||
self.assertEqual(hash(C(42)), hash((42,)))
|
||||
|
||||
def test_hash(self):
|
||||
@dataclass(hash=True)
|
||||
class C:
|
||||
x: int
|
||||
y: str
|
||||
self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
|
||||
|
||||
def test_no_hash(self):
|
||||
@dataclass(hash=None)
|
||||
class C:
|
||||
x: int
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"unhashable type: 'C'"):
|
||||
hash(C(1))
|
||||
|
||||
def test_hash_rules(self):
|
||||
# There are 24 cases of:
|
||||
# hash=True/False/None
|
||||
# eq=True/False
|
||||
# order=True/False
|
||||
# frozen=True/False
|
||||
for (hash, eq, order, frozen, result ) in [
|
||||
(False, False, False, False, 'absent'),
|
||||
(False, False, False, True, 'absent'),
|
||||
(False, False, True, False, 'exception'),
|
||||
(False, False, True, True, 'exception'),
|
||||
(False, True, False, False, 'absent'),
|
||||
(False, True, False, True, 'absent'),
|
||||
(False, True, True, False, 'absent'),
|
||||
(False, True, True, True, 'absent'),
|
||||
(True, False, False, False, 'fn'),
|
||||
(True, False, False, True, 'fn'),
|
||||
(True, False, True, False, 'exception'),
|
||||
(True, False, True, True, 'exception'),
|
||||
(True, True, False, False, 'fn'),
|
||||
(True, True, False, True, 'fn'),
|
||||
(True, True, True, False, 'fn'),
|
||||
(True, True, True, True, 'fn'),
|
||||
(None, False, False, False, 'absent'),
|
||||
(None, False, False, True, 'absent'),
|
||||
(None, False, True, False, 'exception'),
|
||||
(None, False, True, True, 'exception'),
|
||||
(None, True, False, False, 'none'),
|
||||
(None, True, False, True, 'fn'),
|
||||
(None, True, True, False, 'none'),
|
||||
(None, True, True, True, 'fn'),
|
||||
]:
|
||||
with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
|
||||
if result == 'exception':
|
||||
with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
|
||||
@dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
|
||||
class C:
|
||||
pass
|
||||
else:
|
||||
@dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
|
||||
class C:
|
||||
pass
|
||||
|
||||
# See if the result matches what's expected.
|
||||
if result == 'fn':
|
||||
# __hash__ contains the function we generated.
|
||||
self.assertIn('__hash__', C.__dict__)
|
||||
self.assertIsNotNone(C.__dict__['__hash__'])
|
||||
elif result == 'absent':
|
||||
# __hash__ is not present in our class.
|
||||
self.assertNotIn('__hash__', C.__dict__)
|
||||
elif result == 'none':
|
||||
# __hash__ is set to None.
|
||||
self.assertIn('__hash__', C.__dict__)
|
||||
self.assertIsNone(C.__dict__['__hash__'])
|
||||
else:
|
||||
assert False, f'unknown result {result!r}'
|
||||
|
||||
def test_eq_order(self):
|
||||
# Test combining eq and order.
|
||||
for (eq, order, result ) in [
|
||||
(False, False, 'neither'),
|
||||
(False, True, 'exception'),
|
||||
|
@ -513,21 +328,18 @@ class TestCase(unittest.TestCase):
|
|||
|
||||
if result == 'neither':
|
||||
self.assertNotIn('__eq__', C.__dict__)
|
||||
self.assertNotIn('__ne__', C.__dict__)
|
||||
self.assertNotIn('__lt__', C.__dict__)
|
||||
self.assertNotIn('__le__', C.__dict__)
|
||||
self.assertNotIn('__gt__', C.__dict__)
|
||||
self.assertNotIn('__ge__', C.__dict__)
|
||||
elif result == 'both':
|
||||
self.assertIn('__eq__', C.__dict__)
|
||||
self.assertIn('__ne__', C.__dict__)
|
||||
self.assertIn('__lt__', C.__dict__)
|
||||
self.assertIn('__le__', C.__dict__)
|
||||
self.assertIn('__gt__', C.__dict__)
|
||||
self.assertIn('__ge__', C.__dict__)
|
||||
elif result == 'eq_only':
|
||||
self.assertIn('__eq__', C.__dict__)
|
||||
self.assertIn('__ne__', C.__dict__)
|
||||
self.assertNotIn('__lt__', C.__dict__)
|
||||
self.assertNotIn('__le__', C.__dict__)
|
||||
self.assertNotIn('__gt__', C.__dict__)
|
||||
|
@ -811,19 +623,6 @@ class TestCase(unittest.TestCase):
|
|||
y: int
|
||||
self.assertNotEqual(Point(1, 3), C(1, 3))
|
||||
|
||||
def test_base_has_init(self):
|
||||
class B:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# Make sure that declaring this class doesn't raise an error.
|
||||
# The issue is that we can't override __init__ in our class,
|
||||
# but it should be okay to add __init__ to us if our base has
|
||||
# an __init__.
|
||||
@dataclass
|
||||
class C(B):
|
||||
x: int = 0
|
||||
|
||||
def test_frozen(self):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
|
@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase):
|
|||
'y': int,
|
||||
'z': 'typing.Any'})
|
||||
|
||||
|
||||
class TestDocString(unittest.TestCase):
|
||||
def assertDocStrEqual(self, a, b):
|
||||
# Because 3.6 and 3.7 differ in how inspect.signature work
|
||||
|
@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase):
|
|||
self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
|
||||
|
||||
|
||||
class TestInit(unittest.TestCase):
|
||||
def test_base_has_init(self):
|
||||
class B:
|
||||
def __init__(self):
|
||||
self.z = 100
|
||||
pass
|
||||
|
||||
# Make sure that declaring this class doesn't raise an error.
|
||||
# The issue is that we can't override __init__ in our class,
|
||||
# but it should be okay to add __init__ to us if our base has
|
||||
# an __init__.
|
||||
@dataclass
|
||||
class C(B):
|
||||
x: int = 0
|
||||
c = C(10)
|
||||
self.assertEqual(c.x, 10)
|
||||
self.assertNotIn('z', vars(c))
|
||||
|
||||
# Make sure that if we don't add an init, the base __init__
|
||||
# gets called.
|
||||
@dataclass(init=False)
|
||||
class C(B):
|
||||
x: int = 10
|
||||
c = C()
|
||||
self.assertEqual(c.x, 10)
|
||||
self.assertEqual(c.z, 100)
|
||||
|
||||
def test_no_init(self):
|
||||
dataclass(init=False)
|
||||
class C:
|
||||
i: int = 0
|
||||
self.assertEqual(C().i, 0)
|
||||
|
||||
dataclass(init=False)
|
||||
class C:
|
||||
i: int = 2
|
||||
def __init__(self):
|
||||
self.i = 3
|
||||
self.assertEqual(C().i, 3)
|
||||
|
||||
def test_overwriting_init(self):
|
||||
# If the class has __init__, use it no matter the value of
|
||||
# init=.
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
def __init__(self, x):
|
||||
self.x = 2 * x
|
||||
self.assertEqual(C(3).x, 6)
|
||||
|
||||
@dataclass(init=True)
|
||||
class C:
|
||||
x: int
|
||||
def __init__(self, x):
|
||||
self.x = 2 * x
|
||||
self.assertEqual(C(4).x, 8)
|
||||
|
||||
@dataclass(init=False)
|
||||
class C:
|
||||
x: int
|
||||
def __init__(self, x):
|
||||
self.x = 2 * x
|
||||
self.assertEqual(C(5).x, 10)
|
||||
|
||||
|
||||
class TestRepr(unittest.TestCase):
|
||||
def test_repr(self):
|
||||
@dataclass
|
||||
class B:
|
||||
x: int
|
||||
|
||||
@dataclass
|
||||
class C(B):
|
||||
y: int = 10
|
||||
|
||||
o = C(4)
|
||||
self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
|
||||
|
||||
@dataclass
|
||||
class D(C):
|
||||
x: int = 20
|
||||
self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
@dataclass
|
||||
class D:
|
||||
i: int
|
||||
@dataclass
|
||||
class E:
|
||||
pass
|
||||
self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
|
||||
self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
|
||||
|
||||
def test_no_repr(self):
|
||||
# Test a class with no __repr__ and repr=False.
|
||||
@dataclass(repr=False)
|
||||
class C:
|
||||
x: int
|
||||
self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
|
||||
repr(C(3)))
|
||||
|
||||
# Test a class with a __repr__ and repr=False.
|
||||
@dataclass(repr=False)
|
||||
class C:
|
||||
x: int
|
||||
def __repr__(self):
|
||||
return 'C-class'
|
||||
self.assertEqual(repr(C(3)), 'C-class')
|
||||
|
||||
def test_overwriting_repr(self):
|
||||
# If the class has __repr__, use it no matter the value of
|
||||
# repr=.
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
def __repr__(self):
|
||||
return 'x'
|
||||
self.assertEqual(repr(C(0)), 'x')
|
||||
|
||||
@dataclass(repr=True)
|
||||
class C:
|
||||
x: int
|
||||
def __repr__(self):
|
||||
return 'x'
|
||||
self.assertEqual(repr(C(0)), 'x')
|
||||
|
||||
@dataclass(repr=False)
|
||||
class C:
|
||||
x: int
|
||||
def __repr__(self):
|
||||
return 'x'
|
||||
self.assertEqual(repr(C(0)), 'x')
|
||||
|
||||
|
||||
class TestFrozen(unittest.TestCase):
|
||||
def test_overwriting_frozen(self):
|
||||
# frozen uses __setattr__ and __delattr__
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __setattr__'):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
x: int
|
||||
def __setattr__(self):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __delattr__'):
|
||||
@dataclass(frozen=True)
|
||||
class C:
|
||||
x: int
|
||||
def __delattr__(self):
|
||||
pass
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class C:
|
||||
x: int
|
||||
def __setattr__(self, name, value):
|
||||
self.__dict__['x'] = value * 2
|
||||
self.assertEqual(C(10).x, 20)
|
||||
|
||||
|
||||
class TestEq(unittest.TestCase):
|
||||
def test_no_eq(self):
|
||||
# Test a class with no __eq__ and eq=False.
|
||||
@dataclass(eq=False)
|
||||
class C:
|
||||
x: int
|
||||
self.assertNotEqual(C(0), C(0))
|
||||
c = C(3)
|
||||
self.assertEqual(c, c)
|
||||
|
||||
# Test a class with an __eq__ and eq=False.
|
||||
@dataclass(eq=False)
|
||||
class C:
|
||||
x: int
|
||||
def __eq__(self, other):
|
||||
return other == 10
|
||||
self.assertEqual(C(3), 10)
|
||||
|
||||
def test_overwriting_eq(self):
|
||||
# If the class has __eq__, use it no matter the value of
|
||||
# eq=.
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
def __eq__(self, other):
|
||||
return other == 3
|
||||
self.assertEqual(C(1), 3)
|
||||
self.assertNotEqual(C(1), 1)
|
||||
|
||||
@dataclass(eq=True)
|
||||
class C:
|
||||
x: int
|
||||
def __eq__(self, other):
|
||||
return other == 4
|
||||
self.assertEqual(C(1), 4)
|
||||
self.assertNotEqual(C(1), 1)
|
||||
|
||||
@dataclass(eq=False)
|
||||
class C:
|
||||
x: int
|
||||
def __eq__(self, other):
|
||||
return other == 5
|
||||
self.assertEqual(C(1), 5)
|
||||
self.assertNotEqual(C(1), 1)
|
||||
|
||||
|
||||
class TestOrdering(unittest.TestCase):
|
||||
def test_functools_total_ordering(self):
|
||||
# Test that functools.total_ordering works with this class.
|
||||
@total_ordering
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
def __lt__(self, other):
|
||||
# Perform the test "backward", just to make
|
||||
# sure this is being called.
|
||||
return self.x >= other
|
||||
|
||||
self.assertLess(C(0), -1)
|
||||
self.assertLessEqual(C(0), -1)
|
||||
self.assertGreater(C(0), 1)
|
||||
self.assertGreaterEqual(C(0), 1)
|
||||
|
||||
def test_no_order(self):
|
||||
# Test that no ordering functions are added by default.
|
||||
@dataclass(order=False)
|
||||
class C:
|
||||
x: int
|
||||
# Make sure no order methods are added.
|
||||
self.assertNotIn('__le__', C.__dict__)
|
||||
self.assertNotIn('__lt__', C.__dict__)
|
||||
self.assertNotIn('__ge__', C.__dict__)
|
||||
self.assertNotIn('__gt__', C.__dict__)
|
||||
|
||||
# Test that __lt__ is still called
|
||||
@dataclass(order=False)
|
||||
class C:
|
||||
x: int
|
||||
def __lt__(self, other):
|
||||
return False
|
||||
# Make sure other methods aren't added.
|
||||
self.assertNotIn('__le__', C.__dict__)
|
||||
self.assertNotIn('__ge__', C.__dict__)
|
||||
self.assertNotIn('__gt__', C.__dict__)
|
||||
|
||||
def test_overwriting_order(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __lt__'
|
||||
'.*using functools.total_ordering'):
|
||||
@dataclass(order=True)
|
||||
class C:
|
||||
x: int
|
||||
def __lt__(self):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __le__'
|
||||
'.*using functools.total_ordering'):
|
||||
@dataclass(order=True)
|
||||
class C:
|
||||
x: int
|
||||
def __le__(self):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __gt__'
|
||||
'.*using functools.total_ordering'):
|
||||
@dataclass(order=True)
|
||||
class C:
|
||||
x: int
|
||||
def __gt__(self):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Cannot overwrite attribute __ge__'
|
||||
'.*using functools.total_ordering'):
|
||||
@dataclass(order=True)
|
||||
class C:
|
||||
x: int
|
||||
def __ge__(self):
|
||||
pass
|
||||
|
||||
class TestHash(unittest.TestCase):
|
||||
def test_hash(self):
|
||||
@dataclass(hash=True)
|
||||
class C:
|
||||
x: int
|
||||
y: str
|
||||
self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
|
||||
|
||||
def test_hash_false(self):
|
||||
@dataclass(hash=False)
|
||||
class C:
|
||||
x: int
|
||||
y: str
|
||||
self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
|
||||
|
||||
def test_hash_none(self):
|
||||
@dataclass(hash=None)
|
||||
class C:
|
||||
x: int
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"unhashable type: 'C'"):
|
||||
hash(C(1))
|
||||
|
||||
def test_hash_rules(self):
|
||||
def non_bool(value):
|
||||
# Map to something else that's True, but not a bool.
|
||||
if value is None:
|
||||
return None
|
||||
if value:
|
||||
return (3,)
|
||||
return 0
|
||||
|
||||
def test(case, hash, eq, frozen, with_hash, result):
|
||||
with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
|
||||
if with_hash:
|
||||
@dataclass(hash=hash, eq=eq, frozen=frozen)
|
||||
class C:
|
||||
def __hash__(self):
|
||||
return 0
|
||||
else:
|
||||
@dataclass(hash=hash, eq=eq, frozen=frozen)
|
||||
class C:
|
||||
pass
|
||||
|
||||
# See if the result matches what's expected.
|
||||
if result in ('fn', 'fn-x'):
|
||||
# __hash__ contains the function we generated.
|
||||
self.assertIn('__hash__', C.__dict__)
|
||||
self.assertIsNotNone(C.__dict__['__hash__'])
|
||||
|
||||
if result == 'fn-x':
|
||||
# This is the "auto-hash test" case. We
|
||||
# should overwrite __hash__ iff there's an
|
||||
# __eq__ and if __hash__=None.
|
||||
|
||||
# There are two ways of getting __hash__=None:
|
||||
# explicitely, and by defining __eq__. If
|
||||
# __eq__ is defined, python will add __hash__
|
||||
# when the class is created.
|
||||
@dataclass(hash=hash, eq=eq, frozen=frozen)
|
||||
class C:
|
||||
def __eq__(self, other): pass
|
||||
__hash__ = None
|
||||
|
||||
# Hash should be overwritten (non-None).
|
||||
self.assertIsNotNone(C.__dict__['__hash__'])
|
||||
|
||||
# Same test as above, but we don't provide
|
||||
# __hash__, it will implicitely set to None.
|
||||
@dataclass(hash=hash, eq=eq, frozen=frozen)
|
||||
class C:
|
||||
def __eq__(self, other): pass
|
||||
|
||||
# Hash should be overwritten (non-None).
|
||||
self.assertIsNotNone(C.__dict__['__hash__'])
|
||||
|
||||
elif result == '':
|
||||
# __hash__ is not present in our class.
|
||||
if not with_hash:
|
||||
self.assertNotIn('__hash__', C.__dict__)
|
||||
elif result == 'none':
|
||||
# __hash__ is set to None.
|
||||
self.assertIn('__hash__', C.__dict__)
|
||||
self.assertIsNone(C.__dict__['__hash__'])
|
||||
else:
|
||||
assert False, f'unknown result {result!r}'
|
||||
|
||||
# There are 12 cases of:
|
||||
# hash=True/False/None
|
||||
# eq=True/False
|
||||
# frozen=True/False
|
||||
# And for each of these, a different result if
|
||||
# __hash__ is defined or not.
|
||||
for case, (hash, eq, frozen, result_no, result_yes) in enumerate([
|
||||
(None, False, False, '', ''),
|
||||
(None, False, True, '', ''),
|
||||
(None, True, False, 'none', ''),
|
||||
(None, True, True, 'fn', 'fn-x'),
|
||||
(False, False, False, '', ''),
|
||||
(False, False, True, '', ''),
|
||||
(False, True, False, '', ''),
|
||||
(False, True, True, '', ''),
|
||||
(True, False, False, 'fn', 'fn-x'),
|
||||
(True, False, True, 'fn', 'fn-x'),
|
||||
(True, True, False, 'fn', 'fn-x'),
|
||||
(True, True, True, 'fn', 'fn-x'),
|
||||
], 1):
|
||||
test(case, hash, eq, frozen, False, result_no)
|
||||
test(case, hash, eq, frozen, True, result_yes)
|
||||
|
||||
# Test non-bool truth values, too. This is just to
|
||||
# make sure the data-driven table in the decorator
|
||||
# handles non-bool values.
|
||||
test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
|
||||
test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes)
|
||||
|
||||
|
||||
def test_eq_only(self):
|
||||
# If a class defines __eq__, __hash__ is automatically added
|
||||
# and set to None. This is normal Python behavior, not
|
||||
# related to dataclasses. Make sure we don't interfere with
|
||||
# that (see bpo=32546).
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
i: int
|
||||
def __eq__(self, other):
|
||||
return self.i == other.i
|
||||
self.assertEqual(C(1), C(1))
|
||||
self.assertNotEqual(C(1), C(4))
|
||||
|
||||
# And make sure things work in this case if we specify
|
||||
# hash=True.
|
||||
@dataclass(hash=True)
|
||||
class C:
|
||||
i: int
|
||||
def __eq__(self, other):
|
||||
return self.i == other.i
|
||||
self.assertEqual(C(1), C(1.0))
|
||||
self.assertEqual(hash(C(1)), hash(C(1.0)))
|
||||
|
||||
# And check that the classes __eq__ is being used, despite
|
||||
# specifying eq=True.
|
||||
@dataclass(hash=True, eq=True)
|
||||
class C:
|
||||
i: int
|
||||
def __eq__(self, other):
|
||||
return self.i == 3 and self.i == other.i
|
||||
self.assertEqual(C(3), C(3))
|
||||
self.assertNotEqual(C(1), C(1))
|
||||
self.assertEqual(hash(C(1)), hash(C(1.0)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
In dataclasses, allow easier overriding of dunder methods without specifying
|
||||
decorator parameters.
|
Loading…
Reference in New Issue