diff --git a/Lib/pickle.py b/Lib/pickle.py index 6e3c61fd0b..550f8675f2 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -314,16 +314,17 @@ def load_frame(self, frame_size): # Tools used for pickling. def _getattribute(obj, name): + top = obj for subpath in name.split('.'): if subpath == '': raise AttributeError("Can't get local attribute {!r} on {!r}" - .format(name, obj)) + .format(name, top)) try: parent = obj obj = getattr(obj, subpath) except AttributeError: raise AttributeError("Can't get attribute {!r} on {!r}" - .format(name, obj)) from None + .format(name, top)) from None return obj, parent def whichmodule(obj, name): @@ -396,6 +397,8 @@ def decode_long(data): return int.from_bytes(data, byteorder='little', signed=True) +_NoValue = object() + # Pickling machinery class _Pickler: @@ -530,10 +533,11 @@ def save(self, obj, save_persistent_id=True): self.framer.commit_frame() # Check for persistent id (defined by a subclass) - pid = self.persistent_id(obj) - if pid is not None and save_persistent_id: - self.save_pers(pid) - return + if save_persistent_id: + pid = self.persistent_id(obj) + if pid is not None: + self.save_pers(pid) + return # Check the memo x = self.memo.get(id(obj)) @@ -542,8 +546,8 @@ def save(self, obj, save_persistent_id=True): return rv = NotImplemented - reduce = getattr(self, "reducer_override", None) - if reduce is not None: + reduce = getattr(self, "reducer_override", _NoValue) + if reduce is not _NoValue: rv = reduce(obj) if rv is NotImplemented: @@ -556,8 +560,8 @@ def save(self, obj, save_persistent_id=True): # Check private dispatch table if any, or else # copyreg.dispatch_table - reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) - if reduce is not None: + reduce = getattr(self, 'dispatch_table', dispatch_table).get(t, _NoValue) + if reduce is not _NoValue: rv = reduce(obj) else: # Check for a class with a custom metaclass; treat as regular @@ -567,12 +571,12 @@ def save(self, obj, save_persistent_id=True): return # Check for a __reduce_ex__ method, fall back to __reduce__ - reduce = getattr(obj, "__reduce_ex__", None) - if reduce is not None: + reduce = getattr(obj, "__reduce_ex__", _NoValue) + if reduce is not _NoValue: rv = reduce(self.proto) else: - reduce = getattr(obj, "__reduce__", None) - if reduce is not None: + reduce = getattr(obj, "__reduce__", _NoValue) + if reduce is not _NoValue: rv = reduce() else: raise PicklingError("Can't pickle %r object: %r" % @@ -780,14 +784,10 @@ def save_float(self, obj): self.write(FLOAT + repr(obj).encode("ascii") + b'\n') dispatch[float] = save_float - def save_bytes(self, obj): - if self.proto < 3: - if not obj: # bytes object is empty - self.save_reduce(bytes, (), obj=obj) - else: - self.save_reduce(codecs.encode, - (str(obj, 'latin1'), 'latin1'), obj=obj) - return + def _save_bytes_no_memo(self, obj): + # helper for writing bytes objects for protocol >= 3 + # without memoizing them + assert self.proto >= 3 n = len(obj) if n <= 0xff: self.write(SHORT_BINBYTES + pack("= 5 + # without memoizing them + assert self.proto >= 5 + n = len(obj) + if n >= self.framer._FRAME_SIZE_TARGET: + self._write_large_bytes(BYTEARRAY8 + pack("= self.framer._FRAME_SIZE_TARGET: - self._write_large_bytes(BYTEARRAY8 + pack("= 5") with obj.raw() as m: if not m.contiguous: @@ -830,10 +846,18 @@ def save_picklebuffer(self, obj): if in_band: # Write data in-band # XXX The C implementation avoids a copy here + buf = m.tobytes() + in_memo = id(buf) in self.memo if m.readonly: - self.save_bytes(m.tobytes()) + if in_memo: + self._save_bytes_no_memo(buf) + else: + self.save_bytes(buf) else: - self.save_bytearray(m.tobytes()) + if in_memo: + self._save_bytearray_no_memo(buf) + else: + self.save_bytearray(buf) else: # Write data out-of-band self.write(NEXT_BUFFER) @@ -1070,11 +1094,16 @@ def save_global(self, obj, name=None): (obj, module_name, name)) if self.proto >= 2: - code = _extension_registry.get((module_name, name)) - if code: - assert code > 0 + code = _extension_registry.get((module_name, name), _NoValue) + if code is not _NoValue: if code <= 0xff: - write(EXT1 + pack("= 3: - write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + - bytes(name, "utf-8") + b'\n') + elif '.' in name: + # In protocol < 4, objects with multi-part __qualname__ + # are represented as + # getattr(getattr(..., attrname1), attrname2). + dotted_path = name.split('.') + name = dotted_path.pop(0) + save = self.save + for attrname in dotted_path: + save(getattr) + if self.proto < 2: + write(MARK) + self._save_toplevel_by_name(module_name, name) + for attrname in dotted_path: + save(attrname) + if self.proto < 2: + write(TUPLE) + else: + write(TUPLE2) + write(REDUCE) + else: + self._save_toplevel_by_name(module_name, name) + + self.memoize(obj) + + def _save_toplevel_by_name(self, module_name, name): + if self.proto >= 3: + # Non-ASCII identifiers are supported only with protocols >= 3. + self.write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + + bytes(name, "utf-8") + b'\n') else: if self.fix_imports: r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING @@ -1102,14 +1155,12 @@ def save_global(self, obj, name=None): elif module_name in r_import_mapping: module_name = r_import_mapping[module_name] try: - write(GLOBAL + bytes(module_name, "ascii") + b'\n' + - bytes(name, "ascii") + b'\n') + self.write(GLOBAL + bytes(module_name, "ascii") + b'\n' + + bytes(name, "ascii") + b'\n') except UnicodeEncodeError: raise PicklingError( "can't pickle global identifier '%s.%s' using " - "pickle protocol %i" % (module, name, self.proto)) from None - - self.memoize(obj) + "pickle protocol %i" % (module_name, name, self.proto)) from None def save_type(self, obj): if obj is type(None): @@ -1546,9 +1597,8 @@ def load_ext4(self): dispatch[EXT4[0]] = load_ext4 def get_extension(self, code): - nil = [] - obj = _extension_cache.get(code, nil) - if obj is not nil: + obj = _extension_cache.get(code, _NoValue) + if obj is not _NoValue: self.append(obj) return key = _inverted_registry.get(code) @@ -1705,8 +1755,8 @@ def load_build(self): stack = self.stack state = stack.pop() inst = stack[-1] - setstate = getattr(inst, "__setstate__", None) - if setstate is not None: + setstate = getattr(inst, "__setstate__", _NoValue) + if setstate is not _NoValue: setstate(state) return slotstate = None diff --git a/Lib/pickletools.py b/Lib/pickletools.py index 51ee4a7a26..33a51492ea 100644 --- a/Lib/pickletools.py +++ b/Lib/pickletools.py @@ -312,7 +312,7 @@ def read_uint8(f): doc="Eight-byte unsigned integer, little-endian.") -def read_stringnl(f, decode=True, stripquotes=True): +def read_stringnl(f, decode=True, stripquotes=True, *, encoding='latin-1'): r""" >>> import io >>> read_stringnl(io.BytesIO(b"'abcd'\nefg\n")) @@ -356,7 +356,7 @@ def read_stringnl(f, decode=True, stripquotes=True): raise ValueError("no string quotes around %r" % data) if decode: - data = codecs.escape_decode(data)[0].decode("ascii") + data = codecs.escape_decode(data)[0].decode(encoding) return data stringnl = ArgumentDescriptor( @@ -370,7 +370,7 @@ def read_stringnl(f, decode=True, stripquotes=True): """) def read_stringnl_noescape(f): - return read_stringnl(f, stripquotes=False) + return read_stringnl(f, stripquotes=False, encoding='utf-8') stringnl_noescape = ArgumentDescriptor( name='stringnl_noescape', @@ -2513,7 +2513,10 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0): # make a mild effort to align arguments line += ' ' * (10 - len(opcode.name)) if arg is not None: - line += ' ' + repr(arg) + if opcode.name in ("STRING", "BINSTRING", "SHORT_BINSTRING"): + line += ' ' + ascii(arg) + else: + line += ' ' + repr(arg) if markmsg: line += ' ' + markmsg if annotate: diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index a9177ada39..070e277c2a 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -97,11 +97,6 @@ def dumps(self, arg, proto=None, **kwargs): def test_picklebuffer_error(self): # TODO(RUSTPYTHON): Remove this test when it passes return super().test_picklebuffer_error() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_reduce_ex_None(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_reduce_ex_None() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_bad_getattr(self): # TODO(RUSTPYTHON): Remove this test when it passes @@ -190,16 +185,6 @@ def test_oob_buffers_writable_to_readonly(self): # TODO(RUSTPYTHON): Remove this def test_optional_frames(self): # TODO(RUSTPYTHON): Remove this test when it passes return super().test_optional_frames() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_pickle_setstate_None(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_pickle_setstate_None() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_recursive_nested_names2(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_recursive_nested_names2() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_buffers_error(self): # TODO(RUSTPYTHON): Remove this test when it passes @@ -240,16 +225,6 @@ def test_oob_buffers_writable_to_readonly(self): # TODO(RUSTPYTHON): Remove this def test_optional_frames(self): # TODO(RUSTPYTHON): Remove this test when it passes return super().test_optional_frames() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_pickle_setstate_None(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_pickle_setstate_None() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_recursive_nested_names2(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_recursive_nested_names2() - class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, BigmemPickleTests, unittest.TestCase): @@ -299,11 +274,6 @@ def test_load_python2_str_as_bytes(self): # TODO(RUSTPYTHON): Remove this test w def test_py_methods(self): # TODO(RUSTPYTHON): Remove this test when it passes return super().test_py_methods() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_recursive_nested_names2(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_recursive_nested_names2() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_oob_buffers_writable_to_readonly(self): # TODO(RUSTPYTHON): Remove this test when it passes @@ -344,11 +314,6 @@ def test_oob_buffers(self): # TODO(RUSTPYTHON): Remove this test when it passes def test_optional_frames(self): # TODO(RUSTPYTHON): Remove this test when it passes return super().test_optional_frames() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_pickle_setstate_None(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_pickle_setstate_None() - class PersistentPicklerUnpicklerMixin(object): def dumps(self, arg, proto=None): @@ -465,8 +430,6 @@ def persistent_load(pid): return pid check(PersUnpickler) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pickler_super(self): class PersPickler(self.pickler): def persistent_id(subself, obj): @@ -496,8 +459,6 @@ def persistent_load(subself, pid): self.assertEqual(unpickler.load(), 'abc') self.assertEqual(called, ['abc']) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pickler_instance_attribute(self): def persistent_id(obj): called.append(obj) @@ -532,8 +493,6 @@ def persistent_load(pid): del unpickler.persistent_load self.assertEqual(unpickler.persistent_load, old_persistent_load) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pickler_super_instance_attribute(self): class PersPickler(self.pickler): def persistent_id(subself, obj): @@ -582,11 +541,6 @@ class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittes pickler_class = pickle._Pickler unpickler_class = pickle._Unpickler - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_pickle_invalid_reducer_override(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_pickle_invalid_reducer_override() - class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): @@ -595,11 +549,6 @@ class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): def get_dispatch_table(self): return pickle.dispatch_table.copy() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_dispatch_table_None_item(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_dispatch_table_None_item() - class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): @@ -608,11 +557,6 @@ class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): def get_dispatch_table(self): return collections.ChainMap({}, pickle.dispatch_table) - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_dispatch_table_None_item(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_dispatch_table_None_item() - class PyPicklerHookTests(AbstractHookTests, unittest.TestCase): class CustomPyPicklerClass(pickle._Pickler, diff --git a/Lib/test/test_pickletools.py b/Lib/test/test_pickletools.py index 2a19976ce2..6c38bef3d3 100644 --- a/Lib/test/test_pickletools.py +++ b/Lib/test/test_pickletools.py @@ -102,16 +102,6 @@ def test_oob_buffers_writable_to_readonly(self): # TODO(RUSTPYTHON): Remove this def test_optional_frames(self): # TODO(RUSTPYTHON): Remove this test when it passes return super().test_optional_frames() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_pickle_setstate_None(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_pickle_setstate_None() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_recursive_nested_names2(self): # TODO(RUSTPYTHON): Remove this test when it passes - return super().test_recursive_nested_names2() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_py_methods(self): # TODO(RUSTPYTHON): Remove this test when it passes @@ -436,8 +426,6 @@ def test_annotate(self): highest protocol among opcodes = 0 ''', annotate=20) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_string(self): self.check_dis(b"S'abc'\n.", '''\ 0: S STRING 'abc' @@ -469,8 +457,6 @@ def test_string_without_quotes(self): self.check_dis_error(b"S\"abc'\n.", '', r"""strinq quote b'"' not found at both ends of b'"abc\\''""") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_binstring(self): self.check_dis(b"T\x03\x00\x00\x00abc.", '''\ 0: T BINSTRING 'abc' @@ -483,8 +469,6 @@ def test_binstring(self): highest protocol among opcodes = 1 ''') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_short_binstring(self): self.check_dis(b"U\x03abc.", '''\ 0: U SHORT_BINSTRING 'abc' @@ -497,8 +481,6 @@ def test_short_binstring(self): highest protocol among opcodes = 1 ''') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_global(self): self.check_dis(b"cmodule\nname\n.", '''\ 0: c GLOBAL 'module name'