bpo-34776: Fix dataclasses to support __future__ "annotations" mode (#9518)

This commit is contained in:
Yury Selivanov 2019-12-09 09:54:20 -05:00 committed by Łukasz Langa
parent bba873e633
commit d219cc4180
4 changed files with 78 additions and 34 deletions

View File

@ -378,23 +378,24 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
# worries about external callers.
if locals is None:
locals = {}
# __builtins__ may be the "builtins" module or
# the value of its "__dict__",
# so make sure "__builtins__" is the module.
if globals is not None and '__builtins__' not in globals:
globals['__builtins__'] = builtins
if 'BUILTINS' not in locals:
locals['BUILTINS'] = builtins
return_annotation = ''
if return_type is not MISSING:
locals['_return_type'] = return_type
return_annotation = '->_return_type'
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
body = '\n'.join(f' {b}' for b in body)
# Compute the text of the entire function.
txt = f'def {name}({args}){return_annotation}:\n{body}'
txt = f' def {name}({args}){return_annotation}:\n{body}'
exec(txt, globals, locals)
return locals[name]
local_vars = ', '.join(locals.keys())
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
ns = {}
exec(txt, globals, ns)
return ns['__create_fn__'](**locals)
def _field_assign(frozen, name, value, self_name):
@ -405,7 +406,7 @@ def _field_assign(frozen, name, value, self_name):
# self_name is what "self" is called in this function: don't
# hard-code "self", since that might be a field name.
if frozen:
return f'__builtins__.object.__setattr__({self_name},{name!r},{value})'
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
return f'{self_name}.{name}={value}'
@ -482,7 +483,7 @@ def _init_param(f):
return f'{f.name}:_type_{f.name}{default}'
def _init_fn(fields, frozen, has_post_init, self_name):
def _init_fn(fields, frozen, has_post_init, self_name, globals):
# fields contains both real fields and InitVar pseudo-fields.
# Make sure we don't have fields without defaults following fields
@ -500,12 +501,15 @@ def _init_fn(fields, frozen, has_post_init, self_name):
raise TypeError(f'non-default argument {f.name!r} '
'follows default argument')
globals = {'MISSING': MISSING,
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
locals = {f'_type_{f.name}': f.type for f in fields}
locals.update({
'MISSING': MISSING,
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
})
body_lines = []
for f in fields:
line = _field_init(f, frozen, globals, self_name)
line = _field_init(f, frozen, locals, self_name)
# line is None means that this field doesn't require
# initialization (it's a pseudo-field). Just skip it.
if line:
@ -521,7 +525,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
if not body_lines:
body_lines = ['pass']
locals = {f'_type_{f.name}': f.type for f in fields}
return _create_fn('__init__',
[self_name] + [_init_param(f) for f in fields if f.init],
body_lines,
@ -530,20 +533,19 @@ def _init_fn(fields, frozen, has_post_init, self_name):
return_type=None)
def _repr_fn(fields):
def _repr_fn(fields, globals):
fn = _create_fn('__repr__',
('self',),
['return self.__class__.__qualname__ + f"(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'])
')"'],
globals=globals)
return _recursive_repr(fn)
def _frozen_get_del_attr(cls, fields):
# XXX: globals is modified on the first call to _create_fn, then
# the modified version is used in the second call. Is this okay?
globals = {'cls': cls,
def _frozen_get_del_attr(cls, fields, globals):
locals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
if fields:
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
@ -555,17 +557,19 @@ def _frozen_get_del_attr(cls, fields):
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
locals=locals,
globals=globals),
_create_fn('__delattr__',
('self', 'name'),
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
locals=locals,
globals=globals),
)
def _cmp_fn(name, op, self_tuple, other_tuple):
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
# Create a comparison function. If the fields in the object are
# named 'x' and 'y', then self_tuple is the string
# '(self.x,self.y)' and other_tuple is the string
@ -575,14 +579,16 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
('self', 'other'),
[ 'if other.__class__ is self.__class__:',
f' return {self_tuple}{op}{other_tuple}',
'return NotImplemented'])
'return NotImplemented'],
globals=globals)
def _hash_fn(fields):
def _hash_fn(fields, globals):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
('self',),
[f'return hash({self_tuple})'])
[f'return hash({self_tuple})'],
globals=globals)
def _is_classvar(a_type, typing):
@ -755,14 +761,14 @@ def _set_new_attribute(cls, name, value):
# take. The common case is to do nothing, so instead of providing a
# function that is a no-op, use None to signify that.
def _hash_set_none(cls, fields):
def _hash_set_none(cls, fields, globals):
return None
def _hash_add(cls, fields):
def _hash_add(cls, fields, globals):
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
return _hash_fn(flds)
return _hash_fn(flds, globals)
def _hash_exception(cls, fields):
def _hash_exception(cls, fields, globals):
# Raise an exception.
raise TypeError(f'Cannot overwrite attribute __hash__ '
f'in class {cls.__name__}')
@ -804,6 +810,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# is defined by the base class, which is found first.
fields = {}
if cls.__module__ in sys.modules:
globals = sys.modules[cls.__module__].__dict__
else:
# Theoretically this can happen if someone writes
# a custom string to cls.__module__. In which case
# such dataclass won't be fully introspectable
# (w.r.t. typing.get_type_hints) but will still function
# correctly.
globals = {}
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
unsafe_hash, frozen))
@ -913,6 +929,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# if possible.
'__dataclass_self__' if 'self' in fields
else 'self',
globals,
))
# Get the fields as a list, and include only real fields. This is
@ -921,7 +938,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
if repr:
flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
if eq:
# Create _eq__ method. There's no need for a __ne__ method,
@ -931,7 +948,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
other_tuple = _tuple_str('other', flds)
_set_new_attribute(cls, '__eq__',
_cmp_fn('__eq__', '==',
self_tuple, other_tuple))
self_tuple, other_tuple,
globals=globals))
if order:
# Create and set the ordering methods.
@ -944,13 +962,14 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
('__ge__', '>='),
]:
if _set_new_attribute(cls, name,
_cmp_fn(name, op, self_tuple, other_tuple)):
_cmp_fn(name, op, self_tuple, other_tuple,
globals=globals)):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in class {cls.__name__}. Consider using '
'functools.total_ordering')
if frozen:
for fn in _frozen_get_del_attr(cls, field_list):
for fn in _frozen_get_del_attr(cls, field_list, globals):
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
@ -963,7 +982,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
if hash_action:
# No need to call _set_new_attribute here, since by the time
# we're here the overwriting is unconditional.
cls.__hash__ = hash_action(cls, field_list)
cls.__hash__ = hash_action(cls, field_list, globals)
if not getattr(cls, '__doc__'):
# Create a class doc-string.

View File

@ -0,0 +1,12 @@
from __future__ import annotations
import dataclasses
class Foo:
pass
@dataclasses.dataclass
class Bar:
foo: Foo

View File

@ -10,6 +10,7 @@ import builtins
import unittest
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
from typing import get_type_hints
from collections import deque, OrderedDict, namedtuple
from functools import total_ordering
@ -2926,6 +2927,17 @@ class TestStringAnnotations(unittest.TestCase):
# won't exist on the instance.
self.assertNotIn('not_iv4', c.__dict__)
def test_text_annotations(self):
from test import dataclass_textanno
self.assertEqual(
get_type_hints(dataclass_textanno.Bar),
{'foo': dataclass_textanno.Foo})
self.assertEqual(
get_type_hints(dataclass_textanno.Bar.__init__),
{'foo': dataclass_textanno.Foo,
'return': type(None)})
class TestMakeDataclass(unittest.TestCase):
def test_simple(self):

View File

@ -0,0 +1 @@
Fix dataclasses to support forward references in type annotations