@@ -314,16 +314,17 @@ def load_frame(self, frame_size):
314
314
# Tools used for pickling.
315
315
316
316
def _getattribute (obj , name ):
317
+ top = obj
317
318
for subpath in name .split ('.' ):
318
319
if subpath == '<locals>' :
319
320
raise AttributeError ("Can't get local attribute {!r} on {!r}"
320
- .format (name , obj ))
321
+ .format (name , top ))
321
322
try :
322
323
parent = obj
323
324
obj = getattr (obj , subpath )
324
325
except AttributeError :
325
326
raise AttributeError ("Can't get attribute {!r} on {!r}"
326
- .format (name , obj )) from None
327
+ .format (name , top )) from None
327
328
return obj , parent
328
329
329
330
def whichmodule (obj , name ):
@@ -396,6 +397,8 @@ def decode_long(data):
396
397
return int .from_bytes (data , byteorder = 'little' , signed = True )
397
398
398
399
400
+ _NoValue = object ()
401
+
399
402
# Pickling machinery
400
403
401
404
class _Pickler :
@@ -530,10 +533,11 @@ def save(self, obj, save_persistent_id=True):
530
533
self .framer .commit_frame ()
531
534
532
535
# Check for persistent id (defined by a subclass)
533
- pid = self .persistent_id (obj )
534
- if pid is not None and save_persistent_id :
535
- self .save_pers (pid )
536
- return
536
+ if save_persistent_id :
537
+ pid = self .persistent_id (obj )
538
+ if pid is not None :
539
+ self .save_pers (pid )
540
+ return
537
541
538
542
# Check the memo
539
543
x = self .memo .get (id (obj ))
@@ -542,8 +546,8 @@ def save(self, obj, save_persistent_id=True):
542
546
return
543
547
544
548
rv = NotImplemented
545
- reduce = getattr (self , "reducer_override" , None )
546
- if reduce is not None :
549
+ reduce = getattr (self , "reducer_override" , _NoValue )
550
+ if reduce is not _NoValue :
547
551
rv = reduce (obj )
548
552
549
553
if rv is NotImplemented :
@@ -556,8 +560,8 @@ def save(self, obj, save_persistent_id=True):
556
560
557
561
# Check private dispatch table if any, or else
558
562
# copyreg.dispatch_table
559
- reduce = getattr (self , 'dispatch_table' , dispatch_table ).get (t )
560
- if reduce is not None :
563
+ reduce = getattr (self , 'dispatch_table' , dispatch_table ).get (t , _NoValue )
564
+ if reduce is not _NoValue :
561
565
rv = reduce (obj )
562
566
else :
563
567
# Check for a class with a custom metaclass; treat as regular
@@ -567,12 +571,12 @@ def save(self, obj, save_persistent_id=True):
567
571
return
568
572
569
573
# Check for a __reduce_ex__ method, fall back to __reduce__
570
- reduce = getattr (obj , "__reduce_ex__" , None )
571
- if reduce is not None :
574
+ reduce = getattr (obj , "__reduce_ex__" , _NoValue )
575
+ if reduce is not _NoValue :
572
576
rv = reduce (self .proto )
573
577
else :
574
- reduce = getattr (obj , "__reduce__" , None )
575
- if reduce is not None :
578
+ reduce = getattr (obj , "__reduce__" , _NoValue )
579
+ if reduce is not _NoValue :
576
580
rv = reduce ()
577
581
else :
578
582
raise PicklingError ("Can't pickle %r object: %r" %
@@ -780,14 +784,10 @@ def save_float(self, obj):
780
784
self .write (FLOAT + repr (obj ).encode ("ascii" ) + b'\n ' )
781
785
dispatch [float ] = save_float
782
786
783
- def save_bytes (self , obj ):
784
- if self .proto < 3 :
785
- if not obj : # bytes object is empty
786
- self .save_reduce (bytes , (), obj = obj )
787
- else :
788
- self .save_reduce (codecs .encode ,
789
- (str (obj , 'latin1' ), 'latin1' ), obj = obj )
790
- return
787
+ def _save_bytes_no_memo (self , obj ):
788
+ # helper for writing bytes objects for protocol >= 3
789
+ # without memoizing them
790
+ assert self .proto >= 3
791
791
n = len (obj )
792
792
if n <= 0xff :
793
793
self .write (SHORT_BINBYTES + pack ("<B" , n ) + obj )
@@ -797,28 +797,44 @@ def save_bytes(self, obj):
797
797
self ._write_large_bytes (BINBYTES + pack ("<I" , n ), obj )
798
798
else :
799
799
self .write (BINBYTES + pack ("<I" , n ) + obj )
800
+
801
+ def save_bytes (self , obj ):
802
+ if self .proto < 3 :
803
+ if not obj : # bytes object is empty
804
+ self .save_reduce (bytes , (), obj = obj )
805
+ else :
806
+ self .save_reduce (codecs .encode ,
807
+ (str (obj , 'latin1' ), 'latin1' ), obj = obj )
808
+ return
809
+ self ._save_bytes_no_memo (obj )
800
810
self .memoize (obj )
801
811
dispatch [bytes ] = save_bytes
802
812
813
+ def _save_bytearray_no_memo (self , obj ):
814
+ # helper for writing bytearray objects for protocol >= 5
815
+ # without memoizing them
816
+ assert self .proto >= 5
817
+ n = len (obj )
818
+ if n >= self .framer ._FRAME_SIZE_TARGET :
819
+ self ._write_large_bytes (BYTEARRAY8 + pack ("<Q" , n ), obj )
820
+ else :
821
+ self .write (BYTEARRAY8 + pack ("<Q" , n ) + obj )
822
+
803
823
def save_bytearray (self , obj ):
804
824
if self .proto < 5 :
805
825
if not obj : # bytearray is empty
806
826
self .save_reduce (bytearray , (), obj = obj )
807
827
else :
808
828
self .save_reduce (bytearray , (bytes (obj ),), obj = obj )
809
829
return
810
- n = len (obj )
811
- if n >= self .framer ._FRAME_SIZE_TARGET :
812
- self ._write_large_bytes (BYTEARRAY8 + pack ("<Q" , n ), obj )
813
- else :
814
- self .write (BYTEARRAY8 + pack ("<Q" , n ) + obj )
830
+ self ._save_bytearray_no_memo (obj )
815
831
self .memoize (obj )
816
832
dispatch [bytearray ] = save_bytearray
817
833
818
834
if _HAVE_PICKLE_BUFFER :
819
835
def save_picklebuffer (self , obj ):
820
836
if self .proto < 5 :
821
- raise PicklingError ("PickleBuffer can only pickled with "
837
+ raise PicklingError ("PickleBuffer can only be pickled with "
822
838
"protocol >= 5" )
823
839
with obj .raw () as m :
824
840
if not m .contiguous :
@@ -830,10 +846,18 @@ def save_picklebuffer(self, obj):
830
846
if in_band :
831
847
# Write data in-band
832
848
# XXX The C implementation avoids a copy here
849
+ buf = m .tobytes ()
850
+ in_memo = id (buf ) in self .memo
833
851
if m .readonly :
834
- self .save_bytes (m .tobytes ())
852
+ if in_memo :
853
+ self ._save_bytes_no_memo (buf )
854
+ else :
855
+ self .save_bytes (buf )
835
856
else :
836
- self .save_bytearray (m .tobytes ())
857
+ if in_memo :
858
+ self ._save_bytearray_no_memo (buf )
859
+ else :
860
+ self .save_bytearray (buf )
837
861
else :
838
862
# Write data out-of-band
839
863
self .write (NEXT_BUFFER )
@@ -1070,11 +1094,16 @@ def save_global(self, obj, name=None):
1070
1094
(obj , module_name , name ))
1071
1095
1072
1096
if self .proto >= 2 :
1073
- code = _extension_registry .get ((module_name , name ))
1074
- if code :
1075
- assert code > 0
1097
+ code = _extension_registry .get ((module_name , name ), _NoValue )
1098
+ if code is not _NoValue :
1076
1099
if code <= 0xff :
1077
- write (EXT1 + pack ("<B" , code ))
1100
+ data = pack ("<B" , code )
1101
+ if data == b'\0 ' :
1102
+ # Should never happen in normal circumstances,
1103
+ # since the type and the value of the code are
1104
+ # checked in copyreg.add_extension().
1105
+ raise RuntimeError ("extension code 0 is out of range" )
1106
+ write (EXT1 + data )
1078
1107
elif code <= 0xffff :
1079
1108
write (EXT2 + pack ("<H" , code ))
1080
1109
else :
@@ -1088,11 +1117,35 @@ def save_global(self, obj, name=None):
1088
1117
self .save (module_name )
1089
1118
self .save (name )
1090
1119
write (STACK_GLOBAL )
1091
- elif parent is not module :
1092
- self .save_reduce (getattr , (parent , lastname ))
1093
- elif self .proto >= 3 :
1094
- write (GLOBAL + bytes (module_name , "utf-8" ) + b'\n ' +
1095
- bytes (name , "utf-8" ) + b'\n ' )
1120
+ elif '.' in name :
1121
+ # In protocol < 4, objects with multi-part __qualname__
1122
+ # are represented as
1123
+ # getattr(getattr(..., attrname1), attrname2).
1124
+ dotted_path = name .split ('.' )
1125
+ name = dotted_path .pop (0 )
1126
+ save = self .save
1127
+ for attrname in dotted_path :
1128
+ save (getattr )
1129
+ if self .proto < 2 :
1130
+ write (MARK )
1131
+ self ._save_toplevel_by_name (module_name , name )
1132
+ for attrname in dotted_path :
1133
+ save (attrname )
1134
+ if self .proto < 2 :
1135
+ write (TUPLE )
1136
+ else :
1137
+ write (TUPLE2 )
1138
+ write (REDUCE )
1139
+ else :
1140
+ self ._save_toplevel_by_name (module_name , name )
1141
+
1142
+ self .memoize (obj )
1143
+
1144
+ def _save_toplevel_by_name (self , module_name , name ):
1145
+ if self .proto >= 3 :
1146
+ # Non-ASCII identifiers are supported only with protocols >= 3.
1147
+ self .write (GLOBAL + bytes (module_name , "utf-8" ) + b'\n ' +
1148
+ bytes (name , "utf-8" ) + b'\n ' )
1096
1149
else :
1097
1150
if self .fix_imports :
1098
1151
r_name_mapping = _compat_pickle .REVERSE_NAME_MAPPING
@@ -1102,14 +1155,12 @@ def save_global(self, obj, name=None):
1102
1155
elif module_name in r_import_mapping :
1103
1156
module_name = r_import_mapping [module_name ]
1104
1157
try :
1105
- write (GLOBAL + bytes (module_name , "ascii" ) + b'\n ' +
1106
- bytes (name , "ascii" ) + b'\n ' )
1158
+ self . write (GLOBAL + bytes (module_name , "ascii" ) + b'\n ' +
1159
+ bytes (name , "ascii" ) + b'\n ' )
1107
1160
except UnicodeEncodeError :
1108
1161
raise PicklingError (
1109
1162
"can't pickle global identifier '%s.%s' using "
1110
- "pickle protocol %i" % (module , name , self .proto )) from None
1111
-
1112
- self .memoize (obj )
1163
+ "pickle protocol %i" % (module_name , name , self .proto )) from None
1113
1164
1114
1165
def save_type (self , obj ):
1115
1166
if obj is type (None ):
@@ -1546,9 +1597,8 @@ def load_ext4(self):
1546
1597
dispatch [EXT4 [0 ]] = load_ext4
1547
1598
1548
1599
def get_extension (self , code ):
1549
- nil = []
1550
- obj = _extension_cache .get (code , nil )
1551
- if obj is not nil :
1600
+ obj = _extension_cache .get (code , _NoValue )
1601
+ if obj is not _NoValue :
1552
1602
self .append (obj )
1553
1603
return
1554
1604
key = _inverted_registry .get (code )
@@ -1705,8 +1755,8 @@ def load_build(self):
1705
1755
stack = self .stack
1706
1756
state = stack .pop ()
1707
1757
inst = stack [- 1 ]
1708
- setstate = getattr (inst , "__setstate__" , None )
1709
- if setstate is not None :
1758
+ setstate = getattr (inst , "__setstate__" , _NoValue )
1759
+ if setstate is not _NoValue :
1710
1760
setstate (state )
1711
1761
return
1712
1762
slotstate = None
0 commit comments