gh-109870: Dataclasses: batch up exec calls (gh-110851)

Instead of calling `exec()` once for each function added to a dataclass, only call `exec()` once per dataclass. This can lead to speed improvements of up to 20%.
This commit is contained in:
Eric V. Smith 2024-03-25 19:59:14 -04:00 committed by GitHub
parent 7ebad77ad6
commit 8945b7ff55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 181 additions and 140 deletions

View File

@ -426,32 +426,95 @@ def _tuple_str(obj_name, fields):
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
def _create_fn(name, args, body, *, globals=None, locals=None, class _FuncBuilder:
return_type=MISSING): def __init__(self, globals):
# Note that we may mutate locals. Callers beware! self.names = []
# The only callers are internal to this module, so no self.src = []
# worries about external callers. self.globals = globals
if locals is None: self.locals = {}
locals = {} self.overwrite_errors = {}
return_annotation = '' self.unconditional_adds = {}
if return_type is not MISSING:
locals['__dataclass_return_type__'] = return_type
return_annotation = '->__dataclass_return_type__'
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
# Compute the text of the entire function. def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
txt = f' def {name}({args}){return_annotation}:\n{body}' overwrite_error=False, unconditional_add=False, decorator=None):
if locals is not None:
self.locals.update(locals)
# Free variables in exec are resolved in the global namespace. # Keep track if this method is allowed to be overwritten if it already
# The global namespace we have is user-provided, so we can't modify it for # exists in the class. The error is method-specific, so keep it with
# our purposes. So we put the things we need into locals and introduce a # the name. We'll use this when we generate all of the functions in
# scope to allow the function we're creating to close over them. # the add_fns_to_class call. overwrite_error is either True, in which
local_vars = ', '.join(locals.keys()) # case we'll raise an error, or it's a string, in which case we'll
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" # raise an error and append this string.
ns = {} if overwrite_error:
exec(txt, globals, ns) self.overwrite_errors[name] = overwrite_error
return ns['__create_fn__'](**locals)
# Should this function always overwrite anything that's already in the
# class? The default is to not overwrite a function that already
# exists.
if unconditional_add:
self.unconditional_adds[name] = True
self.names.append(name)
if return_type is not MISSING:
self.locals[f'__dataclass_{name}_return_type__'] = return_type
return_annotation = f'->__dataclass_{name}_return_type__'
else:
return_annotation = ''
args = ','.join(args)
body = '\n'.join(body)
# Compute the text of the entire function, add it to the text we're generating.
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
def add_fns_to_class(self, cls):
# The source to all of the functions we're generating.
fns_src = '\n'.join(self.src)
# The locals they use.
local_vars = ','.join(self.locals.keys())
# The names of all of the functions, used for the return value of the
# outer function. Need to handle the 0-tuple specially.
if len(self.names) == 0:
return_names = '()'
else:
return_names =f'({",".join(self.names)},)'
# txt is the entire function we're going to execute, including the
# bodies of the functions we're defining. Here's a greatly simplified
# version:
# def __create_fn__():
# def __init__(self, x, y):
# self.x = x
# self.y = y
# @recursive_repr
# def __repr__(self):
# return f"cls(x={self.x!r},y={self.y!r})"
# return __init__,__repr__
txt = f"def __create_fn__({local_vars}):\n{fns_src}\n return {return_names}"
ns = {}
exec(txt, self.globals, ns)
fns = ns['__create_fn__'](**self.locals)
# Now that we've generated the functions, assign them into cls.
for name, fn in zip(self.names, fns):
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
if self.unconditional_adds.get(name, False):
setattr(cls, name, fn)
else:
already_exists = _set_new_attribute(cls, name, fn)
# See if it's an error to overwrite this particular function.
if already_exists and (msg_extra := self.overwrite_errors.get(name)):
error_msg = (f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
if not msg_extra is True:
error_msg = f'{error_msg} {msg_extra}'
raise TypeError(error_msg)
def _field_assign(frozen, name, value, self_name): def _field_assign(frozen, name, value, self_name):
@ -462,8 +525,8 @@ def _field_assign(frozen, name, value, self_name):
# self_name is what "self" is called in this function: don't # self_name is what "self" is called in this function: don't
# hard-code "self", since that might be a field name. # hard-code "self", since that might be a field name.
if frozen: if frozen:
return f'__dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})' return f' __dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
return f'{self_name}.{name}={value}' return f' {self_name}.{name}={value}'
def _field_init(f, frozen, globals, self_name, slots): def _field_init(f, frozen, globals, self_name, slots):
@ -546,7 +609,7 @@ def _init_param(f):
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
self_name, globals, slots): self_name, func_builder, slots):
# fields contains both real fields and InitVar pseudo-fields. # fields contains both real fields and InitVar pseudo-fields.
# Make sure we don't have fields without defaults following fields # Make sure we don't have fields without defaults following fields
@ -565,11 +628,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
raise TypeError(f'non-default argument {f.name!r} ' raise TypeError(f'non-default argument {f.name!r} '
f'follows default argument {seen_default.name!r}') f'follows default argument {seen_default.name!r}')
locals = {f'__dataclass_type_{f.name}__': f.type for f in fields} locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
locals.update({ **{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY, '__dataclass_builtins_object__': object,
'__dataclass_builtins_object__': object, }
}) }
body_lines = [] body_lines = []
for f in fields: for f in fields:
@ -583,11 +646,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
if has_post_init: if has_post_init:
params_str = ','.join(f.name for f in fields params_str = ','.join(f.name for f in fields
if f._field_type is _FIELD_INITVAR) if f._field_type is _FIELD_INITVAR)
body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})') body_lines.append(f' {self_name}.{_POST_INIT_NAME}({params_str})')
# If no body lines, use 'pass'. # If no body lines, use 'pass'.
if not body_lines: if not body_lines:
body_lines = ['pass'] body_lines = [' pass']
_init_params = [_init_param(f) for f in std_fields] _init_params = [_init_param(f) for f in std_fields]
if kw_only_fields: if kw_only_fields:
@ -596,68 +659,34 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
# (instead of just concatenting the lists together). # (instead of just concatenting the lists together).
_init_params += ['*'] _init_params += ['*']
_init_params += [_init_param(f) for f in kw_only_fields] _init_params += [_init_param(f) for f in kw_only_fields]
return _create_fn('__init__', func_builder.add_fn('__init__',
[self_name] + _init_params, [self_name] + _init_params,
body_lines, body_lines,
locals=locals, locals=locals,
globals=globals, return_type=None)
return_type=None)
def _repr_fn(fields, globals): def _frozen_get_del_attr(cls, fields, func_builder):
fn = _create_fn('__repr__',
('self',),
['return f"{self.__class__.__qualname__}(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'],
globals=globals)
return recursive_repr()(fn)
def _frozen_get_del_attr(cls, fields, globals):
locals = {'cls': cls, locals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError} 'FrozenInstanceError': FrozenInstanceError}
condition = 'type(self) is cls' condition = 'type(self) is cls'
if fields: if fields:
condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}' condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}'
return (_create_fn('__setattr__',
('self', 'name', 'value'),
(f'if {condition}:',
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
locals=locals,
globals=globals),
_create_fn('__delattr__',
('self', 'name'),
(f'if {condition}:',
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
locals=locals,
globals=globals),
)
func_builder.add_fn('__setattr__',
def _cmp_fn(name, op, self_tuple, other_tuple, globals): ('self', 'name', 'value'),
# Create a comparison function. If the fields in the object are (f' if {condition}:',
# named 'x' and 'y', then self_tuple is the string ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
# '(self.x,self.y)' and other_tuple is the string f' super(cls, self).__setattr__(name, value)'),
# '(other.x,other.y)'. locals=locals,
overwrite_error=True)
return _create_fn(name, func_builder.add_fn('__delattr__',
('self', 'other'), ('self', 'name'),
[ 'if other.__class__ is self.__class__:', (f' if {condition}:',
f' return {self_tuple}{op}{other_tuple}', ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
'return NotImplemented'], f' super(cls, self).__delattr__(name)'),
globals=globals) locals=locals,
overwrite_error=True)
def _hash_fn(fields, globals):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
('self',),
[f'return hash({self_tuple})'],
globals=globals)
def _is_classvar(a_type, typing): def _is_classvar(a_type, typing):
@ -834,19 +863,11 @@ def _get_field(cls, a_name, a_type, default_kw_only):
return f return f
def _set_qualname(cls, value):
# Ensure that the functions returned from _create_fn uses the proper
# __qualname__ (the class they belong to).
if isinstance(value, FunctionType):
value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
return value
def _set_new_attribute(cls, name, value): def _set_new_attribute(cls, name, value):
# Never overwrites an existing attribute. Returns True if the # Never overwrites an existing attribute. Returns True if the
# attribute already exists. # attribute already exists.
if name in cls.__dict__: if name in cls.__dict__:
return True return True
_set_qualname(cls, value)
setattr(cls, name, value) setattr(cls, name, value)
return False return False
@ -856,14 +877,22 @@ def _set_new_attribute(cls, name, value):
# take. The common case is to do nothing, so instead of providing a # take. The common case is to do nothing, so instead of providing a
# function that is a no-op, use None to signify that. # function that is a no-op, use None to signify that.
def _hash_set_none(cls, fields, globals): def _hash_set_none(cls, fields, func_builder):
return None # It's sort of a hack that I'm setting this here, instead of at
# func_builder.add_fns_to_class time, but since this is an exceptional case
# (it's not setting an attribute to a function, but to a scalar value),
# just do it directly here. I might come to regret this.
cls.__hash__ = None
def _hash_add(cls, fields, globals): def _hash_add(cls, fields, func_builder):
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
return _set_qualname(cls, _hash_fn(flds, globals)) self_tuple = _tuple_str('self', flds)
func_builder.add_fn('__hash__',
('self',),
[f' return hash({self_tuple})'],
unconditional_add=True)
def _hash_exception(cls, fields, globals): def _hash_exception(cls, fields, func_builder):
# Raise an exception. # Raise an exception.
raise TypeError(f'Cannot overwrite attribute __hash__ ' raise TypeError(f'Cannot overwrite attribute __hash__ '
f'in class {cls.__name__}') f'in class {cls.__name__}')
@ -1041,24 +1070,26 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
(std_init_fields, (std_init_fields,
kw_only_init_fields) = _fields_in_init_order(all_init_fields) kw_only_init_fields) = _fields_in_init_order(all_init_fields)
func_builder = _FuncBuilder(globals)
if init: if init:
# Does this class have a post-init function? # Does this class have a post-init function?
has_post_init = hasattr(cls, _POST_INIT_NAME) has_post_init = hasattr(cls, _POST_INIT_NAME)
_set_new_attribute(cls, '__init__', _init_fn(all_init_fields,
_init_fn(all_init_fields, std_init_fields,
std_init_fields, kw_only_init_fields,
kw_only_init_fields, frozen,
frozen, has_post_init,
has_post_init, # The name to use for the "self"
# The name to use for the "self" # param in __init__. Use "self"
# param in __init__. Use "self" # if possible.
# if possible. '__dataclass_self__' if 'self' in fields
'__dataclass_self__' if 'self' in fields else 'self',
else 'self', func_builder,
globals, slots,
slots, )
))
_set_new_attribute(cls, '__replace__', _replace) _set_new_attribute(cls, '__replace__', _replace)
# Get the fields as a list, and include only real fields. This is # Get the fields as a list, and include only real fields. This is
@ -1067,7 +1098,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
if repr: if repr:
flds = [f for f in field_list if f.repr] flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals)) func_builder.add_fn('__repr__',
('self',),
[' return f"{self.__class__.__qualname__}(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in flds]) + ')"'],
locals={'__dataclasses_recursive_repr': recursive_repr},
decorator="@__dataclasses_recursive_repr()")
if eq: if eq:
# Create __eq__ method. There's no need for a __ne__ method, # Create __eq__ method. There's no need for a __ne__ method,
@ -1075,16 +1112,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
cmp_fields = (field for field in field_list if field.compare) cmp_fields = (field for field in field_list if field.compare)
terms = [f'self.{field.name}==other.{field.name}' for field in cmp_fields] terms = [f'self.{field.name}==other.{field.name}' for field in cmp_fields]
field_comparisons = ' and '.join(terms) or 'True' field_comparisons = ' and '.join(terms) or 'True'
body = [f'if self is other:', func_builder.add_fn('__eq__',
f' return True', ('self', 'other'),
f'if other.__class__ is self.__class__:', [ ' if self is other:',
f' return {field_comparisons}', ' return True',
f'return NotImplemented'] ' if other.__class__ is self.__class__:',
func = _create_fn('__eq__', f' return {field_comparisons}',
('self', 'other'), ' return NotImplemented'])
body,
globals=globals)
_set_new_attribute(cls, '__eq__', func)
if order: if order:
# Create and set the ordering methods. # Create and set the ordering methods.
@ -1096,18 +1130,19 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
('__gt__', '>'), ('__gt__', '>'),
('__ge__', '>='), ('__ge__', '>='),
]: ]:
if _set_new_attribute(cls, name, # Create a comparison function. If the fields in the object are
_cmp_fn(name, op, self_tuple, other_tuple, # named 'x' and 'y', then self_tuple is the string
globals=globals)): # '(self.x,self.y)' and other_tuple is the string
raise TypeError(f'Cannot overwrite attribute {name} ' # '(other.x,other.y)'.
f'in class {cls.__name__}. Consider using ' func_builder.add_fn(name,
'functools.total_ordering') ('self', 'other'),
[ ' if other.__class__ is self.__class__:',
f' return {self_tuple}{op}{other_tuple}',
' return NotImplemented'],
overwrite_error='Consider using functools.total_ordering')
if frozen: if frozen:
for fn in _frozen_get_del_attr(cls, field_list, globals): _frozen_get_del_attr(cls, field_list, func_builder)
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
# Decide if/how we're going to create a hash function. # Decide if/how we're going to create a hash function.
hash_action = _hash_action[bool(unsafe_hash), hash_action = _hash_action[bool(unsafe_hash),
@ -1115,9 +1150,12 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
bool(frozen), bool(frozen),
has_explicit_hash] has_explicit_hash]
if hash_action: if hash_action:
# No need to call _set_new_attribute here, since by the time cls.__hash__ = hash_action(cls, field_list, func_builder)
# we're here the overwriting is unconditional.
cls.__hash__ = hash_action(cls, field_list, globals) # Generate the methods and add them to the class. This needs to be done
# before the __doc__ logic below, since inspect will look at the __init__
# signature.
func_builder.add_fns_to_class(cls)
if not getattr(cls, '__doc__'): if not getattr(cls, '__doc__'):
# Create a class doc-string. # Create a class doc-string.
@ -1130,7 +1168,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
cls.__doc__ = (cls.__name__ + text_sig) cls.__doc__ = (cls.__name__ + text_sig)
if match_args: if match_args:
# I could probably compute this once # I could probably compute this once.
_set_new_attribute(cls, '__match_args__', _set_new_attribute(cls, '__match_args__',
tuple(f.name for f in std_init_fields)) tuple(f.name for f in std_init_fields))

View File

@ -0,0 +1,3 @@
Dataclasses now calls :func:`exec` once per dataclass, instead of once
per method being added. This can speed up dataclass creation by up to
20%.