Skip to content

Commit 72fc3c0

Browse files
authored
Update pickle{tools,}.py from 3.13.5 (#6064)
1 parent 566d9aa commit 72fc3c0

File tree

4 files changed

+106
-127
lines changed

4 files changed

+106
-127
lines changed

Lib/pickle.py

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,17 @@ def load_frame(self, frame_size):
314314
# Tools used for pickling.
315315

316316
def _getattribute(obj, name):
317+
top = obj
317318
for subpath in name.split('.'):
318319
if subpath == '<locals>':
319320
raise AttributeError("Can't get local attribute {!r} on {!r}"
320-
.format(name, obj))
321+
.format(name, top))
321322
try:
322323
parent = obj
323324
obj = getattr(obj, subpath)
324325
except AttributeError:
325326
raise AttributeError("Can't get attribute {!r} on {!r}"
326-
.format(name, obj)) from None
327+
.format(name, top)) from None
327328
return obj, parent
328329

329330
def whichmodule(obj, name):
@@ -396,6 +397,8 @@ def decode_long(data):
396397
return int.from_bytes(data, byteorder='little', signed=True)
397398

398399

400+
_NoValue = object()
401+
399402
# Pickling machinery
400403

401404
class _Pickler:
@@ -530,10 +533,11 @@ def save(self, obj, save_persistent_id=True):
530533
self.framer.commit_frame()
531534

532535
# Check for persistent id (defined by a subclass)
533-
pid = self.persistent_id(obj)
534-
if pid is not None and save_persistent_id:
535-
self.save_pers(pid)
536-
return
536+
if save_persistent_id:
537+
pid = self.persistent_id(obj)
538+
if pid is not None:
539+
self.save_pers(pid)
540+
return
537541

538542
# Check the memo
539543
x = self.memo.get(id(obj))
@@ -542,8 +546,8 @@ def save(self, obj, save_persistent_id=True):
542546
return
543547

544548
rv = NotImplemented
545-
reduce = getattr(self, "reducer_override", None)
546-
if reduce is not None:
549+
reduce = getattr(self, "reducer_override", _NoValue)
550+
if reduce is not _NoValue:
547551
rv = reduce(obj)
548552

549553
if rv is NotImplemented:
@@ -556,8 +560,8 @@ def save(self, obj, save_persistent_id=True):
556560

557561
# Check private dispatch table if any, or else
558562
# copyreg.dispatch_table
559-
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
560-
if reduce is not None:
563+
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t, _NoValue)
564+
if reduce is not _NoValue:
561565
rv = reduce(obj)
562566
else:
563567
# Check for a class with a custom metaclass; treat as regular
@@ -567,12 +571,12 @@ def save(self, obj, save_persistent_id=True):
567571
return
568572

569573
# Check for a __reduce_ex__ method, fall back to __reduce__
570-
reduce = getattr(obj, "__reduce_ex__", None)
571-
if reduce is not None:
574+
reduce = getattr(obj, "__reduce_ex__", _NoValue)
575+
if reduce is not _NoValue:
572576
rv = reduce(self.proto)
573577
else:
574-
reduce = getattr(obj, "__reduce__", None)
575-
if reduce is not None:
578+
reduce = getattr(obj, "__reduce__", _NoValue)
579+
if reduce is not _NoValue:
576580
rv = reduce()
577581
else:
578582
raise PicklingError("Can't pickle %r object: %r" %
@@ -780,14 +784,10 @@ def save_float(self, obj):
780784
self.write(FLOAT + repr(obj).encode("ascii") + b'\n')
781785
dispatch[float] = save_float
782786

783-
def save_bytes(self, obj):
784-
if self.proto < 3:
785-
if not obj: # bytes object is empty
786-
self.save_reduce(bytes, (), obj=obj)
787-
else:
788-
self.save_reduce(codecs.encode,
789-
(str(obj, 'latin1'), 'latin1'), obj=obj)
790-
return
787+
def _save_bytes_no_memo(self, obj):
788+
# helper for writing bytes objects for protocol >= 3
789+
# without memoizing them
790+
assert self.proto >= 3
791791
n = len(obj)
792792
if n <= 0xff:
793793
self.write(SHORT_BINBYTES + pack("<B", n) + obj)
@@ -797,28 +797,44 @@ def save_bytes(self, obj):
797797
self._write_large_bytes(BINBYTES + pack("<I", n), obj)
798798
else:
799799
self.write(BINBYTES + pack("<I", n) + obj)
800+
801+
def save_bytes(self, obj):
802+
if self.proto < 3:
803+
if not obj: # bytes object is empty
804+
self.save_reduce(bytes, (), obj=obj)
805+
else:
806+
self.save_reduce(codecs.encode,
807+
(str(obj, 'latin1'), 'latin1'), obj=obj)
808+
return
809+
self._save_bytes_no_memo(obj)
800810
self.memoize(obj)
801811
dispatch[bytes] = save_bytes
802812

813+
def _save_bytearray_no_memo(self, obj):
814+
# helper for writing bytearray objects for protocol >= 5
815+
# without memoizing them
816+
assert self.proto >= 5
817+
n = len(obj)
818+
if n >= self.framer._FRAME_SIZE_TARGET:
819+
self._write_large_bytes(BYTEARRAY8 + pack("<Q", n), obj)
820+
else:
821+
self.write(BYTEARRAY8 + pack("<Q", n) + obj)
822+
803823
def save_bytearray(self, obj):
804824
if self.proto < 5:
805825
if not obj: # bytearray is empty
806826
self.save_reduce(bytearray, (), obj=obj)
807827
else:
808828
self.save_reduce(bytearray, (bytes(obj),), obj=obj)
809829
return
810-
n = len(obj)
811-
if n >= self.framer._FRAME_SIZE_TARGET:
812-
self._write_large_bytes(BYTEARRAY8 + pack("<Q", n), obj)
813-
else:
814-
self.write(BYTEARRAY8 + pack("<Q", n) + obj)
830+
self._save_bytearray_no_memo(obj)
815831
self.memoize(obj)
816832
dispatch[bytearray] = save_bytearray
817833

818834
if _HAVE_PICKLE_BUFFER:
819835
def save_picklebuffer(self, obj):
820836
if self.proto < 5:
821-
raise PicklingError("PickleBuffer can only pickled with "
837+
raise PicklingError("PickleBuffer can only be pickled with "
822838
"protocol >= 5")
823839
with obj.raw() as m:
824840
if not m.contiguous:
@@ -830,10 +846,18 @@ def save_picklebuffer(self, obj):
830846
if in_band:
831847
# Write data in-band
832848
# XXX The C implementation avoids a copy here
849+
buf = m.tobytes()
850+
in_memo = id(buf) in self.memo
833851
if m.readonly:
834-
self.save_bytes(m.tobytes())
852+
if in_memo:
853+
self._save_bytes_no_memo(buf)
854+
else:
855+
self.save_bytes(buf)
835856
else:
836-
self.save_bytearray(m.tobytes())
857+
if in_memo:
858+
self._save_bytearray_no_memo(buf)
859+
else:
860+
self.save_bytearray(buf)
837861
else:
838862
# Write data out-of-band
839863
self.write(NEXT_BUFFER)
@@ -1070,11 +1094,16 @@ def save_global(self, obj, name=None):
10701094
(obj, module_name, name))
10711095

10721096
if self.proto >= 2:
1073-
code = _extension_registry.get((module_name, name))
1074-
if code:
1075-
assert code > 0
1097+
code = _extension_registry.get((module_name, name), _NoValue)
1098+
if code is not _NoValue:
10761099
if code <= 0xff:
1077-
write(EXT1 + pack("<B", code))
1100+
data = pack("<B", code)
1101+
if data == b'\0':
1102+
# Should never happen in normal circumstances,
1103+
# since the type and the value of the code are
1104+
# checked in copyreg.add_extension().
1105+
raise RuntimeError("extension code 0 is out of range")
1106+
write(EXT1 + data)
10781107
elif code <= 0xffff:
10791108
write(EXT2 + pack("<H", code))
10801109
else:
@@ -1088,11 +1117,35 @@ def save_global(self, obj, name=None):
10881117
self.save(module_name)
10891118
self.save(name)
10901119
write(STACK_GLOBAL)
1091-
elif parent is not module:
1092-
self.save_reduce(getattr, (parent, lastname))
1093-
elif self.proto >= 3:
1094-
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
1095-
bytes(name, "utf-8") + b'\n')
1120+
elif '.' in name:
1121+
# In protocol < 4, objects with multi-part __qualname__
1122+
# are represented as
1123+
# getattr(getattr(..., attrname1), attrname2).
1124+
dotted_path = name.split('.')
1125+
name = dotted_path.pop(0)
1126+
save = self.save
1127+
for attrname in dotted_path:
1128+
save(getattr)
1129+
if self.proto < 2:
1130+
write(MARK)
1131+
self._save_toplevel_by_name(module_name, name)
1132+
for attrname in dotted_path:
1133+
save(attrname)
1134+
if self.proto < 2:
1135+
write(TUPLE)
1136+
else:
1137+
write(TUPLE2)
1138+
write(REDUCE)
1139+
else:
1140+
self._save_toplevel_by_name(module_name, name)
1141+
1142+
self.memoize(obj)
1143+
1144+
def _save_toplevel_by_name(self, module_name, name):
1145+
if self.proto >= 3:
1146+
# Non-ASCII identifiers are supported only with protocols >= 3.
1147+
self.write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
1148+
bytes(name, "utf-8") + b'\n')
10961149
else:
10971150
if self.fix_imports:
10981151
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
@@ -1102,14 +1155,12 @@ def save_global(self, obj, name=None):
11021155
elif module_name in r_import_mapping:
11031156
module_name = r_import_mapping[module_name]
11041157
try:
1105-
write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
1106-
bytes(name, "ascii") + b'\n')
1158+
self.write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
1159+
bytes(name, "ascii") + b'\n')
11071160
except UnicodeEncodeError:
11081161
raise PicklingError(
11091162
"can't pickle global identifier '%s.%s' using "
1110-
"pickle protocol %i" % (module, name, self.proto)) from None
1111-
1112-
self.memoize(obj)
1163+
"pickle protocol %i" % (module_name, name, self.proto)) from None
11131164

11141165
def save_type(self, obj):
11151166
if obj is type(None):
@@ -1546,9 +1597,8 @@ def load_ext4(self):
15461597
dispatch[EXT4[0]] = load_ext4
15471598

15481599
def get_extension(self, code):
1549-
nil = []
1550-
obj = _extension_cache.get(code, nil)
1551-
if obj is not nil:
1600+
obj = _extension_cache.get(code, _NoValue)
1601+
if obj is not _NoValue:
15521602
self.append(obj)
15531603
return
15541604
key = _inverted_registry.get(code)
@@ -1705,8 +1755,8 @@ def load_build(self):
17051755
stack = self.stack
17061756
state = stack.pop()
17071757
inst = stack[-1]
1708-
setstate = getattr(inst, "__setstate__", None)
1709-
if setstate is not None:
1758+
setstate = getattr(inst, "__setstate__", _NoValue)
1759+
if setstate is not _NoValue:
17101760
setstate(state)
17111761
return
17121762
slotstate = None

Lib/pickletools.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def read_uint8(f):
312312
doc="Eight-byte unsigned integer, little-endian.")
313313

314314

315-
def read_stringnl(f, decode=True, stripquotes=True):
315+
def read_stringnl(f, decode=True, stripquotes=True, *, encoding='latin-1'):
316316
r"""
317317
>>> import io
318318
>>> read_stringnl(io.BytesIO(b"'abcd'\nefg\n"))
@@ -356,7 +356,7 @@ def read_stringnl(f, decode=True, stripquotes=True):
356356
raise ValueError("no string quotes around %r" % data)
357357

358358
if decode:
359-
data = codecs.escape_decode(data)[0].decode("ascii")
359+
data = codecs.escape_decode(data)[0].decode(encoding)
360360
return data
361361

362362
stringnl = ArgumentDescriptor(
@@ -370,7 +370,7 @@ def read_stringnl(f, decode=True, stripquotes=True):
370370
""")
371371

372372
def read_stringnl_noescape(f):
373-
return read_stringnl(f, stripquotes=False)
373+
return read_stringnl(f, stripquotes=False, encoding='utf-8')
374374

375375
stringnl_noescape = ArgumentDescriptor(
376376
name='stringnl_noescape',
@@ -2513,7 +2513,10 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0):
25132513
# make a mild effort to align arguments
25142514
line += ' ' * (10 - len(opcode.name))
25152515
if arg is not None:
2516-
line += ' ' + repr(arg)
2516+
if opcode.name in ("STRING", "BINSTRING", "SHORT_BINSTRING"):
2517+
line += ' ' + ascii(arg)
2518+
else:
2519+
line += ' ' + repr(arg)
25172520
if markmsg:
25182521
line += ' ' + markmsg
25192522
if annotate:

0 commit comments

Comments
 (0)