gh-108751: Add copy.replace() function (GH-108752)

It creates a modified copy of an object by calling the object's
__replace__() method.

It is a generalization of dataclasses.replace(), named tuple's _replace()
method and replace() methods in various classes, and supports all these
stdlib classes.
This commit is contained in:
Serhiy Storchaka 2023-09-06 23:55:42 +03:00 committed by GitHub
parent 9f0c0a46f0
commit 6f3c138dfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 311 additions and 68 deletions

View File

@ -979,6 +979,8 @@ field names, the method and attribute names start with an underscore.
>>> for partnum, record in inventory.items(): >>> for partnum, record in inventory.items():
... inventory[partnum] = record._replace(price=newprices[partnum], timestamp=time.now()) ... inventory[partnum] = record._replace(price=newprices[partnum], timestamp=time.now())
Named tuples are also supported by generic function :func:`copy.replace`.
.. attribute:: somenamedtuple._fields .. attribute:: somenamedtuple._fields
Tuple of strings listing the field names. Useful for introspection Tuple of strings listing the field names. Useful for introspection

View File

@ -17,14 +17,22 @@ operations (explained below).
Interface summary: Interface summary:
.. function:: copy(x) .. function:: copy(obj)
Return a shallow copy of *x*. Return a shallow copy of *obj*.
.. function:: deepcopy(x[, memo]) .. function:: deepcopy(obj[, memo])
Return a deep copy of *x*. Return a deep copy of *obj*.
.. function:: replace(obj, /, **changes)
Creates a new object of the same type as *obj*, replacing fields with values
from *changes*.
.. versionadded:: 3.13
.. exception:: Error .. exception:: Error
@ -89,6 +97,20 @@ with the component as first argument and the memo dictionary as second argument.
The memo dictionary should be treated as an opaque object. The memo dictionary should be treated as an opaque object.
.. index::
single: __replace__() (replace protocol)
Function :func:`replace` is more limited than :func:`copy` and :func:`deepcopy`,
and only supports named tuples created by :func:`~collections.namedtuple`,
:mod:`dataclasses`, and other classes which define method :meth:`!__replace__`.
.. method:: __replace__(self, /, **changes)
:noindex:
:meth:`!__replace__` should create a new object of the same type,
replacing fields with values from *changes*.
.. seealso:: .. seealso::
Module :mod:`pickle` Module :mod:`pickle`

View File

@ -456,6 +456,8 @@ Module contents
``replace()`` (or similarly named) method which handles instance ``replace()`` (or similarly named) method which handles instance
copying. copying.
Dataclass instances are also supported by generic function :func:`copy.replace`.
.. function:: is_dataclass(obj) .. function:: is_dataclass(obj)
Return ``True`` if its parameter is a dataclass or an instance of one, Return ``True`` if its parameter is a dataclass or an instance of one,

View File

@ -652,6 +652,9 @@ Instance methods:
>>> d.replace(day=26) >>> d.replace(day=26)
datetime.date(2002, 12, 26) datetime.date(2002, 12, 26)
:class:`date` objects are also supported by generic function
:func:`copy.replace`.
.. method:: date.timetuple() .. method:: date.timetuple()
@ -1251,6 +1254,9 @@ Instance methods:
``tzinfo=None`` can be specified to create a naive datetime from an aware ``tzinfo=None`` can be specified to create a naive datetime from an aware
datetime with no conversion of date and time data. datetime with no conversion of date and time data.
:class:`datetime` objects are also supported by generic function
:func:`copy.replace`.
.. versionadded:: 3.6 .. versionadded:: 3.6
Added the ``fold`` argument. Added the ``fold`` argument.
@ -1827,6 +1833,9 @@ Instance methods:
``tzinfo=None`` can be specified to create a naive :class:`.time` from an ``tzinfo=None`` can be specified to create a naive :class:`.time` from an
aware :class:`.time`, without conversion of the time data. aware :class:`.time`, without conversion of the time data.
:class:`time` objects are also supported by generic function
:func:`copy.replace`.
.. versionadded:: 3.6 .. versionadded:: 3.6
Added the ``fold`` argument. Added the ``fold`` argument.

View File

@ -689,8 +689,8 @@ function.
The optional *return_annotation* argument, can be an arbitrary Python object, The optional *return_annotation* argument, can be an arbitrary Python object,
is the "return" annotation of the callable. is the "return" annotation of the callable.
Signature objects are *immutable*. Use :meth:`Signature.replace` to make a Signature objects are *immutable*. Use :meth:`Signature.replace` or
modified copy. :func:`copy.replace` to make a modified copy.
.. versionchanged:: 3.5 .. versionchanged:: 3.5
Signature objects are picklable and :term:`hashable`. Signature objects are picklable and :term:`hashable`.
@ -746,6 +746,9 @@ function.
>>> str(new_sig) >>> str(new_sig)
"(a, b) -> 'new return anno'" "(a, b) -> 'new return anno'"
Signature objects are also supported by generic function
:func:`copy.replace`.
.. classmethod:: Signature.from_callable(obj, *, follow_wrapped=True, globalns=None, localns=None) .. classmethod:: Signature.from_callable(obj, *, follow_wrapped=True, globalns=None, localns=None)
Return a :class:`Signature` (or its subclass) object for a given callable Return a :class:`Signature` (or its subclass) object for a given callable
@ -769,7 +772,7 @@ function.
.. class:: Parameter(name, kind, *, default=Parameter.empty, annotation=Parameter.empty) .. class:: Parameter(name, kind, *, default=Parameter.empty, annotation=Parameter.empty)
Parameter objects are *immutable*. Instead of modifying a Parameter object, Parameter objects are *immutable*. Instead of modifying a Parameter object,
you can use :meth:`Parameter.replace` to create a modified copy. you can use :meth:`Parameter.replace` or :func:`copy.replace` to create a modified copy.
.. versionchanged:: 3.5 .. versionchanged:: 3.5
Parameter objects are picklable and :term:`hashable`. Parameter objects are picklable and :term:`hashable`.
@ -892,6 +895,8 @@ function.
>>> str(param.replace(default=Parameter.empty, annotation='spam')) >>> str(param.replace(default=Parameter.empty, annotation='spam'))
"foo:'spam'" "foo:'spam'"
Parameter objects are also supported by generic function :func:`copy.replace`.
.. versionchanged:: 3.4 .. versionchanged:: 3.4
In Python 3.3 Parameter objects were allowed to have ``name`` set In Python 3.3 Parameter objects were allowed to have ``name`` set
to ``None`` if their ``kind`` was set to ``POSITIONAL_ONLY``. to ``None`` if their ``kind`` was set to ``POSITIONAL_ONLY``.

View File

@ -200,6 +200,8 @@ Standard names are defined for the following types:
Return a copy of the code object with new values for the specified fields. Return a copy of the code object with new values for the specified fields.
Code objects are also supported by generic function :func:`copy.replace`.
.. versionadded:: 3.8 .. versionadded:: 3.8
.. data:: CellType .. data:: CellType

View File

@ -115,6 +115,18 @@ array
It can be used instead of ``'u'`` type code, which is deprecated. It can be used instead of ``'u'`` type code, which is deprecated.
(Contributed by Inada Naoki in :gh:`80480`.) (Contributed by Inada Naoki in :gh:`80480`.)
copy
----
* Add :func:`copy.replace` function which allows to create a modified copy of
an object, which is especially usefule for immutable objects.
It supports named tuples created with the factory function
:func:`collections.namedtuple`, :class:`~dataclasses.dataclass` instances,
various :mod:`datetime` objects, :class:`~inspect.Signature` objects,
:class:`~inspect.Parameter` objects, :ref:`code object <code-objects>`, and
any user classes which define the :meth:`!__replace__` method.
(Contributed by Serhiy Storchaka in :gh:`108751`.)
dbm dbm
--- ---

View File

@ -1112,6 +1112,8 @@ class date:
day = self._day day = self._day
return type(self)(year, month, day) return type(self)(year, month, day)
__replace__ = replace
# Comparisons of date objects with other. # Comparisons of date objects with other.
def __eq__(self, other): def __eq__(self, other):
@ -1637,6 +1639,8 @@ class time:
fold = self._fold fold = self._fold
return type(self)(hour, minute, second, microsecond, tzinfo, fold=fold) return type(self)(hour, minute, second, microsecond, tzinfo, fold=fold)
__replace__ = replace
# Pickle support. # Pickle support.
def _getstate(self, protocol=3): def _getstate(self, protocol=3):
@ -1983,6 +1987,8 @@ class datetime(date):
return type(self)(year, month, day, hour, minute, second, return type(self)(year, month, day, hour, minute, second,
microsecond, tzinfo, fold=fold) microsecond, tzinfo, fold=fold)
__replace__ = replace
def _local_timezone(self): def _local_timezone(self):
if self.tzinfo is None: if self.tzinfo is None:
ts = self._mktime() ts = self._mktime()

View File

@ -495,6 +495,7 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
'_field_defaults': field_defaults, '_field_defaults': field_defaults,
'__new__': __new__, '__new__': __new__,
'_make': _make, '_make': _make,
'__replace__': _replace,
'_replace': _replace, '_replace': _replace,
'__repr__': __repr__, '__repr__': __repr__,
'_asdict': _asdict, '_asdict': _asdict,

View File

@ -290,3 +290,16 @@ def _reconstruct(x, memo, func, args,
return y return y
del types, weakref del types, weakref
def replace(obj, /, **changes):
"""Return a new object replacing specified fields with new values.
This is especially useful for immutable objects, like named tuples or
frozen dataclasses.
"""
cls = obj.__class__
func = getattr(cls, '__replace__', None)
if func is None:
raise TypeError(f"replace() does not support {cls.__name__} objects")
return func(obj, **changes)

View File

@ -1073,6 +1073,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
globals, globals,
slots, slots,
)) ))
_set_new_attribute(cls, '__replace__', _replace)
# Get the fields as a list, and include only real fields. This is # Get the fields as a list, and include only real fields. This is
# used in all of the following methods. # used in all of the following methods.
@ -1546,12 +1547,14 @@ def replace(obj, /, **changes):
c1 = replace(c, x=3) c1 = replace(c, x=3)
assert c1.x == 3 and c1.y == 2 assert c1.x == 3 and c1.y == 2
""" """
# We're going to mutate 'changes', but that's okay because it's a
# new dict, even if called with 'replace(obj, **my_changes)'.
if not _is_dataclass_instance(obj): if not _is_dataclass_instance(obj):
raise TypeError("replace() should be called on dataclass instances") raise TypeError("replace() should be called on dataclass instances")
return _replace(obj, **changes)
def _replace(obj, /, **changes):
# We're going to mutate 'changes', but that's okay because it's a
# new dict, even if called with 'replace(obj, **my_changes)'.
# It's an error to have init=False fields in 'changes'. # It's an error to have init=False fields in 'changes'.
# If a field is not in 'changes', read its value from the provided obj. # If a field is not in 'changes', read its value from the provided obj.

View File

@ -2870,6 +2870,8 @@ class Parameter:
return formatted return formatted
__replace__ = replace
def __repr__(self): def __repr__(self):
return '<{} "{}">'.format(self.__class__.__name__, self) return '<{} "{}">'.format(self.__class__.__name__, self)
@ -3130,6 +3132,8 @@ class Signature:
return type(self)(parameters, return type(self)(parameters,
return_annotation=return_annotation) return_annotation=return_annotation)
__replace__ = replace
def _hash_basis(self): def _hash_basis(self):
params = tuple(param for param in self.parameters.values() params = tuple(param for param in self.parameters.values()
if param.kind != _KEYWORD_ONLY) if param.kind != _KEYWORD_ONLY)

View File

@ -1699,22 +1699,23 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
cls = self.theclass cls = self.theclass
args = [1, 2, 3] args = [1, 2, 3]
base = cls(*args) base = cls(*args)
self.assertEqual(base, base.replace()) self.assertEqual(base.replace(), base)
self.assertEqual(copy.replace(base), base)
i = 0 changes = (("year", 2),
for name, newval in (("year", 2), ("month", 3),
("month", 3), ("day", 4))
("day", 4)): for i, (name, newval) in enumerate(changes):
newargs = args[:] newargs = args[:]
newargs[i] = newval newargs[i] = newval
expected = cls(*newargs) expected = cls(*newargs)
got = base.replace(**{name: newval}) self.assertEqual(base.replace(**{name: newval}), expected)
self.assertEqual(expected, got) self.assertEqual(copy.replace(base, **{name: newval}), expected)
i += 1
# Out of bounds. # Out of bounds.
base = cls(2000, 2, 29) base = cls(2000, 2, 29)
self.assertRaises(ValueError, base.replace, year=2001) self.assertRaises(ValueError, base.replace, year=2001)
self.assertRaises(ValueError, copy.replace, base, year=2001)
def test_subclass_replace(self): def test_subclass_replace(self):
class DateSubclass(self.theclass): class DateSubclass(self.theclass):
@ -1722,6 +1723,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
dt = DateSubclass(2012, 1, 1) dt = DateSubclass(2012, 1, 1)
self.assertIs(type(dt.replace(year=2013)), DateSubclass) self.assertIs(type(dt.replace(year=2013)), DateSubclass)
self.assertIs(type(copy.replace(dt, year=2013)), DateSubclass)
def test_subclass_date(self): def test_subclass_date(self):
@ -2856,26 +2858,27 @@ class TestDateTime(TestDate):
cls = self.theclass cls = self.theclass
args = [1, 2, 3, 4, 5, 6, 7] args = [1, 2, 3, 4, 5, 6, 7]
base = cls(*args) base = cls(*args)
self.assertEqual(base, base.replace()) self.assertEqual(base.replace(), base)
self.assertEqual(copy.replace(base), base)
i = 0 changes = (("year", 2),
for name, newval in (("year", 2), ("month", 3),
("month", 3), ("day", 4),
("day", 4), ("hour", 5),
("hour", 5), ("minute", 6),
("minute", 6), ("second", 7),
("second", 7), ("microsecond", 8))
("microsecond", 8)): for i, (name, newval) in enumerate(changes):
newargs = args[:] newargs = args[:]
newargs[i] = newval newargs[i] = newval
expected = cls(*newargs) expected = cls(*newargs)
got = base.replace(**{name: newval}) self.assertEqual(base.replace(**{name: newval}), expected)
self.assertEqual(expected, got) self.assertEqual(copy.replace(base, **{name: newval}), expected)
i += 1
# Out of bounds. # Out of bounds.
base = cls(2000, 2, 29) base = cls(2000, 2, 29)
self.assertRaises(ValueError, base.replace, year=2001) self.assertRaises(ValueError, base.replace, year=2001)
self.assertRaises(ValueError, copy.replace, base, year=2001)
@support.run_with_tz('EDT4') @support.run_with_tz('EDT4')
def test_astimezone(self): def test_astimezone(self):
@ -3671,19 +3674,19 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
cls = self.theclass cls = self.theclass
args = [1, 2, 3, 4] args = [1, 2, 3, 4]
base = cls(*args) base = cls(*args)
self.assertEqual(base, base.replace()) self.assertEqual(base.replace(), base)
self.assertEqual(copy.replace(base), base)
i = 0 changes = (("hour", 5),
for name, newval in (("hour", 5), ("minute", 6),
("minute", 6), ("second", 7),
("second", 7), ("microsecond", 8))
("microsecond", 8)): for i, (name, newval) in enumerate(changes):
newargs = args[:] newargs = args[:]
newargs[i] = newval newargs[i] = newval
expected = cls(*newargs) expected = cls(*newargs)
got = base.replace(**{name: newval}) self.assertEqual(base.replace(**{name: newval}), expected)
self.assertEqual(expected, got) self.assertEqual(copy.replace(base, **{name: newval}), expected)
i += 1
# Out of bounds. # Out of bounds.
base = cls(1) base = cls(1)
@ -3691,6 +3694,10 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
self.assertRaises(ValueError, base.replace, minute=-1) self.assertRaises(ValueError, base.replace, minute=-1)
self.assertRaises(ValueError, base.replace, second=100) self.assertRaises(ValueError, base.replace, second=100)
self.assertRaises(ValueError, base.replace, microsecond=1000000) self.assertRaises(ValueError, base.replace, microsecond=1000000)
self.assertRaises(ValueError, copy.replace, base, hour=24)
self.assertRaises(ValueError, copy.replace, base, minute=-1)
self.assertRaises(ValueError, copy.replace, base, second=100)
self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
def test_subclass_replace(self): def test_subclass_replace(self):
class TimeSubclass(self.theclass): class TimeSubclass(self.theclass):
@ -3698,6 +3705,7 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
ctime = TimeSubclass(12, 30) ctime = TimeSubclass(12, 30)
self.assertIs(type(ctime.replace(hour=10)), TimeSubclass) self.assertIs(type(ctime.replace(hour=10)), TimeSubclass)
self.assertIs(type(copy.replace(ctime, hour=10)), TimeSubclass)
def test_subclass_time(self): def test_subclass_time(self):
@ -4085,31 +4093,37 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
zm200 = FixedOffset(timedelta(minutes=-200), "-200") zm200 = FixedOffset(timedelta(minutes=-200), "-200")
args = [1, 2, 3, 4, z100] args = [1, 2, 3, 4, z100]
base = cls(*args) base = cls(*args)
self.assertEqual(base, base.replace()) self.assertEqual(base.replace(), base)
self.assertEqual(copy.replace(base), base)
i = 0 changes = (("hour", 5),
for name, newval in (("hour", 5), ("minute", 6),
("minute", 6), ("second", 7),
("second", 7), ("microsecond", 8),
("microsecond", 8), ("tzinfo", zm200))
("tzinfo", zm200)): for i, (name, newval) in enumerate(changes):
newargs = args[:] newargs = args[:]
newargs[i] = newval newargs[i] = newval
expected = cls(*newargs) expected = cls(*newargs)
got = base.replace(**{name: newval}) self.assertEqual(base.replace(**{name: newval}), expected)
self.assertEqual(expected, got) self.assertEqual(copy.replace(base, **{name: newval}), expected)
i += 1
# Ensure we can get rid of a tzinfo. # Ensure we can get rid of a tzinfo.
self.assertEqual(base.tzname(), "+100") self.assertEqual(base.tzname(), "+100")
base2 = base.replace(tzinfo=None) base2 = base.replace(tzinfo=None)
self.assertIsNone(base2.tzinfo) self.assertIsNone(base2.tzinfo)
self.assertIsNone(base2.tzname()) self.assertIsNone(base2.tzname())
base22 = copy.replace(base, tzinfo=None)
self.assertIsNone(base22.tzinfo)
self.assertIsNone(base22.tzname())
# Ensure we can add one. # Ensure we can add one.
base3 = base2.replace(tzinfo=z100) base3 = base2.replace(tzinfo=z100)
self.assertEqual(base, base3) self.assertEqual(base, base3)
self.assertIs(base.tzinfo, base3.tzinfo) self.assertIs(base.tzinfo, base3.tzinfo)
base32 = copy.replace(base22, tzinfo=z100)
self.assertEqual(base, base32)
self.assertIs(base.tzinfo, base32.tzinfo)
# Out of bounds. # Out of bounds.
base = cls(1) base = cls(1)
@ -4117,6 +4131,10 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
self.assertRaises(ValueError, base.replace, minute=-1) self.assertRaises(ValueError, base.replace, minute=-1)
self.assertRaises(ValueError, base.replace, second=100) self.assertRaises(ValueError, base.replace, second=100)
self.assertRaises(ValueError, base.replace, microsecond=1000000) self.assertRaises(ValueError, base.replace, microsecond=1000000)
self.assertRaises(ValueError, copy.replace, base, hour=24)
self.assertRaises(ValueError, copy.replace, base, minute=-1)
self.assertRaises(ValueError, copy.replace, base, second=100)
self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
def test_mixed_compare(self): def test_mixed_compare(self):
t1 = self.theclass(1, 2, 3) t1 = self.theclass(1, 2, 3)
@ -4885,38 +4903,45 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
zm200 = FixedOffset(timedelta(minutes=-200), "-200") zm200 = FixedOffset(timedelta(minutes=-200), "-200")
args = [1, 2, 3, 4, 5, 6, 7, z100] args = [1, 2, 3, 4, 5, 6, 7, z100]
base = cls(*args) base = cls(*args)
self.assertEqual(base, base.replace()) self.assertEqual(base.replace(), base)
self.assertEqual(copy.replace(base), base)
i = 0 changes = (("year", 2),
for name, newval in (("year", 2), ("month", 3),
("month", 3), ("day", 4),
("day", 4), ("hour", 5),
("hour", 5), ("minute", 6),
("minute", 6), ("second", 7),
("second", 7), ("microsecond", 8),
("microsecond", 8), ("tzinfo", zm200))
("tzinfo", zm200)): for i, (name, newval) in enumerate(changes):
newargs = args[:] newargs = args[:]
newargs[i] = newval newargs[i] = newval
expected = cls(*newargs) expected = cls(*newargs)
got = base.replace(**{name: newval}) self.assertEqual(base.replace(**{name: newval}), expected)
self.assertEqual(expected, got) self.assertEqual(copy.replace(base, **{name: newval}), expected)
i += 1
# Ensure we can get rid of a tzinfo. # Ensure we can get rid of a tzinfo.
self.assertEqual(base.tzname(), "+100") self.assertEqual(base.tzname(), "+100")
base2 = base.replace(tzinfo=None) base2 = base.replace(tzinfo=None)
self.assertIsNone(base2.tzinfo) self.assertIsNone(base2.tzinfo)
self.assertIsNone(base2.tzname()) self.assertIsNone(base2.tzname())
base22 = copy.replace(base, tzinfo=None)
self.assertIsNone(base22.tzinfo)
self.assertIsNone(base22.tzname())
# Ensure we can add one. # Ensure we can add one.
base3 = base2.replace(tzinfo=z100) base3 = base2.replace(tzinfo=z100)
self.assertEqual(base, base3) self.assertEqual(base, base3)
self.assertIs(base.tzinfo, base3.tzinfo) self.assertIs(base.tzinfo, base3.tzinfo)
base32 = copy.replace(base22, tzinfo=z100)
self.assertEqual(base, base32)
self.assertIs(base.tzinfo, base32.tzinfo)
# Out of bounds. # Out of bounds.
base = cls(2000, 2, 29) base = cls(2000, 2, 29)
self.assertRaises(ValueError, base.replace, year=2001) self.assertRaises(ValueError, base.replace, year=2001)
self.assertRaises(ValueError, copy.replace, base, year=2001)
def test_more_astimezone(self): def test_more_astimezone(self):
# The inherited test_astimezone covered some trivial and error cases. # The inherited test_astimezone covered some trivial and error cases.

View File

@ -125,6 +125,7 @@ consts: ('None',)
""" """
import copy
import inspect import inspect
import sys import sys
import threading import threading
@ -280,11 +281,17 @@ class CodeTest(unittest.TestCase):
with self.subTest(attr=attr, value=value): with self.subTest(attr=attr, value=value):
new_code = code.replace(**{attr: value}) new_code = code.replace(**{attr: value})
self.assertEqual(getattr(new_code, attr), value) self.assertEqual(getattr(new_code, attr), value)
new_code = copy.replace(code, **{attr: value})
self.assertEqual(getattr(new_code, attr), value)
new_code = code.replace(co_varnames=code2.co_varnames, new_code = code.replace(co_varnames=code2.co_varnames,
co_nlocals=code2.co_nlocals) co_nlocals=code2.co_nlocals)
self.assertEqual(new_code.co_varnames, code2.co_varnames) self.assertEqual(new_code.co_varnames, code2.co_varnames)
self.assertEqual(new_code.co_nlocals, code2.co_nlocals) self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
new_code = copy.replace(code, co_varnames=code2.co_varnames,
co_nlocals=code2.co_nlocals)
self.assertEqual(new_code.co_varnames, code2.co_varnames)
self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
def test_nlocals_mismatch(self): def test_nlocals_mismatch(self):
def func(): def func():

View File

@ -4,7 +4,7 @@ import copy
import copyreg import copyreg
import weakref import weakref
import abc import abc
from operator import le, lt, ge, gt, eq, ne from operator import le, lt, ge, gt, eq, ne, attrgetter
import unittest import unittest
from test import support from test import support
@ -899,7 +899,71 @@ class TestCopy(unittest.TestCase):
g.b() g.b()
class TestReplace(unittest.TestCase):
def test_unsupported(self):
self.assertRaises(TypeError, copy.replace, 1)
self.assertRaises(TypeError, copy.replace, [])
self.assertRaises(TypeError, copy.replace, {})
def f(): pass
self.assertRaises(TypeError, copy.replace, f)
class A: pass
self.assertRaises(TypeError, copy.replace, A)
self.assertRaises(TypeError, copy.replace, A())
def test_replace_method(self):
class A:
def __new__(cls, x, y=0):
self = object.__new__(cls)
self.x = x
self.y = y
return self
def __init__(self, *args, **kwargs):
self.z = self.x + self.y
def __replace__(self, **changes):
x = changes.get('x', self.x)
y = changes.get('y', self.y)
return type(self)(x, y)
attrs = attrgetter('x', 'y', 'z')
a = A(11, 22)
self.assertEqual(attrs(copy.replace(a)), (11, 22, 33))
self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23))
self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13))
self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3))
def test_namedtuple(self):
from collections import namedtuple
Point = namedtuple('Point', 'x y', defaults=(0,))
p = Point(11, 22)
self.assertEqual(copy.replace(p), (11, 22))
self.assertEqual(copy.replace(p, x=1), (1, 22))
self.assertEqual(copy.replace(p, y=2), (11, 2))
self.assertEqual(copy.replace(p, x=1, y=2), (1, 2))
with self.assertRaisesRegex(ValueError, 'unexpected field name'):
copy.replace(p, x=1, error=2)
def test_dataclass(self):
from dataclasses import dataclass
@dataclass
class C:
x: int
y: int = 0
attrs = attrgetter('x', 'y')
c = C(11, 22)
self.assertEqual(attrs(copy.replace(c)), (11, 22))
self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22))
self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2))
self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2))
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
copy.replace(c, x=1, error=2)
def global_foo(x, y): return x+y def global_foo(x, y): return x+y
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import builtins import builtins
import collections import collections
import copy
import datetime import datetime
import functools import functools
import importlib import importlib
@ -3830,6 +3831,28 @@ class TestSignatureObject(unittest.TestCase):
P('bar', P.VAR_POSITIONAL)])), P('bar', P.VAR_POSITIONAL)])),
'(foo, /, *bar)') '(foo, /, *bar)')
def test_signature_replace_parameters(self):
def test(a, b) -> 42:
pass
sig = inspect.signature(test)
parameters = sig.parameters
sig = sig.replace(parameters=list(parameters.values())[1:])
self.assertEqual(list(sig.parameters), ['b'])
self.assertEqual(sig.parameters['b'], parameters['b'])
self.assertEqual(sig.return_annotation, 42)
sig = sig.replace(parameters=())
self.assertEqual(dict(sig.parameters), {})
sig = inspect.signature(test)
parameters = sig.parameters
sig = copy.replace(sig, parameters=list(parameters.values())[1:])
self.assertEqual(list(sig.parameters), ['b'])
self.assertEqual(sig.parameters['b'], parameters['b'])
self.assertEqual(sig.return_annotation, 42)
sig = copy.replace(sig, parameters=())
self.assertEqual(dict(sig.parameters), {})
def test_signature_replace_anno(self): def test_signature_replace_anno(self):
def test() -> 42: def test() -> 42:
pass pass
@ -3843,6 +3866,15 @@ class TestSignatureObject(unittest.TestCase):
self.assertEqual(sig.return_annotation, 42) self.assertEqual(sig.return_annotation, 42)
self.assertEqual(sig, inspect.signature(test)) self.assertEqual(sig, inspect.signature(test))
sig = inspect.signature(test)
sig = copy.replace(sig, return_annotation=None)
self.assertIs(sig.return_annotation, None)
sig = copy.replace(sig, return_annotation=sig.empty)
self.assertIs(sig.return_annotation, sig.empty)
sig = copy.replace(sig, return_annotation=42)
self.assertEqual(sig.return_annotation, 42)
self.assertEqual(sig, inspect.signature(test))
def test_signature_replaced(self): def test_signature_replaced(self):
def test(): def test():
pass pass
@ -4187,41 +4219,66 @@ class TestParameterObject(unittest.TestCase):
p = inspect.Parameter('foo', default=42, p = inspect.Parameter('foo', default=42,
kind=inspect.Parameter.KEYWORD_ONLY) kind=inspect.Parameter.KEYWORD_ONLY)
self.assertIsNot(p, p.replace()) self.assertIsNot(p.replace(), p)
self.assertEqual(p, p.replace()) self.assertEqual(p.replace(), p)
self.assertIsNot(copy.replace(p), p)
self.assertEqual(copy.replace(p), p)
p2 = p.replace(annotation=1) p2 = p.replace(annotation=1)
self.assertEqual(p2.annotation, 1) self.assertEqual(p2.annotation, 1)
p2 = p2.replace(annotation=p2.empty) p2 = p2.replace(annotation=p2.empty)
self.assertEqual(p, p2) self.assertEqual(p2, p)
p3 = copy.replace(p, annotation=1)
self.assertEqual(p3.annotation, 1)
p3 = copy.replace(p3, annotation=p3.empty)
self.assertEqual(p3, p)
p2 = p2.replace(name='bar') p2 = p2.replace(name='bar')
self.assertEqual(p2.name, 'bar') self.assertEqual(p2.name, 'bar')
self.assertNotEqual(p2, p) self.assertNotEqual(p2, p)
p3 = copy.replace(p3, name='bar')
self.assertEqual(p3.name, 'bar')
self.assertNotEqual(p3, p)
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
'name is a required attribute'): 'name is a required attribute'):
p2 = p2.replace(name=p2.empty) p2 = p2.replace(name=p2.empty)
with self.assertRaisesRegex(ValueError,
'name is a required attribute'):
p3 = copy.replace(p3, name=p3.empty)
p2 = p2.replace(name='foo', default=None) p2 = p2.replace(name='foo', default=None)
self.assertIs(p2.default, None) self.assertIs(p2.default, None)
self.assertNotEqual(p2, p) self.assertNotEqual(p2, p)
p3 = copy.replace(p3, name='foo', default=None)
self.assertIs(p3.default, None)
self.assertNotEqual(p3, p)
p2 = p2.replace(name='foo', default=p2.empty) p2 = p2.replace(name='foo', default=p2.empty)
self.assertIs(p2.default, p2.empty) self.assertIs(p2.default, p2.empty)
p3 = copy.replace(p3, name='foo', default=p3.empty)
self.assertIs(p3.default, p3.empty)
p2 = p2.replace(default=42, kind=p2.POSITIONAL_OR_KEYWORD) p2 = p2.replace(default=42, kind=p2.POSITIONAL_OR_KEYWORD)
self.assertEqual(p2.kind, p2.POSITIONAL_OR_KEYWORD) self.assertEqual(p2.kind, p2.POSITIONAL_OR_KEYWORD)
self.assertNotEqual(p2, p) self.assertNotEqual(p2, p)
p3 = copy.replace(p3, default=42, kind=p3.POSITIONAL_OR_KEYWORD)
self.assertEqual(p3.kind, p3.POSITIONAL_OR_KEYWORD)
self.assertNotEqual(p3, p)
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
"value <class 'inspect._empty'> " "value <class 'inspect._empty'> "
"is not a valid Parameter.kind"): "is not a valid Parameter.kind"):
p2 = p2.replace(kind=p2.empty) p2 = p2.replace(kind=p2.empty)
with self.assertRaisesRegex(ValueError,
"value <class 'inspect._empty'> "
"is not a valid Parameter.kind"):
p3 = copy.replace(p3, kind=p3.empty)
p2 = p2.replace(kind=p2.KEYWORD_ONLY) p2 = p2.replace(kind=p2.KEYWORD_ONLY)
self.assertEqual(p2, p) self.assertEqual(p2, p)
p3 = copy.replace(p3, kind=p3.KEYWORD_ONLY)
self.assertEqual(p3, p)
def test_signature_parameter_positional_only(self): def test_signature_parameter_positional_only(self):
with self.assertRaisesRegex(TypeError, 'name must be a str'): with self.assertRaisesRegex(TypeError, 'name must be a str'):

View File

@ -0,0 +1,2 @@
Add :func:`copy.replace` function which allows to create a modified copy of
an object. It supports named tuples, dataclasses, and many other objects.

View File

@ -3590,6 +3590,8 @@ static PyMethodDef date_methods[] = {
{"replace", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS, {"replace", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Return date with new specified fields.")}, PyDoc_STR("Return date with new specified fields.")},
{"__replace__", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS},
{"__reduce__", (PyCFunction)date_reduce, METH_NOARGS, {"__reduce__", (PyCFunction)date_reduce, METH_NOARGS,
PyDoc_STR("__reduce__() -> (cls, state)")}, PyDoc_STR("__reduce__() -> (cls, state)")},
@ -4719,6 +4721,8 @@ static PyMethodDef time_methods[] = {
{"replace", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS, {"replace", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Return time with new specified fields.")}, PyDoc_STR("Return time with new specified fields.")},
{"__replace__", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS},
{"fromisoformat", (PyCFunction)time_fromisoformat, METH_O | METH_CLASS, {"fromisoformat", (PyCFunction)time_fromisoformat, METH_O | METH_CLASS,
PyDoc_STR("string -> time from a string in ISO 8601 format")}, PyDoc_STR("string -> time from a string in ISO 8601 format")},
@ -6579,6 +6583,8 @@ static PyMethodDef datetime_methods[] = {
{"replace", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS, {"replace", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Return datetime with new specified fields.")}, PyDoc_STR("Return datetime with new specified fields.")},
{"__replace__", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS},
{"astimezone", _PyCFunction_CAST(datetime_astimezone), METH_VARARGS | METH_KEYWORDS, {"astimezone", _PyCFunction_CAST(datetime_astimezone), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("tz -> convert to local time in new timezone tz\n")}, PyDoc_STR("tz -> convert to local time in new timezone tz\n")},

View File

@ -2145,6 +2145,7 @@ static struct PyMethodDef code_methods[] = {
{"co_positions", (PyCFunction)code_positionsiterator, METH_NOARGS}, {"co_positions", (PyCFunction)code_positionsiterator, METH_NOARGS},
CODE_REPLACE_METHODDEF CODE_REPLACE_METHODDEF
CODE__VARNAME_FROM_OPARG_METHODDEF CODE__VARNAME_FROM_OPARG_METHODDEF
{"__replace__", _PyCFunction_CAST(code_replace), METH_FASTCALL|METH_KEYWORDS},
{NULL, NULL} /* sentinel */ {NULL, NULL} /* sentinel */
}; };