mirror of https://github.com/python/cpython
gh-108751: Add copy.replace() function (GH-108752)
It creates a modified copy of an object by calling the object's __replace__() method. It is a generalization of dataclasses.replace(), named tuple's _replace() method and replace() methods in various classes, and supports all these stdlib classes.
This commit is contained in:
parent
9f0c0a46f0
commit
6f3c138dfa
|
@ -979,6 +979,8 @@ field names, the method and attribute names start with an underscore.
|
|||
>>> for partnum, record in inventory.items():
|
||||
... inventory[partnum] = record._replace(price=newprices[partnum], timestamp=time.now())
|
||||
|
||||
Named tuples are also supported by generic function :func:`copy.replace`.
|
||||
|
||||
.. attribute:: somenamedtuple._fields
|
||||
|
||||
Tuple of strings listing the field names. Useful for introspection
|
||||
|
|
|
@ -17,14 +17,22 @@ operations (explained below).
|
|||
|
||||
Interface summary:
|
||||
|
||||
.. function:: copy(x)
|
||||
.. function:: copy(obj)
|
||||
|
||||
Return a shallow copy of *x*.
|
||||
Return a shallow copy of *obj*.
|
||||
|
||||
|
||||
.. function:: deepcopy(x[, memo])
|
||||
.. function:: deepcopy(obj[, memo])
|
||||
|
||||
Return a deep copy of *x*.
|
||||
Return a deep copy of *obj*.
|
||||
|
||||
|
||||
.. function:: replace(obj, /, **changes)
|
||||
|
||||
Creates a new object of the same type as *obj*, replacing fields with values
|
||||
from *changes*.
|
||||
|
||||
.. versionadded:: 3.13
|
||||
|
||||
|
||||
.. exception:: Error
|
||||
|
@ -89,6 +97,20 @@ with the component as first argument and the memo dictionary as second argument.
|
|||
The memo dictionary should be treated as an opaque object.
|
||||
|
||||
|
||||
.. index::
|
||||
single: __replace__() (replace protocol)
|
||||
|
||||
Function :func:`replace` is more limited than :func:`copy` and :func:`deepcopy`,
|
||||
and only supports named tuples created by :func:`~collections.namedtuple`,
|
||||
:mod:`dataclasses`, and other classes which define method :meth:`!__replace__`.
|
||||
|
||||
.. method:: __replace__(self, /, **changes)
|
||||
:noindex:
|
||||
|
||||
:meth:`!__replace__` should create a new object of the same type,
|
||||
replacing fields with values from *changes*.
|
||||
|
||||
|
||||
.. seealso::
|
||||
|
||||
Module :mod:`pickle`
|
||||
|
|
|
@ -456,6 +456,8 @@ Module contents
|
|||
``replace()`` (or similarly named) method which handles instance
|
||||
copying.
|
||||
|
||||
Dataclass instances are also supported by generic function :func:`copy.replace`.
|
||||
|
||||
.. function:: is_dataclass(obj)
|
||||
|
||||
Return ``True`` if its parameter is a dataclass or an instance of one,
|
||||
|
|
|
@ -652,6 +652,9 @@ Instance methods:
|
|||
>>> d.replace(day=26)
|
||||
datetime.date(2002, 12, 26)
|
||||
|
||||
:class:`date` objects are also supported by generic function
|
||||
:func:`copy.replace`.
|
||||
|
||||
|
||||
.. method:: date.timetuple()
|
||||
|
||||
|
@ -1251,6 +1254,9 @@ Instance methods:
|
|||
``tzinfo=None`` can be specified to create a naive datetime from an aware
|
||||
datetime with no conversion of date and time data.
|
||||
|
||||
:class:`datetime` objects are also supported by generic function
|
||||
:func:`copy.replace`.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
Added the ``fold`` argument.
|
||||
|
||||
|
@ -1827,6 +1833,9 @@ Instance methods:
|
|||
``tzinfo=None`` can be specified to create a naive :class:`.time` from an
|
||||
aware :class:`.time`, without conversion of the time data.
|
||||
|
||||
:class:`time` objects are also supported by generic function
|
||||
:func:`copy.replace`.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
Added the ``fold`` argument.
|
||||
|
||||
|
|
|
@ -689,8 +689,8 @@ function.
|
|||
The optional *return_annotation* argument, can be an arbitrary Python object,
|
||||
is the "return" annotation of the callable.
|
||||
|
||||
Signature objects are *immutable*. Use :meth:`Signature.replace` to make a
|
||||
modified copy.
|
||||
Signature objects are *immutable*. Use :meth:`Signature.replace` or
|
||||
:func:`copy.replace` to make a modified copy.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Signature objects are picklable and :term:`hashable`.
|
||||
|
@ -746,6 +746,9 @@ function.
|
|||
>>> str(new_sig)
|
||||
"(a, b) -> 'new return anno'"
|
||||
|
||||
Signature objects are also supported by generic function
|
||||
:func:`copy.replace`.
|
||||
|
||||
.. classmethod:: Signature.from_callable(obj, *, follow_wrapped=True, globalns=None, localns=None)
|
||||
|
||||
Return a :class:`Signature` (or its subclass) object for a given callable
|
||||
|
@ -769,7 +772,7 @@ function.
|
|||
.. class:: Parameter(name, kind, *, default=Parameter.empty, annotation=Parameter.empty)
|
||||
|
||||
Parameter objects are *immutable*. Instead of modifying a Parameter object,
|
||||
you can use :meth:`Parameter.replace` to create a modified copy.
|
||||
you can use :meth:`Parameter.replace` or :func:`copy.replace` to create a modified copy.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Parameter objects are picklable and :term:`hashable`.
|
||||
|
@ -892,6 +895,8 @@ function.
|
|||
>>> str(param.replace(default=Parameter.empty, annotation='spam'))
|
||||
"foo:'spam'"
|
||||
|
||||
Parameter objects are also supported by generic function :func:`copy.replace`.
|
||||
|
||||
.. versionchanged:: 3.4
|
||||
In Python 3.3 Parameter objects were allowed to have ``name`` set
|
||||
to ``None`` if their ``kind`` was set to ``POSITIONAL_ONLY``.
|
||||
|
|
|
@ -200,6 +200,8 @@ Standard names are defined for the following types:
|
|||
|
||||
Return a copy of the code object with new values for the specified fields.
|
||||
|
||||
Code objects are also supported by generic function :func:`copy.replace`.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
|
||||
.. data:: CellType
|
||||
|
|
|
@ -115,6 +115,18 @@ array
|
|||
It can be used instead of ``'u'`` type code, which is deprecated.
|
||||
(Contributed by Inada Naoki in :gh:`80480`.)
|
||||
|
||||
copy
|
||||
----
|
||||
|
||||
* Add :func:`copy.replace` function which allows to create a modified copy of
|
||||
an object, which is especially usefule for immutable objects.
|
||||
It supports named tuples created with the factory function
|
||||
:func:`collections.namedtuple`, :class:`~dataclasses.dataclass` instances,
|
||||
various :mod:`datetime` objects, :class:`~inspect.Signature` objects,
|
||||
:class:`~inspect.Parameter` objects, :ref:`code object <code-objects>`, and
|
||||
any user classes which define the :meth:`!__replace__` method.
|
||||
(Contributed by Serhiy Storchaka in :gh:`108751`.)
|
||||
|
||||
dbm
|
||||
---
|
||||
|
||||
|
|
|
@ -1112,6 +1112,8 @@ class date:
|
|||
day = self._day
|
||||
return type(self)(year, month, day)
|
||||
|
||||
__replace__ = replace
|
||||
|
||||
# Comparisons of date objects with other.
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -1637,6 +1639,8 @@ class time:
|
|||
fold = self._fold
|
||||
return type(self)(hour, minute, second, microsecond, tzinfo, fold=fold)
|
||||
|
||||
__replace__ = replace
|
||||
|
||||
# Pickle support.
|
||||
|
||||
def _getstate(self, protocol=3):
|
||||
|
@ -1983,6 +1987,8 @@ class datetime(date):
|
|||
return type(self)(year, month, day, hour, minute, second,
|
||||
microsecond, tzinfo, fold=fold)
|
||||
|
||||
__replace__ = replace
|
||||
|
||||
def _local_timezone(self):
|
||||
if self.tzinfo is None:
|
||||
ts = self._mktime()
|
||||
|
|
|
@ -495,6 +495,7 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
|
|||
'_field_defaults': field_defaults,
|
||||
'__new__': __new__,
|
||||
'_make': _make,
|
||||
'__replace__': _replace,
|
||||
'_replace': _replace,
|
||||
'__repr__': __repr__,
|
||||
'_asdict': _asdict,
|
||||
|
|
13
Lib/copy.py
13
Lib/copy.py
|
@ -290,3 +290,16 @@ def _reconstruct(x, memo, func, args,
|
|||
return y
|
||||
|
||||
del types, weakref
|
||||
|
||||
|
||||
def replace(obj, /, **changes):
|
||||
"""Return a new object replacing specified fields with new values.
|
||||
|
||||
This is especially useful for immutable objects, like named tuples or
|
||||
frozen dataclasses.
|
||||
"""
|
||||
cls = obj.__class__
|
||||
func = getattr(cls, '__replace__', None)
|
||||
if func is None:
|
||||
raise TypeError(f"replace() does not support {cls.__name__} objects")
|
||||
return func(obj, **changes)
|
||||
|
|
|
@ -1073,6 +1073,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
|||
globals,
|
||||
slots,
|
||||
))
|
||||
_set_new_attribute(cls, '__replace__', _replace)
|
||||
|
||||
# Get the fields as a list, and include only real fields. This is
|
||||
# used in all of the following methods.
|
||||
|
@ -1546,12 +1547,14 @@ def replace(obj, /, **changes):
|
|||
c1 = replace(c, x=3)
|
||||
assert c1.x == 3 and c1.y == 2
|
||||
"""
|
||||
|
||||
# We're going to mutate 'changes', but that's okay because it's a
|
||||
# new dict, even if called with 'replace(obj, **my_changes)'.
|
||||
|
||||
if not _is_dataclass_instance(obj):
|
||||
raise TypeError("replace() should be called on dataclass instances")
|
||||
return _replace(obj, **changes)
|
||||
|
||||
|
||||
def _replace(obj, /, **changes):
|
||||
# We're going to mutate 'changes', but that's okay because it's a
|
||||
# new dict, even if called with 'replace(obj, **my_changes)'.
|
||||
|
||||
# It's an error to have init=False fields in 'changes'.
|
||||
# If a field is not in 'changes', read its value from the provided obj.
|
||||
|
|
|
@ -2870,6 +2870,8 @@ class Parameter:
|
|||
|
||||
return formatted
|
||||
|
||||
__replace__ = replace
|
||||
|
||||
def __repr__(self):
|
||||
return '<{} "{}">'.format(self.__class__.__name__, self)
|
||||
|
||||
|
@ -3130,6 +3132,8 @@ class Signature:
|
|||
return type(self)(parameters,
|
||||
return_annotation=return_annotation)
|
||||
|
||||
__replace__ = replace
|
||||
|
||||
def _hash_basis(self):
|
||||
params = tuple(param for param in self.parameters.values()
|
||||
if param.kind != _KEYWORD_ONLY)
|
||||
|
|
|
@ -1699,22 +1699,23 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
|
|||
cls = self.theclass
|
||||
args = [1, 2, 3]
|
||||
base = cls(*args)
|
||||
self.assertEqual(base, base.replace())
|
||||
self.assertEqual(base.replace(), base)
|
||||
self.assertEqual(copy.replace(base), base)
|
||||
|
||||
i = 0
|
||||
for name, newval in (("year", 2),
|
||||
("month", 3),
|
||||
("day", 4)):
|
||||
changes = (("year", 2),
|
||||
("month", 3),
|
||||
("day", 4))
|
||||
for i, (name, newval) in enumerate(changes):
|
||||
newargs = args[:]
|
||||
newargs[i] = newval
|
||||
expected = cls(*newargs)
|
||||
got = base.replace(**{name: newval})
|
||||
self.assertEqual(expected, got)
|
||||
i += 1
|
||||
self.assertEqual(base.replace(**{name: newval}), expected)
|
||||
self.assertEqual(copy.replace(base, **{name: newval}), expected)
|
||||
|
||||
# Out of bounds.
|
||||
base = cls(2000, 2, 29)
|
||||
self.assertRaises(ValueError, base.replace, year=2001)
|
||||
self.assertRaises(ValueError, copy.replace, base, year=2001)
|
||||
|
||||
def test_subclass_replace(self):
|
||||
class DateSubclass(self.theclass):
|
||||
|
@ -1722,6 +1723,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
|
|||
|
||||
dt = DateSubclass(2012, 1, 1)
|
||||
self.assertIs(type(dt.replace(year=2013)), DateSubclass)
|
||||
self.assertIs(type(copy.replace(dt, year=2013)), DateSubclass)
|
||||
|
||||
def test_subclass_date(self):
|
||||
|
||||
|
@ -2856,26 +2858,27 @@ class TestDateTime(TestDate):
|
|||
cls = self.theclass
|
||||
args = [1, 2, 3, 4, 5, 6, 7]
|
||||
base = cls(*args)
|
||||
self.assertEqual(base, base.replace())
|
||||
self.assertEqual(base.replace(), base)
|
||||
self.assertEqual(copy.replace(base), base)
|
||||
|
||||
i = 0
|
||||
for name, newval in (("year", 2),
|
||||
("month", 3),
|
||||
("day", 4),
|
||||
("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8)):
|
||||
changes = (("year", 2),
|
||||
("month", 3),
|
||||
("day", 4),
|
||||
("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8))
|
||||
for i, (name, newval) in enumerate(changes):
|
||||
newargs = args[:]
|
||||
newargs[i] = newval
|
||||
expected = cls(*newargs)
|
||||
got = base.replace(**{name: newval})
|
||||
self.assertEqual(expected, got)
|
||||
i += 1
|
||||
self.assertEqual(base.replace(**{name: newval}), expected)
|
||||
self.assertEqual(copy.replace(base, **{name: newval}), expected)
|
||||
|
||||
# Out of bounds.
|
||||
base = cls(2000, 2, 29)
|
||||
self.assertRaises(ValueError, base.replace, year=2001)
|
||||
self.assertRaises(ValueError, copy.replace, base, year=2001)
|
||||
|
||||
@support.run_with_tz('EDT4')
|
||||
def test_astimezone(self):
|
||||
|
@ -3671,19 +3674,19 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
|
|||
cls = self.theclass
|
||||
args = [1, 2, 3, 4]
|
||||
base = cls(*args)
|
||||
self.assertEqual(base, base.replace())
|
||||
self.assertEqual(base.replace(), base)
|
||||
self.assertEqual(copy.replace(base), base)
|
||||
|
||||
i = 0
|
||||
for name, newval in (("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8)):
|
||||
changes = (("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8))
|
||||
for i, (name, newval) in enumerate(changes):
|
||||
newargs = args[:]
|
||||
newargs[i] = newval
|
||||
expected = cls(*newargs)
|
||||
got = base.replace(**{name: newval})
|
||||
self.assertEqual(expected, got)
|
||||
i += 1
|
||||
self.assertEqual(base.replace(**{name: newval}), expected)
|
||||
self.assertEqual(copy.replace(base, **{name: newval}), expected)
|
||||
|
||||
# Out of bounds.
|
||||
base = cls(1)
|
||||
|
@ -3691,6 +3694,10 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
|
|||
self.assertRaises(ValueError, base.replace, minute=-1)
|
||||
self.assertRaises(ValueError, base.replace, second=100)
|
||||
self.assertRaises(ValueError, base.replace, microsecond=1000000)
|
||||
self.assertRaises(ValueError, copy.replace, base, hour=24)
|
||||
self.assertRaises(ValueError, copy.replace, base, minute=-1)
|
||||
self.assertRaises(ValueError, copy.replace, base, second=100)
|
||||
self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
|
||||
|
||||
def test_subclass_replace(self):
|
||||
class TimeSubclass(self.theclass):
|
||||
|
@ -3698,6 +3705,7 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
|
|||
|
||||
ctime = TimeSubclass(12, 30)
|
||||
self.assertIs(type(ctime.replace(hour=10)), TimeSubclass)
|
||||
self.assertIs(type(copy.replace(ctime, hour=10)), TimeSubclass)
|
||||
|
||||
def test_subclass_time(self):
|
||||
|
||||
|
@ -4085,31 +4093,37 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
|
|||
zm200 = FixedOffset(timedelta(minutes=-200), "-200")
|
||||
args = [1, 2, 3, 4, z100]
|
||||
base = cls(*args)
|
||||
self.assertEqual(base, base.replace())
|
||||
self.assertEqual(base.replace(), base)
|
||||
self.assertEqual(copy.replace(base), base)
|
||||
|
||||
i = 0
|
||||
for name, newval in (("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8),
|
||||
("tzinfo", zm200)):
|
||||
changes = (("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8),
|
||||
("tzinfo", zm200))
|
||||
for i, (name, newval) in enumerate(changes):
|
||||
newargs = args[:]
|
||||
newargs[i] = newval
|
||||
expected = cls(*newargs)
|
||||
got = base.replace(**{name: newval})
|
||||
self.assertEqual(expected, got)
|
||||
i += 1
|
||||
self.assertEqual(base.replace(**{name: newval}), expected)
|
||||
self.assertEqual(copy.replace(base, **{name: newval}), expected)
|
||||
|
||||
# Ensure we can get rid of a tzinfo.
|
||||
self.assertEqual(base.tzname(), "+100")
|
||||
base2 = base.replace(tzinfo=None)
|
||||
self.assertIsNone(base2.tzinfo)
|
||||
self.assertIsNone(base2.tzname())
|
||||
base22 = copy.replace(base, tzinfo=None)
|
||||
self.assertIsNone(base22.tzinfo)
|
||||
self.assertIsNone(base22.tzname())
|
||||
|
||||
# Ensure we can add one.
|
||||
base3 = base2.replace(tzinfo=z100)
|
||||
self.assertEqual(base, base3)
|
||||
self.assertIs(base.tzinfo, base3.tzinfo)
|
||||
base32 = copy.replace(base22, tzinfo=z100)
|
||||
self.assertEqual(base, base32)
|
||||
self.assertIs(base.tzinfo, base32.tzinfo)
|
||||
|
||||
# Out of bounds.
|
||||
base = cls(1)
|
||||
|
@ -4117,6 +4131,10 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
|
|||
self.assertRaises(ValueError, base.replace, minute=-1)
|
||||
self.assertRaises(ValueError, base.replace, second=100)
|
||||
self.assertRaises(ValueError, base.replace, microsecond=1000000)
|
||||
self.assertRaises(ValueError, copy.replace, base, hour=24)
|
||||
self.assertRaises(ValueError, copy.replace, base, minute=-1)
|
||||
self.assertRaises(ValueError, copy.replace, base, second=100)
|
||||
self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
|
||||
|
||||
def test_mixed_compare(self):
|
||||
t1 = self.theclass(1, 2, 3)
|
||||
|
@ -4885,38 +4903,45 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
|
|||
zm200 = FixedOffset(timedelta(minutes=-200), "-200")
|
||||
args = [1, 2, 3, 4, 5, 6, 7, z100]
|
||||
base = cls(*args)
|
||||
self.assertEqual(base, base.replace())
|
||||
self.assertEqual(base.replace(), base)
|
||||
self.assertEqual(copy.replace(base), base)
|
||||
|
||||
i = 0
|
||||
for name, newval in (("year", 2),
|
||||
("month", 3),
|
||||
("day", 4),
|
||||
("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8),
|
||||
("tzinfo", zm200)):
|
||||
changes = (("year", 2),
|
||||
("month", 3),
|
||||
("day", 4),
|
||||
("hour", 5),
|
||||
("minute", 6),
|
||||
("second", 7),
|
||||
("microsecond", 8),
|
||||
("tzinfo", zm200))
|
||||
for i, (name, newval) in enumerate(changes):
|
||||
newargs = args[:]
|
||||
newargs[i] = newval
|
||||
expected = cls(*newargs)
|
||||
got = base.replace(**{name: newval})
|
||||
self.assertEqual(expected, got)
|
||||
i += 1
|
||||
self.assertEqual(base.replace(**{name: newval}), expected)
|
||||
self.assertEqual(copy.replace(base, **{name: newval}), expected)
|
||||
|
||||
# Ensure we can get rid of a tzinfo.
|
||||
self.assertEqual(base.tzname(), "+100")
|
||||
base2 = base.replace(tzinfo=None)
|
||||
self.assertIsNone(base2.tzinfo)
|
||||
self.assertIsNone(base2.tzname())
|
||||
base22 = copy.replace(base, tzinfo=None)
|
||||
self.assertIsNone(base22.tzinfo)
|
||||
self.assertIsNone(base22.tzname())
|
||||
|
||||
# Ensure we can add one.
|
||||
base3 = base2.replace(tzinfo=z100)
|
||||
self.assertEqual(base, base3)
|
||||
self.assertIs(base.tzinfo, base3.tzinfo)
|
||||
base32 = copy.replace(base22, tzinfo=z100)
|
||||
self.assertEqual(base, base32)
|
||||
self.assertIs(base.tzinfo, base32.tzinfo)
|
||||
|
||||
# Out of bounds.
|
||||
base = cls(2000, 2, 29)
|
||||
self.assertRaises(ValueError, base.replace, year=2001)
|
||||
self.assertRaises(ValueError, copy.replace, base, year=2001)
|
||||
|
||||
def test_more_astimezone(self):
|
||||
# The inherited test_astimezone covered some trivial and error cases.
|
||||
|
|
|
@ -125,6 +125,7 @@ consts: ('None',)
|
|||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import sys
|
||||
import threading
|
||||
|
@ -280,11 +281,17 @@ class CodeTest(unittest.TestCase):
|
|||
with self.subTest(attr=attr, value=value):
|
||||
new_code = code.replace(**{attr: value})
|
||||
self.assertEqual(getattr(new_code, attr), value)
|
||||
new_code = copy.replace(code, **{attr: value})
|
||||
self.assertEqual(getattr(new_code, attr), value)
|
||||
|
||||
new_code = code.replace(co_varnames=code2.co_varnames,
|
||||
co_nlocals=code2.co_nlocals)
|
||||
self.assertEqual(new_code.co_varnames, code2.co_varnames)
|
||||
self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
|
||||
new_code = copy.replace(code, co_varnames=code2.co_varnames,
|
||||
co_nlocals=code2.co_nlocals)
|
||||
self.assertEqual(new_code.co_varnames, code2.co_varnames)
|
||||
self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
|
||||
|
||||
def test_nlocals_mismatch(self):
|
||||
def func():
|
||||
|
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import copyreg
|
||||
import weakref
|
||||
import abc
|
||||
from operator import le, lt, ge, gt, eq, ne
|
||||
from operator import le, lt, ge, gt, eq, ne, attrgetter
|
||||
|
||||
import unittest
|
||||
from test import support
|
||||
|
@ -899,7 +899,71 @@ class TestCopy(unittest.TestCase):
|
|||
g.b()
|
||||
|
||||
|
||||
class TestReplace(unittest.TestCase):
|
||||
|
||||
def test_unsupported(self):
|
||||
self.assertRaises(TypeError, copy.replace, 1)
|
||||
self.assertRaises(TypeError, copy.replace, [])
|
||||
self.assertRaises(TypeError, copy.replace, {})
|
||||
def f(): pass
|
||||
self.assertRaises(TypeError, copy.replace, f)
|
||||
class A: pass
|
||||
self.assertRaises(TypeError, copy.replace, A)
|
||||
self.assertRaises(TypeError, copy.replace, A())
|
||||
|
||||
def test_replace_method(self):
|
||||
class A:
|
||||
def __new__(cls, x, y=0):
|
||||
self = object.__new__(cls)
|
||||
self.x = x
|
||||
self.y = y
|
||||
return self
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.z = self.x + self.y
|
||||
|
||||
def __replace__(self, **changes):
|
||||
x = changes.get('x', self.x)
|
||||
y = changes.get('y', self.y)
|
||||
return type(self)(x, y)
|
||||
|
||||
attrs = attrgetter('x', 'y', 'z')
|
||||
a = A(11, 22)
|
||||
self.assertEqual(attrs(copy.replace(a)), (11, 22, 33))
|
||||
self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23))
|
||||
self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13))
|
||||
self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3))
|
||||
|
||||
def test_namedtuple(self):
|
||||
from collections import namedtuple
|
||||
Point = namedtuple('Point', 'x y', defaults=(0,))
|
||||
p = Point(11, 22)
|
||||
self.assertEqual(copy.replace(p), (11, 22))
|
||||
self.assertEqual(copy.replace(p, x=1), (1, 22))
|
||||
self.assertEqual(copy.replace(p, y=2), (11, 2))
|
||||
self.assertEqual(copy.replace(p, x=1, y=2), (1, 2))
|
||||
with self.assertRaisesRegex(ValueError, 'unexpected field name'):
|
||||
copy.replace(p, x=1, error=2)
|
||||
|
||||
def test_dataclass(self):
|
||||
from dataclasses import dataclass
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
y: int = 0
|
||||
|
||||
attrs = attrgetter('x', 'y')
|
||||
c = C(11, 22)
|
||||
self.assertEqual(attrs(copy.replace(c)), (11, 22))
|
||||
self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22))
|
||||
self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2))
|
||||
self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2))
|
||||
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
|
||||
copy.replace(c, x=1, error=2)
|
||||
|
||||
|
||||
def global_foo(x, y): return x+y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import builtins
|
||||
import collections
|
||||
import copy
|
||||
import datetime
|
||||
import functools
|
||||
import importlib
|
||||
|
@ -3830,6 +3831,28 @@ class TestSignatureObject(unittest.TestCase):
|
|||
P('bar', P.VAR_POSITIONAL)])),
|
||||
'(foo, /, *bar)')
|
||||
|
||||
def test_signature_replace_parameters(self):
|
||||
def test(a, b) -> 42:
|
||||
pass
|
||||
|
||||
sig = inspect.signature(test)
|
||||
parameters = sig.parameters
|
||||
sig = sig.replace(parameters=list(parameters.values())[1:])
|
||||
self.assertEqual(list(sig.parameters), ['b'])
|
||||
self.assertEqual(sig.parameters['b'], parameters['b'])
|
||||
self.assertEqual(sig.return_annotation, 42)
|
||||
sig = sig.replace(parameters=())
|
||||
self.assertEqual(dict(sig.parameters), {})
|
||||
|
||||
sig = inspect.signature(test)
|
||||
parameters = sig.parameters
|
||||
sig = copy.replace(sig, parameters=list(parameters.values())[1:])
|
||||
self.assertEqual(list(sig.parameters), ['b'])
|
||||
self.assertEqual(sig.parameters['b'], parameters['b'])
|
||||
self.assertEqual(sig.return_annotation, 42)
|
||||
sig = copy.replace(sig, parameters=())
|
||||
self.assertEqual(dict(sig.parameters), {})
|
||||
|
||||
def test_signature_replace_anno(self):
|
||||
def test() -> 42:
|
||||
pass
|
||||
|
@ -3843,6 +3866,15 @@ class TestSignatureObject(unittest.TestCase):
|
|||
self.assertEqual(sig.return_annotation, 42)
|
||||
self.assertEqual(sig, inspect.signature(test))
|
||||
|
||||
sig = inspect.signature(test)
|
||||
sig = copy.replace(sig, return_annotation=None)
|
||||
self.assertIs(sig.return_annotation, None)
|
||||
sig = copy.replace(sig, return_annotation=sig.empty)
|
||||
self.assertIs(sig.return_annotation, sig.empty)
|
||||
sig = copy.replace(sig, return_annotation=42)
|
||||
self.assertEqual(sig.return_annotation, 42)
|
||||
self.assertEqual(sig, inspect.signature(test))
|
||||
|
||||
def test_signature_replaced(self):
|
||||
def test():
|
||||
pass
|
||||
|
@ -4187,41 +4219,66 @@ class TestParameterObject(unittest.TestCase):
|
|||
p = inspect.Parameter('foo', default=42,
|
||||
kind=inspect.Parameter.KEYWORD_ONLY)
|
||||
|
||||
self.assertIsNot(p, p.replace())
|
||||
self.assertEqual(p, p.replace())
|
||||
self.assertIsNot(p.replace(), p)
|
||||
self.assertEqual(p.replace(), p)
|
||||
self.assertIsNot(copy.replace(p), p)
|
||||
self.assertEqual(copy.replace(p), p)
|
||||
|
||||
p2 = p.replace(annotation=1)
|
||||
self.assertEqual(p2.annotation, 1)
|
||||
p2 = p2.replace(annotation=p2.empty)
|
||||
self.assertEqual(p, p2)
|
||||
self.assertEqual(p2, p)
|
||||
p3 = copy.replace(p, annotation=1)
|
||||
self.assertEqual(p3.annotation, 1)
|
||||
p3 = copy.replace(p3, annotation=p3.empty)
|
||||
self.assertEqual(p3, p)
|
||||
|
||||
p2 = p2.replace(name='bar')
|
||||
self.assertEqual(p2.name, 'bar')
|
||||
self.assertNotEqual(p2, p)
|
||||
p3 = copy.replace(p3, name='bar')
|
||||
self.assertEqual(p3.name, 'bar')
|
||||
self.assertNotEqual(p3, p)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'name is a required attribute'):
|
||||
p2 = p2.replace(name=p2.empty)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'name is a required attribute'):
|
||||
p3 = copy.replace(p3, name=p3.empty)
|
||||
|
||||
p2 = p2.replace(name='foo', default=None)
|
||||
self.assertIs(p2.default, None)
|
||||
self.assertNotEqual(p2, p)
|
||||
p3 = copy.replace(p3, name='foo', default=None)
|
||||
self.assertIs(p3.default, None)
|
||||
self.assertNotEqual(p3, p)
|
||||
|
||||
p2 = p2.replace(name='foo', default=p2.empty)
|
||||
self.assertIs(p2.default, p2.empty)
|
||||
|
||||
p3 = copy.replace(p3, name='foo', default=p3.empty)
|
||||
self.assertIs(p3.default, p3.empty)
|
||||
|
||||
p2 = p2.replace(default=42, kind=p2.POSITIONAL_OR_KEYWORD)
|
||||
self.assertEqual(p2.kind, p2.POSITIONAL_OR_KEYWORD)
|
||||
self.assertNotEqual(p2, p)
|
||||
p3 = copy.replace(p3, default=42, kind=p3.POSITIONAL_OR_KEYWORD)
|
||||
self.assertEqual(p3.kind, p3.POSITIONAL_OR_KEYWORD)
|
||||
self.assertNotEqual(p3, p)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"value <class 'inspect._empty'> "
|
||||
"is not a valid Parameter.kind"):
|
||||
p2 = p2.replace(kind=p2.empty)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"value <class 'inspect._empty'> "
|
||||
"is not a valid Parameter.kind"):
|
||||
p3 = copy.replace(p3, kind=p3.empty)
|
||||
|
||||
p2 = p2.replace(kind=p2.KEYWORD_ONLY)
|
||||
self.assertEqual(p2, p)
|
||||
p3 = copy.replace(p3, kind=p3.KEYWORD_ONLY)
|
||||
self.assertEqual(p3, p)
|
||||
|
||||
def test_signature_parameter_positional_only(self):
|
||||
with self.assertRaisesRegex(TypeError, 'name must be a str'):
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Add :func:`copy.replace` function which allows to create a modified copy of
|
||||
an object. It supports named tuples, dataclasses, and many other objects.
|
|
@ -3590,6 +3590,8 @@ static PyMethodDef date_methods[] = {
|
|||
{"replace", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS,
|
||||
PyDoc_STR("Return date with new specified fields.")},
|
||||
|
||||
{"__replace__", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS},
|
||||
|
||||
{"__reduce__", (PyCFunction)date_reduce, METH_NOARGS,
|
||||
PyDoc_STR("__reduce__() -> (cls, state)")},
|
||||
|
||||
|
@ -4719,6 +4721,8 @@ static PyMethodDef time_methods[] = {
|
|||
{"replace", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS,
|
||||
PyDoc_STR("Return time with new specified fields.")},
|
||||
|
||||
{"__replace__", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS},
|
||||
|
||||
{"fromisoformat", (PyCFunction)time_fromisoformat, METH_O | METH_CLASS,
|
||||
PyDoc_STR("string -> time from a string in ISO 8601 format")},
|
||||
|
||||
|
@ -6579,6 +6583,8 @@ static PyMethodDef datetime_methods[] = {
|
|||
{"replace", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS,
|
||||
PyDoc_STR("Return datetime with new specified fields.")},
|
||||
|
||||
{"__replace__", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS},
|
||||
|
||||
{"astimezone", _PyCFunction_CAST(datetime_astimezone), METH_VARARGS | METH_KEYWORDS,
|
||||
PyDoc_STR("tz -> convert to local time in new timezone tz\n")},
|
||||
|
||||
|
|
|
@ -2145,6 +2145,7 @@ static struct PyMethodDef code_methods[] = {
|
|||
{"co_positions", (PyCFunction)code_positionsiterator, METH_NOARGS},
|
||||
CODE_REPLACE_METHODDEF
|
||||
CODE__VARNAME_FROM_OPARG_METHODDEF
|
||||
{"__replace__", _PyCFunction_CAST(code_replace), METH_FASTCALL|METH_KEYWORDS},
|
||||
{NULL, NULL} /* sentinel */
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue