mirror of https://github.com/python/cpython
gh-109870: Dataclasses: batch up exec calls (gh-110851)
Instead of calling `exec()` once for each function added to a dataclass, only call `exec()` once per dataclass. This can lead to speed improvements of up to 20%.
This commit is contained in:
parent
7ebad77ad6
commit
8945b7ff55
|
@ -426,32 +426,95 @@ def _tuple_str(obj_name, fields):
|
|||
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
|
||||
|
||||
|
||||
def _create_fn(name, args, body, *, globals=None, locals=None,
|
||||
return_type=MISSING):
|
||||
# Note that we may mutate locals. Callers beware!
|
||||
# The only callers are internal to this module, so no
|
||||
# worries about external callers.
|
||||
if locals is None:
|
||||
locals = {}
|
||||
return_annotation = ''
|
||||
if return_type is not MISSING:
|
||||
locals['__dataclass_return_type__'] = return_type
|
||||
return_annotation = '->__dataclass_return_type__'
|
||||
args = ','.join(args)
|
||||
body = '\n'.join(f' {b}' for b in body)
|
||||
class _FuncBuilder:
|
||||
def __init__(self, globals):
|
||||
self.names = []
|
||||
self.src = []
|
||||
self.globals = globals
|
||||
self.locals = {}
|
||||
self.overwrite_errors = {}
|
||||
self.unconditional_adds = {}
|
||||
|
||||
# Compute the text of the entire function.
|
||||
txt = f' def {name}({args}){return_annotation}:\n{body}'
|
||||
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
|
||||
overwrite_error=False, unconditional_add=False, decorator=None):
|
||||
if locals is not None:
|
||||
self.locals.update(locals)
|
||||
|
||||
# Free variables in exec are resolved in the global namespace.
|
||||
# The global namespace we have is user-provided, so we can't modify it for
|
||||
# our purposes. So we put the things we need into locals and introduce a
|
||||
# scope to allow the function we're creating to close over them.
|
||||
local_vars = ', '.join(locals.keys())
|
||||
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
|
||||
ns = {}
|
||||
exec(txt, globals, ns)
|
||||
return ns['__create_fn__'](**locals)
|
||||
# Keep track if this method is allowed to be overwritten if it already
|
||||
# exists in the class. The error is method-specific, so keep it with
|
||||
# the name. We'll use this when we generate all of the functions in
|
||||
# the add_fns_to_class call. overwrite_error is either True, in which
|
||||
# case we'll raise an error, or it's a string, in which case we'll
|
||||
# raise an error and append this string.
|
||||
if overwrite_error:
|
||||
self.overwrite_errors[name] = overwrite_error
|
||||
|
||||
# Should this function always overwrite anything that's already in the
|
||||
# class? The default is to not overwrite a function that already
|
||||
# exists.
|
||||
if unconditional_add:
|
||||
self.unconditional_adds[name] = True
|
||||
|
||||
self.names.append(name)
|
||||
|
||||
if return_type is not MISSING:
|
||||
self.locals[f'__dataclass_{name}_return_type__'] = return_type
|
||||
return_annotation = f'->__dataclass_{name}_return_type__'
|
||||
else:
|
||||
return_annotation = ''
|
||||
args = ','.join(args)
|
||||
body = '\n'.join(body)
|
||||
|
||||
# Compute the text of the entire function, add it to the text we're generating.
|
||||
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
|
||||
|
||||
def add_fns_to_class(self, cls):
|
||||
# The source to all of the functions we're generating.
|
||||
fns_src = '\n'.join(self.src)
|
||||
|
||||
# The locals they use.
|
||||
local_vars = ','.join(self.locals.keys())
|
||||
|
||||
# The names of all of the functions, used for the return value of the
|
||||
# outer function. Need to handle the 0-tuple specially.
|
||||
if len(self.names) == 0:
|
||||
return_names = '()'
|
||||
else:
|
||||
return_names =f'({",".join(self.names)},)'
|
||||
|
||||
# txt is the entire function we're going to execute, including the
|
||||
# bodies of the functions we're defining. Here's a greatly simplified
|
||||
# version:
|
||||
# def __create_fn__():
|
||||
# def __init__(self, x, y):
|
||||
# self.x = x
|
||||
# self.y = y
|
||||
# @recursive_repr
|
||||
# def __repr__(self):
|
||||
# return f"cls(x={self.x!r},y={self.y!r})"
|
||||
# return __init__,__repr__
|
||||
|
||||
txt = f"def __create_fn__({local_vars}):\n{fns_src}\n return {return_names}"
|
||||
ns = {}
|
||||
exec(txt, self.globals, ns)
|
||||
fns = ns['__create_fn__'](**self.locals)
|
||||
|
||||
# Now that we've generated the functions, assign them into cls.
|
||||
for name, fn in zip(self.names, fns):
|
||||
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
|
||||
if self.unconditional_adds.get(name, False):
|
||||
setattr(cls, name, fn)
|
||||
else:
|
||||
already_exists = _set_new_attribute(cls, name, fn)
|
||||
|
||||
# See if it's an error to overwrite this particular function.
|
||||
if already_exists and (msg_extra := self.overwrite_errors.get(name)):
|
||||
error_msg = (f'Cannot overwrite attribute {fn.__name__} '
|
||||
f'in class {cls.__name__}')
|
||||
if not msg_extra is True:
|
||||
error_msg = f'{error_msg} {msg_extra}'
|
||||
|
||||
raise TypeError(error_msg)
|
||||
|
||||
|
||||
def _field_assign(frozen, name, value, self_name):
|
||||
|
@ -462,8 +525,8 @@ def _field_assign(frozen, name, value, self_name):
|
|||
# self_name is what "self" is called in this function: don't
|
||||
# hard-code "self", since that might be a field name.
|
||||
if frozen:
|
||||
return f'__dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
|
||||
return f'{self_name}.{name}={value}'
|
||||
return f' __dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
|
||||
return f' {self_name}.{name}={value}'
|
||||
|
||||
|
||||
def _field_init(f, frozen, globals, self_name, slots):
|
||||
|
@ -546,7 +609,7 @@ def _init_param(f):
|
|||
|
||||
|
||||
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
||||
self_name, globals, slots):
|
||||
self_name, func_builder, slots):
|
||||
# fields contains both real fields and InitVar pseudo-fields.
|
||||
|
||||
# Make sure we don't have fields without defaults following fields
|
||||
|
@ -565,11 +628,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
|||
raise TypeError(f'non-default argument {f.name!r} '
|
||||
f'follows default argument {seen_default.name!r}')
|
||||
|
||||
locals = {f'__dataclass_type_{f.name}__': f.type for f in fields}
|
||||
locals.update({
|
||||
'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
||||
'__dataclass_builtins_object__': object,
|
||||
})
|
||||
locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
|
||||
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
||||
'__dataclass_builtins_object__': object,
|
||||
}
|
||||
}
|
||||
|
||||
body_lines = []
|
||||
for f in fields:
|
||||
|
@ -583,11 +646,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
|||
if has_post_init:
|
||||
params_str = ','.join(f.name for f in fields
|
||||
if f._field_type is _FIELD_INITVAR)
|
||||
body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})')
|
||||
body_lines.append(f' {self_name}.{_POST_INIT_NAME}({params_str})')
|
||||
|
||||
# If no body lines, use 'pass'.
|
||||
if not body_lines:
|
||||
body_lines = ['pass']
|
||||
body_lines = [' pass']
|
||||
|
||||
_init_params = [_init_param(f) for f in std_fields]
|
||||
if kw_only_fields:
|
||||
|
@ -596,68 +659,34 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
|||
# (instead of just concatenting the lists together).
|
||||
_init_params += ['*']
|
||||
_init_params += [_init_param(f) for f in kw_only_fields]
|
||||
return _create_fn('__init__',
|
||||
[self_name] + _init_params,
|
||||
body_lines,
|
||||
locals=locals,
|
||||
globals=globals,
|
||||
return_type=None)
|
||||
func_builder.add_fn('__init__',
|
||||
[self_name] + _init_params,
|
||||
body_lines,
|
||||
locals=locals,
|
||||
return_type=None)
|
||||
|
||||
|
||||
def _repr_fn(fields, globals):
|
||||
fn = _create_fn('__repr__',
|
||||
('self',),
|
||||
['return f"{self.__class__.__qualname__}(' +
|
||||
', '.join([f"{f.name}={{self.{f.name}!r}}"
|
||||
for f in fields]) +
|
||||
')"'],
|
||||
globals=globals)
|
||||
return recursive_repr()(fn)
|
||||
|
||||
|
||||
def _frozen_get_del_attr(cls, fields, globals):
|
||||
def _frozen_get_del_attr(cls, fields, func_builder):
|
||||
locals = {'cls': cls,
|
||||
'FrozenInstanceError': FrozenInstanceError}
|
||||
condition = 'type(self) is cls'
|
||||
if fields:
|
||||
condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}'
|
||||
return (_create_fn('__setattr__',
|
||||
('self', 'name', 'value'),
|
||||
(f'if {condition}:',
|
||||
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
|
||||
f'super(cls, self).__setattr__(name, value)'),
|
||||
locals=locals,
|
||||
globals=globals),
|
||||
_create_fn('__delattr__',
|
||||
('self', 'name'),
|
||||
(f'if {condition}:',
|
||||
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
|
||||
f'super(cls, self).__delattr__(name)'),
|
||||
locals=locals,
|
||||
globals=globals),
|
||||
)
|
||||
|
||||
|
||||
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
|
||||
# Create a comparison function. If the fields in the object are
|
||||
# named 'x' and 'y', then self_tuple is the string
|
||||
# '(self.x,self.y)' and other_tuple is the string
|
||||
# '(other.x,other.y)'.
|
||||
|
||||
return _create_fn(name,
|
||||
('self', 'other'),
|
||||
[ 'if other.__class__ is self.__class__:',
|
||||
f' return {self_tuple}{op}{other_tuple}',
|
||||
'return NotImplemented'],
|
||||
globals=globals)
|
||||
|
||||
|
||||
def _hash_fn(fields, globals):
|
||||
self_tuple = _tuple_str('self', fields)
|
||||
return _create_fn('__hash__',
|
||||
('self',),
|
||||
[f'return hash({self_tuple})'],
|
||||
globals=globals)
|
||||
func_builder.add_fn('__setattr__',
|
||||
('self', 'name', 'value'),
|
||||
(f' if {condition}:',
|
||||
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
|
||||
f' super(cls, self).__setattr__(name, value)'),
|
||||
locals=locals,
|
||||
overwrite_error=True)
|
||||
func_builder.add_fn('__delattr__',
|
||||
('self', 'name'),
|
||||
(f' if {condition}:',
|
||||
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
|
||||
f' super(cls, self).__delattr__(name)'),
|
||||
locals=locals,
|
||||
overwrite_error=True)
|
||||
|
||||
|
||||
def _is_classvar(a_type, typing):
|
||||
|
@ -834,19 +863,11 @@ def _get_field(cls, a_name, a_type, default_kw_only):
|
|||
|
||||
return f
|
||||
|
||||
def _set_qualname(cls, value):
|
||||
# Ensure that the functions returned from _create_fn uses the proper
|
||||
# __qualname__ (the class they belong to).
|
||||
if isinstance(value, FunctionType):
|
||||
value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
|
||||
return value
|
||||
|
||||
def _set_new_attribute(cls, name, value):
|
||||
# Never overwrites an existing attribute. Returns True if the
|
||||
# attribute already exists.
|
||||
if name in cls.__dict__:
|
||||
return True
|
||||
_set_qualname(cls, value)
|
||||
setattr(cls, name, value)
|
||||
return False
|
||||
|
||||
|
@ -856,14 +877,22 @@ def _set_new_attribute(cls, name, value):
|
|||
# take. The common case is to do nothing, so instead of providing a
|
||||
# function that is a no-op, use None to signify that.
|
||||
|
||||
def _hash_set_none(cls, fields, globals):
|
||||
return None
|
||||
def _hash_set_none(cls, fields, func_builder):
|
||||
# It's sort of a hack that I'm setting this here, instead of at
|
||||
# func_builder.add_fns_to_class time, but since this is an exceptional case
|
||||
# (it's not setting an attribute to a function, but to a scalar value),
|
||||
# just do it directly here. I might come to regret this.
|
||||
cls.__hash__ = None
|
||||
|
||||
def _hash_add(cls, fields, globals):
|
||||
def _hash_add(cls, fields, func_builder):
|
||||
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
|
||||
return _set_qualname(cls, _hash_fn(flds, globals))
|
||||
self_tuple = _tuple_str('self', flds)
|
||||
func_builder.add_fn('__hash__',
|
||||
('self',),
|
||||
[f' return hash({self_tuple})'],
|
||||
unconditional_add=True)
|
||||
|
||||
def _hash_exception(cls, fields, globals):
|
||||
def _hash_exception(cls, fields, func_builder):
|
||||
# Raise an exception.
|
||||
raise TypeError(f'Cannot overwrite attribute __hash__ '
|
||||
f'in class {cls.__name__}')
|
||||
|
@ -1041,24 +1070,26 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
|||
(std_init_fields,
|
||||
kw_only_init_fields) = _fields_in_init_order(all_init_fields)
|
||||
|
||||
func_builder = _FuncBuilder(globals)
|
||||
|
||||
if init:
|
||||
# Does this class have a post-init function?
|
||||
has_post_init = hasattr(cls, _POST_INIT_NAME)
|
||||
|
||||
_set_new_attribute(cls, '__init__',
|
||||
_init_fn(all_init_fields,
|
||||
std_init_fields,
|
||||
kw_only_init_fields,
|
||||
frozen,
|
||||
has_post_init,
|
||||
# The name to use for the "self"
|
||||
# param in __init__. Use "self"
|
||||
# if possible.
|
||||
'__dataclass_self__' if 'self' in fields
|
||||
else 'self',
|
||||
globals,
|
||||
slots,
|
||||
))
|
||||
_init_fn(all_init_fields,
|
||||
std_init_fields,
|
||||
kw_only_init_fields,
|
||||
frozen,
|
||||
has_post_init,
|
||||
# The name to use for the "self"
|
||||
# param in __init__. Use "self"
|
||||
# if possible.
|
||||
'__dataclass_self__' if 'self' in fields
|
||||
else 'self',
|
||||
func_builder,
|
||||
slots,
|
||||
)
|
||||
|
||||
_set_new_attribute(cls, '__replace__', _replace)
|
||||
|
||||
# Get the fields as a list, and include only real fields. This is
|
||||
|
@ -1067,7 +1098,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
|||
|
||||
if repr:
|
||||
flds = [f for f in field_list if f.repr]
|
||||
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
|
||||
func_builder.add_fn('__repr__',
|
||||
('self',),
|
||||
[' return f"{self.__class__.__qualname__}(' +
|
||||
', '.join([f"{f.name}={{self.{f.name}!r}}"
|
||||
for f in flds]) + ')"'],
|
||||
locals={'__dataclasses_recursive_repr': recursive_repr},
|
||||
decorator="@__dataclasses_recursive_repr()")
|
||||
|
||||
if eq:
|
||||
# Create __eq__ method. There's no need for a __ne__ method,
|
||||
|
@ -1075,16 +1112,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
|||
cmp_fields = (field for field in field_list if field.compare)
|
||||
terms = [f'self.{field.name}==other.{field.name}' for field in cmp_fields]
|
||||
field_comparisons = ' and '.join(terms) or 'True'
|
||||
body = [f'if self is other:',
|
||||
f' return True',
|
||||
f'if other.__class__ is self.__class__:',
|
||||
f' return {field_comparisons}',
|
||||
f'return NotImplemented']
|
||||
func = _create_fn('__eq__',
|
||||
('self', 'other'),
|
||||
body,
|
||||
globals=globals)
|
||||
_set_new_attribute(cls, '__eq__', func)
|
||||
func_builder.add_fn('__eq__',
|
||||
('self', 'other'),
|
||||
[ ' if self is other:',
|
||||
' return True',
|
||||
' if other.__class__ is self.__class__:',
|
||||
f' return {field_comparisons}',
|
||||
' return NotImplemented'])
|
||||
|
||||
if order:
|
||||
# Create and set the ordering methods.
|
||||
|
@ -1096,18 +1130,19 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
|||
('__gt__', '>'),
|
||||
('__ge__', '>='),
|
||||
]:
|
||||
if _set_new_attribute(cls, name,
|
||||
_cmp_fn(name, op, self_tuple, other_tuple,
|
||||
globals=globals)):
|
||||
raise TypeError(f'Cannot overwrite attribute {name} '
|
||||
f'in class {cls.__name__}. Consider using '
|
||||
'functools.total_ordering')
|
||||
# Create a comparison function. If the fields in the object are
|
||||
# named 'x' and 'y', then self_tuple is the string
|
||||
# '(self.x,self.y)' and other_tuple is the string
|
||||
# '(other.x,other.y)'.
|
||||
func_builder.add_fn(name,
|
||||
('self', 'other'),
|
||||
[ ' if other.__class__ is self.__class__:',
|
||||
f' return {self_tuple}{op}{other_tuple}',
|
||||
' return NotImplemented'],
|
||||
overwrite_error='Consider using functools.total_ordering')
|
||||
|
||||
if frozen:
|
||||
for fn in _frozen_get_del_attr(cls, field_list, globals):
|
||||
if _set_new_attribute(cls, fn.__name__, fn):
|
||||
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
|
||||
f'in class {cls.__name__}')
|
||||
_frozen_get_del_attr(cls, field_list, func_builder)
|
||||
|
||||
# Decide if/how we're going to create a hash function.
|
||||
hash_action = _hash_action[bool(unsafe_hash),
|
||||
|
@ -1115,9 +1150,12 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
|||
bool(frozen),
|
||||
has_explicit_hash]
|
||||
if hash_action:
|
||||
# No need to call _set_new_attribute here, since by the time
|
||||
# we're here the overwriting is unconditional.
|
||||
cls.__hash__ = hash_action(cls, field_list, globals)
|
||||
cls.__hash__ = hash_action(cls, field_list, func_builder)
|
||||
|
||||
# Generate the methods and add them to the class. This needs to be done
|
||||
# before the __doc__ logic below, since inspect will look at the __init__
|
||||
# signature.
|
||||
func_builder.add_fns_to_class(cls)
|
||||
|
||||
if not getattr(cls, '__doc__'):
|
||||
# Create a class doc-string.
|
||||
|
@ -1130,7 +1168,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
|||
cls.__doc__ = (cls.__name__ + text_sig)
|
||||
|
||||
if match_args:
|
||||
# I could probably compute this once
|
||||
# I could probably compute this once.
|
||||
_set_new_attribute(cls, '__match_args__',
|
||||
tuple(f.name for f in std_init_fields))
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
Dataclasses now calls :func:`exec` once per dataclass, instead of once
|
||||
per method being added. This can speed up dataclass creation by up to
|
||||
20%.
|
Loading…
Reference in New Issue