Skip to content

Update pickle{tools,}.py from 3.13.5 #6064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 99 additions & 49 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,17 @@ def load_frame(self, frame_size):
# Tools used for pickling.

def _getattribute(obj, name):
top = obj
for subpath in name.split('.'):
if subpath == '<locals>':
raise AttributeError("Can't get local attribute {!r} on {!r}"
.format(name, obj))
.format(name, top))
try:
parent = obj
obj = getattr(obj, subpath)
except AttributeError:
raise AttributeError("Can't get attribute {!r} on {!r}"
.format(name, obj)) from None
.format(name, top)) from None
return obj, parent

def whichmodule(obj, name):
Expand Down Expand Up @@ -396,6 +397,8 @@ def decode_long(data):
return int.from_bytes(data, byteorder='little', signed=True)


_NoValue = object()

# Pickling machinery

class _Pickler:
Expand Down Expand Up @@ -530,10 +533,11 @@ def save(self, obj, save_persistent_id=True):
self.framer.commit_frame()

# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
if save_persistent_id:
pid = self.persistent_id(obj)
if pid is not None:
self.save_pers(pid)
return

# Check the memo
x = self.memo.get(id(obj))
Expand All @@ -542,8 +546,8 @@ def save(self, obj, save_persistent_id=True):
return

rv = NotImplemented
reduce = getattr(self, "reducer_override", None)
if reduce is not None:
reduce = getattr(self, "reducer_override", _NoValue)
if reduce is not _NoValue:
rv = reduce(obj)

if rv is NotImplemented:
Expand All @@ -556,8 +560,8 @@ def save(self, obj, save_persistent_id=True):

# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
if reduce is not None:
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t, _NoValue)
if reduce is not _NoValue:
rv = reduce(obj)
else:
# Check for a class with a custom metaclass; treat as regular
Expand All @@ -567,12 +571,12 @@ def save(self, obj, save_persistent_id=True):
return

# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
if reduce is not None:
reduce = getattr(obj, "__reduce_ex__", _NoValue)
if reduce is not _NoValue:
rv = reduce(self.proto)
else:
reduce = getattr(obj, "__reduce__", None)
if reduce is not None:
reduce = getattr(obj, "__reduce__", _NoValue)
if reduce is not _NoValue:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
Expand Down Expand Up @@ -780,14 +784,10 @@ def save_float(self, obj):
self.write(FLOAT + repr(obj).encode("ascii") + b'\n')
dispatch[float] = save_float

def save_bytes(self, obj):
if self.proto < 3:
if not obj: # bytes object is empty
self.save_reduce(bytes, (), obj=obj)
else:
self.save_reduce(codecs.encode,
(str(obj, 'latin1'), 'latin1'), obj=obj)
return
def _save_bytes_no_memo(self, obj):
# helper for writing bytes objects for protocol >= 3
# without memoizing them
assert self.proto >= 3
n = len(obj)
if n <= 0xff:
self.write(SHORT_BINBYTES + pack("<B", n) + obj)
Expand All @@ -797,28 +797,44 @@ def save_bytes(self, obj):
self._write_large_bytes(BINBYTES + pack("<I", n), obj)
else:
self.write(BINBYTES + pack("<I", n) + obj)

def save_bytes(self, obj):
if self.proto < 3:
if not obj: # bytes object is empty
self.save_reduce(bytes, (), obj=obj)
else:
self.save_reduce(codecs.encode,
(str(obj, 'latin1'), 'latin1'), obj=obj)
return
self._save_bytes_no_memo(obj)
self.memoize(obj)
dispatch[bytes] = save_bytes

def _save_bytearray_no_memo(self, obj):
# helper for writing bytearray objects for protocol >= 5
# without memoizing them
assert self.proto >= 5
n = len(obj)
if n >= self.framer._FRAME_SIZE_TARGET:
self._write_large_bytes(BYTEARRAY8 + pack("<Q", n), obj)
else:
self.write(BYTEARRAY8 + pack("<Q", n) + obj)

def save_bytearray(self, obj):
if self.proto < 5:
if not obj: # bytearray is empty
self.save_reduce(bytearray, (), obj=obj)
else:
self.save_reduce(bytearray, (bytes(obj),), obj=obj)
return
n = len(obj)
if n >= self.framer._FRAME_SIZE_TARGET:
self._write_large_bytes(BYTEARRAY8 + pack("<Q", n), obj)
else:
self.write(BYTEARRAY8 + pack("<Q", n) + obj)
self._save_bytearray_no_memo(obj)
self.memoize(obj)
dispatch[bytearray] = save_bytearray

if _HAVE_PICKLE_BUFFER:
def save_picklebuffer(self, obj):
if self.proto < 5:
raise PicklingError("PickleBuffer can only pickled with "
raise PicklingError("PickleBuffer can only be pickled with "
"protocol >= 5")
with obj.raw() as m:
if not m.contiguous:
Expand All @@ -830,10 +846,18 @@ def save_picklebuffer(self, obj):
if in_band:
# Write data in-band
# XXX The C implementation avoids a copy here
buf = m.tobytes()
in_memo = id(buf) in self.memo
if m.readonly:
self.save_bytes(m.tobytes())
if in_memo:
self._save_bytes_no_memo(buf)
else:
self.save_bytes(buf)
else:
self.save_bytearray(m.tobytes())
if in_memo:
self._save_bytearray_no_memo(buf)
else:
self.save_bytearray(buf)
else:
# Write data out-of-band
self.write(NEXT_BUFFER)
Expand Down Expand Up @@ -1070,11 +1094,16 @@ def save_global(self, obj, name=None):
(obj, module_name, name))

if self.proto >= 2:
code = _extension_registry.get((module_name, name))
if code:
assert code > 0
code = _extension_registry.get((module_name, name), _NoValue)
if code is not _NoValue:
if code <= 0xff:
write(EXT1 + pack("<B", code))
data = pack("<B", code)
if data == b'\0':
# Should never happen in normal circumstances,
# since the type and the value of the code are
# checked in copyreg.add_extension().
raise RuntimeError("extension code 0 is out of range")
write(EXT1 + data)
elif code <= 0xffff:
write(EXT2 + pack("<H", code))
else:
Expand All @@ -1088,11 +1117,35 @@ def save_global(self, obj, name=None):
self.save(module_name)
self.save(name)
write(STACK_GLOBAL)
elif parent is not module:
self.save_reduce(getattr, (parent, lastname))
elif self.proto >= 3:
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
elif '.' in name:
# In protocol < 4, objects with multi-part __qualname__
# are represented as
# getattr(getattr(..., attrname1), attrname2).
dotted_path = name.split('.')
name = dotted_path.pop(0)
save = self.save
for attrname in dotted_path:
save(getattr)
if self.proto < 2:
write(MARK)
self._save_toplevel_by_name(module_name, name)
for attrname in dotted_path:
save(attrname)
if self.proto < 2:
write(TUPLE)
else:
write(TUPLE2)
write(REDUCE)
else:
self._save_toplevel_by_name(module_name, name)

self.memoize(obj)

def _save_toplevel_by_name(self, module_name, name):
if self.proto >= 3:
# Non-ASCII identifiers are supported only with protocols >= 3.
self.write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else:
if self.fix_imports:
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
Expand All @@ -1102,14 +1155,12 @@ def save_global(self, obj, name=None):
elif module_name in r_import_mapping:
module_name = r_import_mapping[module_name]
try:
write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
self.write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto)) from None

self.memoize(obj)
"pickle protocol %i" % (module_name, name, self.proto)) from None

def save_type(self, obj):
if obj is type(None):
Expand Down Expand Up @@ -1546,9 +1597,8 @@ def load_ext4(self):
dispatch[EXT4[0]] = load_ext4

def get_extension(self, code):
nil = []
obj = _extension_cache.get(code, nil)
if obj is not nil:
obj = _extension_cache.get(code, _NoValue)
if obj is not _NoValue:
self.append(obj)
return
key = _inverted_registry.get(code)
Expand Down Expand Up @@ -1705,8 +1755,8 @@ def load_build(self):
stack = self.stack
state = stack.pop()
inst = stack[-1]
setstate = getattr(inst, "__setstate__", None)
if setstate is not None:
setstate = getattr(inst, "__setstate__", _NoValue)
if setstate is not _NoValue:
setstate(state)
return
slotstate = None
Expand Down
11 changes: 7 additions & 4 deletions Lib/pickletools.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def read_uint8(f):
doc="Eight-byte unsigned integer, little-endian.")


def read_stringnl(f, decode=True, stripquotes=True):
def read_stringnl(f, decode=True, stripquotes=True, *, encoding='latin-1'):
r"""
>>> import io
>>> read_stringnl(io.BytesIO(b"'abcd'\nefg\n"))
Expand Down Expand Up @@ -356,7 +356,7 @@ def read_stringnl(f, decode=True, stripquotes=True):
raise ValueError("no string quotes around %r" % data)

if decode:
data = codecs.escape_decode(data)[0].decode("ascii")
data = codecs.escape_decode(data)[0].decode(encoding)
return data

stringnl = ArgumentDescriptor(
Expand All @@ -370,7 +370,7 @@ def read_stringnl(f, decode=True, stripquotes=True):
""")

def read_stringnl_noescape(f):
return read_stringnl(f, stripquotes=False)
return read_stringnl(f, stripquotes=False, encoding='utf-8')

stringnl_noescape = ArgumentDescriptor(
name='stringnl_noescape',
Expand Down Expand Up @@ -2513,7 +2513,10 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0):
# make a mild effort to align arguments
line += ' ' * (10 - len(opcode.name))
if arg is not None:
line += ' ' + repr(arg)
if opcode.name in ("STRING", "BINSTRING", "SHORT_BINSTRING"):
line += ' ' + ascii(arg)
else:
line += ' ' + repr(arg)
if markmsg:
line += ' ' + markmsg
if annotate:
Expand Down
Loading
Loading