bpo-34776: Fix dataclasses to support __future__ "annotations" mode (GH-9518) (#17532)
(cherry picked from commit d219cc4180
)
Co-authored-by: Yury Selivanov <yury@magic.io>
This commit is contained in:
parent
a0078d9a33
commit
66d7a5d58a
|
@ -368,23 +368,24 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
|
|||
# worries about external callers.
|
||||
if locals is None:
|
||||
locals = {}
|
||||
# __builtins__ may be the "builtins" module or
|
||||
# the value of its "__dict__",
|
||||
# so make sure "__builtins__" is the module.
|
||||
if globals is not None and '__builtins__' not in globals:
|
||||
globals['__builtins__'] = builtins
|
||||
if 'BUILTINS' not in locals:
|
||||
locals['BUILTINS'] = builtins
|
||||
return_annotation = ''
|
||||
if return_type is not MISSING:
|
||||
locals['_return_type'] = return_type
|
||||
return_annotation = '->_return_type'
|
||||
args = ','.join(args)
|
||||
body = '\n'.join(f' {b}' for b in body)
|
||||
body = '\n'.join(f' {b}' for b in body)
|
||||
|
||||
# Compute the text of the entire function.
|
||||
txt = f'def {name}({args}){return_annotation}:\n{body}'
|
||||
txt = f' def {name}({args}){return_annotation}:\n{body}'
|
||||
|
||||
exec(txt, globals, locals)
|
||||
return locals[name]
|
||||
local_vars = ', '.join(locals.keys())
|
||||
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
|
||||
|
||||
ns = {}
|
||||
exec(txt, globals, ns)
|
||||
return ns['__create_fn__'](**locals)
|
||||
|
||||
|
||||
def _field_assign(frozen, name, value, self_name):
|
||||
|
@ -395,7 +396,7 @@ def _field_assign(frozen, name, value, self_name):
|
|||
# self_name is what "self" is called in this function: don't
|
||||
# hard-code "self", since that might be a field name.
|
||||
if frozen:
|
||||
return f'__builtins__.object.__setattr__({self_name},{name!r},{value})'
|
||||
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
|
||||
return f'{self_name}.{name}={value}'
|
||||
|
||||
|
||||
|
@ -472,7 +473,7 @@ def _init_param(f):
|
|||
return f'{f.name}:_type_{f.name}{default}'
|
||||
|
||||
|
||||
def _init_fn(fields, frozen, has_post_init, self_name):
|
||||
def _init_fn(fields, frozen, has_post_init, self_name, globals):
|
||||
# fields contains both real fields and InitVar pseudo-fields.
|
||||
|
||||
# Make sure we don't have fields without defaults following fields
|
||||
|
@ -490,12 +491,15 @@ def _init_fn(fields, frozen, has_post_init, self_name):
|
|||
raise TypeError(f'non-default argument {f.name!r} '
|
||||
'follows default argument')
|
||||
|
||||
globals = {'MISSING': MISSING,
|
||||
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
|
||||
locals = {f'_type_{f.name}': f.type for f in fields}
|
||||
locals.update({
|
||||
'MISSING': MISSING,
|
||||
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
|
||||
})
|
||||
|
||||
body_lines = []
|
||||
for f in fields:
|
||||
line = _field_init(f, frozen, globals, self_name)
|
||||
line = _field_init(f, frozen, locals, self_name)
|
||||
# line is None means that this field doesn't require
|
||||
# initialization (it's a pseudo-field). Just skip it.
|
||||
if line:
|
||||
|
@ -511,7 +515,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
|
|||
if not body_lines:
|
||||
body_lines = ['pass']
|
||||
|
||||
locals = {f'_type_{f.name}': f.type for f in fields}
|
||||
return _create_fn('__init__',
|
||||
[self_name] + [_init_param(f) for f in fields if f.init],
|
||||
body_lines,
|
||||
|
@ -520,20 +523,19 @@ def _init_fn(fields, frozen, has_post_init, self_name):
|
|||
return_type=None)
|
||||
|
||||
|
||||
def _repr_fn(fields):
|
||||
def _repr_fn(fields, globals):
|
||||
fn = _create_fn('__repr__',
|
||||
('self',),
|
||||
['return self.__class__.__qualname__ + f"(' +
|
||||
', '.join([f"{f.name}={{self.{f.name}!r}}"
|
||||
for f in fields]) +
|
||||
')"'])
|
||||
')"'],
|
||||
globals=globals)
|
||||
return _recursive_repr(fn)
|
||||
|
||||
|
||||
def _frozen_get_del_attr(cls, fields):
|
||||
# XXX: globals is modified on the first call to _create_fn, then
|
||||
# the modified version is used in the second call. Is this okay?
|
||||
globals = {'cls': cls,
|
||||
def _frozen_get_del_attr(cls, fields, globals):
|
||||
locals = {'cls': cls,
|
||||
'FrozenInstanceError': FrozenInstanceError}
|
||||
if fields:
|
||||
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
|
||||
|
@ -545,17 +547,19 @@ def _frozen_get_del_attr(cls, fields):
|
|||
(f'if type(self) is cls or name in {fields_str}:',
|
||||
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
|
||||
f'super(cls, self).__setattr__(name, value)'),
|
||||
locals=locals,
|
||||
globals=globals),
|
||||
_create_fn('__delattr__',
|
||||
('self', 'name'),
|
||||
(f'if type(self) is cls or name in {fields_str}:',
|
||||
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
|
||||
f'super(cls, self).__delattr__(name)'),
|
||||
locals=locals,
|
||||
globals=globals),
|
||||
)
|
||||
|
||||
|
||||
def _cmp_fn(name, op, self_tuple, other_tuple):
|
||||
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
|
||||
# Create a comparison function. If the fields in the object are
|
||||
# named 'x' and 'y', then self_tuple is the string
|
||||
# '(self.x,self.y)' and other_tuple is the string
|
||||
|
@ -565,14 +569,16 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
|
|||
('self', 'other'),
|
||||
[ 'if other.__class__ is self.__class__:',
|
||||
f' return {self_tuple}{op}{other_tuple}',
|
||||
'return NotImplemented'])
|
||||
'return NotImplemented'],
|
||||
globals=globals)
|
||||
|
||||
|
||||
def _hash_fn(fields):
|
||||
def _hash_fn(fields, globals):
|
||||
self_tuple = _tuple_str('self', fields)
|
||||
return _create_fn('__hash__',
|
||||
('self',),
|
||||
[f'return hash({self_tuple})'])
|
||||
[f'return hash({self_tuple})'],
|
||||
globals=globals)
|
||||
|
||||
|
||||
def _is_classvar(a_type, typing):
|
||||
|
@ -744,14 +750,14 @@ def _set_new_attribute(cls, name, value):
|
|||
# take. The common case is to do nothing, so instead of providing a
|
||||
# function that is a no-op, use None to signify that.
|
||||
|
||||
def _hash_set_none(cls, fields):
|
||||
def _hash_set_none(cls, fields, globals):
|
||||
return None
|
||||
|
||||
def _hash_add(cls, fields):
|
||||
def _hash_add(cls, fields, globals):
|
||||
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
|
||||
return _hash_fn(flds)
|
||||
return _hash_fn(flds, globals)
|
||||
|
||||
def _hash_exception(cls, fields):
|
||||
def _hash_exception(cls, fields, globals):
|
||||
# Raise an exception.
|
||||
raise TypeError(f'Cannot overwrite attribute __hash__ '
|
||||
f'in class {cls.__name__}')
|
||||
|
@ -793,6 +799,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
|
|||
# is defined by the base class, which is found first.
|
||||
fields = {}
|
||||
|
||||
if cls.__module__ in sys.modules:
|
||||
globals = sys.modules[cls.__module__].__dict__
|
||||
else:
|
||||
# Theoretically this can happen if someone writes
|
||||
# a custom string to cls.__module__. In which case
|
||||
# such dataclass won't be fully introspectable
|
||||
# (w.r.t. typing.get_type_hints) but will still function
|
||||
# correctly.
|
||||
globals = {}
|
||||
|
||||
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
|
||||
unsafe_hash, frozen))
|
||||
|
||||
|
@ -902,6 +918,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
|
|||
# if possible.
|
||||
'__dataclass_self__' if 'self' in fields
|
||||
else 'self',
|
||||
globals,
|
||||
))
|
||||
|
||||
# Get the fields as a list, and include only real fields. This is
|
||||
|
@ -910,7 +927,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
|
|||
|
||||
if repr:
|
||||
flds = [f for f in field_list if f.repr]
|
||||
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
|
||||
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
|
||||
|
||||
if eq:
|
||||
# Create _eq__ method. There's no need for a __ne__ method,
|
||||
|
@ -920,7 +937,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
|
|||
other_tuple = _tuple_str('other', flds)
|
||||
_set_new_attribute(cls, '__eq__',
|
||||
_cmp_fn('__eq__', '==',
|
||||
self_tuple, other_tuple))
|
||||
self_tuple, other_tuple,
|
||||
globals=globals))
|
||||
|
||||
if order:
|
||||
# Create and set the ordering methods.
|
||||
|
@ -933,13 +951,14 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
|
|||
('__ge__', '>='),
|
||||
]:
|
||||
if _set_new_attribute(cls, name,
|
||||
_cmp_fn(name, op, self_tuple, other_tuple)):
|
||||
_cmp_fn(name, op, self_tuple, other_tuple,
|
||||
globals=globals)):
|
||||
raise TypeError(f'Cannot overwrite attribute {name} '
|
||||
f'in class {cls.__name__}. Consider using '
|
||||
'functools.total_ordering')
|
||||
|
||||
if frozen:
|
||||
for fn in _frozen_get_del_attr(cls, field_list):
|
||||
for fn in _frozen_get_del_attr(cls, field_list, globals):
|
||||
if _set_new_attribute(cls, fn.__name__, fn):
|
||||
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
|
||||
f'in class {cls.__name__}')
|
||||
|
@ -952,7 +971,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
|
|||
if hash_action:
|
||||
# No need to call _set_new_attribute here, since by the time
|
||||
# we're here the overwriting is unconditional.
|
||||
cls.__hash__ = hash_action(cls, field_list)
|
||||
cls.__hash__ = hash_action(cls, field_list, globals)
|
||||
|
||||
if not getattr(cls, '__doc__'):
|
||||
# Create a class doc-string.
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Bar:
|
||||
foo: Foo
|
|
@ -10,6 +10,7 @@ import builtins
|
|||
import unittest
|
||||
from unittest.mock import Mock
|
||||
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
|
||||
from typing import get_type_hints
|
||||
from collections import deque, OrderedDict, namedtuple
|
||||
from functools import total_ordering
|
||||
|
||||
|
@ -2918,6 +2919,17 @@ class TestStringAnnotations(unittest.TestCase):
|
|||
# won't exist on the instance.
|
||||
self.assertNotIn('not_iv4', c.__dict__)
|
||||
|
||||
def test_text_annotations(self):
|
||||
from test import dataclass_textanno
|
||||
|
||||
self.assertEqual(
|
||||
get_type_hints(dataclass_textanno.Bar),
|
||||
{'foo': dataclass_textanno.Foo})
|
||||
self.assertEqual(
|
||||
get_type_hints(dataclass_textanno.Bar.__init__),
|
||||
{'foo': dataclass_textanno.Foo,
|
||||
'return': type(None)})
|
||||
|
||||
|
||||
class TestMakeDataclass(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Fix dataclasses to support forward references in type annotations
|
Loading…
Reference in New Issue