From a78dce4b33ee094141a966002bd3bac1e32c5262 Mon Sep 17 00:00:00 2001 From: ShaharNaveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Tue, 8 Jul 2025 14:34:41 +0300 Subject: [PATCH] Update copy from 3.13.5 --- Lib/copy.py | 30 +++++++++++----- Lib/test/test_copy.py | 84 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 9 deletions(-) diff --git a/Lib/copy.py b/Lib/copy.py index da2908ef62..2a4606246a 100644 --- a/Lib/copy.py +++ b/Lib/copy.py @@ -4,8 +4,9 @@ import copy - x = copy.copy(y) # make a shallow copy of y - x = copy.deepcopy(y) # make a deep copy of y + x = copy.copy(y) # make a shallow copy of y + x = copy.deepcopy(y) # make a deep copy of y + x = copy.replace(y, a=1, b=2) # new object with fields replaced, as defined by `__replace__` For module specific errors, copy.Error is raised. @@ -56,7 +57,7 @@ class Error(Exception): pass error = Error # backward compatibility -__all__ = ["Error", "copy", "deepcopy"] +__all__ = ["Error", "copy", "deepcopy", "replace"] def copy(x): """Shallow copy operation on arbitrary Python objects. @@ -121,13 +122,13 @@ def deepcopy(x, memo=None, _nil=[]): See the module's __doc__ string for more info. """ + d = id(x) if memo is None: memo = {} - - d = id(x) - y = memo.get(d, _nil) - if y is not _nil: - return y + else: + y = memo.get(d, _nil) + if y is not _nil: + return y cls = type(x) @@ -290,3 +291,16 @@ def _reconstruct(x, memo, func, args, return y del types, weakref + + +def replace(obj, /, **changes): + """Return a new object replacing specified fields with new values. + + This is especially useful for immutable objects, like named tuples or + frozen dataclasses. + """ + cls = obj.__class__ + func = getattr(cls, '__replace__', None) + if func is None: + raise TypeError(f"replace() does not support {cls.__name__} objects") + return func(obj, **changes) diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index cf3dc57930..2f9d8ed9b6 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -4,7 +4,7 @@ import copyreg import weakref import abc -from operator import le, lt, ge, gt, eq, ne +from operator import le, lt, ge, gt, eq, ne, attrgetter import unittest from test import support @@ -903,7 +903,89 @@ def m(self): g.b() +class TestReplace(unittest.TestCase): + + def test_unsupported(self): + self.assertRaises(TypeError, copy.replace, 1) + self.assertRaises(TypeError, copy.replace, []) + self.assertRaises(TypeError, copy.replace, {}) + def f(): pass + self.assertRaises(TypeError, copy.replace, f) + class A: pass + self.assertRaises(TypeError, copy.replace, A) + self.assertRaises(TypeError, copy.replace, A()) + + def test_replace_method(self): + class A: + def __new__(cls, x, y=0): + self = object.__new__(cls) + self.x = x + self.y = y + return self + + def __init__(self, *args, **kwargs): + self.z = self.x + self.y + + def __replace__(self, **changes): + x = changes.get('x', self.x) + y = changes.get('y', self.y) + return type(self)(x, y) + + attrs = attrgetter('x', 'y', 'z') + a = A(11, 22) + self.assertEqual(attrs(copy.replace(a)), (11, 22, 33)) + self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23)) + self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13)) + self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3)) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_namedtuple(self): + from collections import namedtuple + from typing import NamedTuple + PointFromCall = namedtuple('Point', 'x y', defaults=(0,)) + class PointFromInheritance(PointFromCall): + pass + class PointFromClass(NamedTuple): + x: int + y: int = 0 + for Point in (PointFromCall, PointFromInheritance, PointFromClass): + with self.subTest(Point=Point): + p = Point(11, 22) + self.assertIsInstance(p, Point) + self.assertEqual(copy.replace(p), (11, 22)) + self.assertIsInstance(copy.replace(p), Point) + self.assertEqual(copy.replace(p, x=1), (1, 22)) + self.assertEqual(copy.replace(p, y=2), (11, 2)) + self.assertEqual(copy.replace(p, x=1, y=2), (1, 2)) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(p, x=1, error=2) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_dataclass(self): + from dataclasses import dataclass + @dataclass + class C: + x: int + y: int = 0 + + attrs = attrgetter('x', 'y') + c = C(11, 22) + self.assertEqual(attrs(copy.replace(c)), (11, 22)) + self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22)) + self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2)) + self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2)) + with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): + copy.replace(c, x=1, error=2) + + +class MiscTestCase(unittest.TestCase): + def test__all__(self): + support.check__all__(self, copy, not_exported={"dispatch_table", "error"}) + def global_foo(x, y): return x+y + if __name__ == "__main__": unittest.main()