@@ -174,37 +174,38 @@ def __init__(self, value):
174
174
175
175
# Pickling machinery
176
176
177
- class Pickler :
177
+ class _Pickler :
178
178
179
179
def __init__ (self , file , protocol = None ):
180
180
"""This takes a binary file for writing a pickle data stream.
181
181
182
182
All protocols now read and write bytes.
183
183
184
184
The optional protocol argument tells the pickler to use the
185
- given protocol; supported protocols are 0, 1, 2. The default
186
- protocol is 2; it's been supported for many years now.
187
-
188
- Protocol 1 is more efficient than protocol 0; protocol 2 is
189
- more efficient than protocol 1.
185
+ given protocol; supported protocols are 0, 1, 2, 3. The default
186
+ protocol is 3; a backward-incompatible protocol designed for
187
+ Python 3.0.
190
188
191
189
Specifying a negative protocol version selects the highest
192
190
protocol version supported. The higher the protocol used, the
193
191
more recent the version of Python needed to read the pickle
194
192
produced.
195
193
196
- The file parameter must have a write() method that accepts a single
197
- string argument. It can thus be an open file object, a StringIO
198
- object, or any other custom object that meets this interface.
199
-
194
+ The file argument must have a write() method that accepts a single
195
+ bytes argument. It can thus be a file object opened for binary
196
+ writing, a io.BytesIO instance, or any other custom object that
197
+ meets this interface.
200
198
"""
201
199
if protocol is None :
202
200
protocol = DEFAULT_PROTOCOL
203
201
if protocol < 0 :
204
202
protocol = HIGHEST_PROTOCOL
205
203
elif not 0 <= protocol <= HIGHEST_PROTOCOL :
206
204
raise ValueError ("pickle protocol must be <= %d" % HIGHEST_PROTOCOL )
207
- self .write = file .write
205
+ try :
206
+ self .write = file .write
207
+ except AttributeError :
208
+ raise TypeError ("file must have a 'write' attribute" )
208
209
self .memo = {}
209
210
self .proto = int (protocol )
210
211
self .bin = protocol >= 1
@@ -270,10 +271,10 @@ def get(self, i, pack=struct.pack):
270
271
271
272
return GET + repr (i ).encode ("ascii" ) + b'\n '
272
273
273
- def save (self , obj ):
274
+ def save (self , obj , save_persistent_id = True ):
274
275
# Check for persistent id (defined by a subclass)
275
276
pid = self .persistent_id (obj )
276
- if pid :
277
+ if pid is not None and save_persistent_id :
277
278
self .save_pers (pid )
278
279
return
279
280
@@ -341,7 +342,7 @@ def persistent_id(self, obj):
341
342
def save_pers (self , pid ):
342
343
# Save a persistent id reference
343
344
if self .bin :
344
- self .save (pid )
345
+ self .save (pid , save_persistent_id = False )
345
346
self .write (BINPERSID )
346
347
else :
347
348
self .write (PERSID + str (pid ).encode ("ascii" ) + b'\n ' )
@@ -350,13 +351,13 @@ def save_reduce(self, func, args, state=None,
350
351
listitems = None , dictitems = None , obj = None ):
351
352
# This API is called by some subclasses
352
353
353
- # Assert that args is a tuple or None
354
+ # Assert that args is a tuple
354
355
if not isinstance (args , tuple ):
355
- raise PicklingError ("args from reduce () should be a tuple" )
356
+ raise PicklingError ("args from save_reduce () should be a tuple" )
356
357
357
358
# Assert that func is callable
358
359
if not hasattr (func , '__call__' ):
359
- raise PicklingError ("func from reduce should be callable" )
360
+ raise PicklingError ("func from save_reduce() should be callable" )
360
361
361
362
save = self .save
362
363
write = self .write
@@ -438,31 +439,6 @@ def save_bool(self, obj):
438
439
self .write (obj and TRUE or FALSE )
439
440
dispatch [bool ] = save_bool
440
441
441
- def save_int (self , obj , pack = struct .pack ):
442
- if self .bin :
443
- # If the int is small enough to fit in a signed 4-byte 2's-comp
444
- # format, we can store it more efficiently than the general
445
- # case.
446
- # First one- and two-byte unsigned ints:
447
- if obj >= 0 :
448
- if obj <= 0xff :
449
- self .write (BININT1 + bytes ([obj ]))
450
- return
451
- if obj <= 0xffff :
452
- self .write (BININT2 + bytes ([obj & 0xff , obj >> 8 ]))
453
- return
454
- # Next check for 4-byte signed ints:
455
- high_bits = obj >> 31 # note that Python shift sign-extends
456
- if high_bits == 0 or high_bits == - 1 :
457
- # All high bits are copies of bit 2**31, so the value
458
- # fits in a 4-byte signed int.
459
- self .write (BININT + pack ("<i" , obj ))
460
- return
461
- # Text pickle, or int too big to fit in signed 4-byte format.
462
- self .write (INT + repr (obj ).encode ("ascii" ) + b'\n ' )
463
- # XXX save_int is merged into save_long
464
- # dispatch[int] = save_int
465
-
466
442
def save_long (self , obj , pack = struct .pack ):
467
443
if self .bin :
468
444
# If the int is small enough to fit in a signed 4-byte 2's-comp
@@ -503,7 +479,7 @@ def save_float(self, obj, pack=struct.pack):
503
479
504
480
def save_bytes (self , obj , pack = struct .pack ):
505
481
if self .proto < 3 :
506
- self .save_reduce (bytes , (list (obj ),))
482
+ self .save_reduce (bytes , (list (obj ),), obj = obj )
507
483
return
508
484
n = len (obj )
509
485
if n < 256 :
@@ -579,12 +555,6 @@ def save_tuple(self, obj):
579
555
580
556
dispatch [tuple ] = save_tuple
581
557
582
- # save_empty_tuple() isn't used by anything in Python 2.3. However, I
583
- # found a Pickler subclass in Zope3 that calls it, so it's not harmless
584
- # to remove it.
585
- def save_empty_tuple (self , obj ):
586
- self .write (EMPTY_TUPLE )
587
-
588
558
def save_list (self , obj ):
589
559
write = self .write
590
560
@@ -696,7 +666,7 @@ def save_global(self, obj, name=None, pack=struct.pack):
696
666
module = whichmodule (obj , name )
697
667
698
668
try :
699
- __import__ (module )
669
+ __import__ (module , level = 0 )
700
670
mod = sys .modules [module ]
701
671
klass = getattr (mod , name )
702
672
except (ImportError , KeyError , AttributeError ):
@@ -720,9 +690,19 @@ def save_global(self, obj, name=None, pack=struct.pack):
720
690
else :
721
691
write (EXT4 + pack ("<i" , code ))
722
692
return
693
+ # Non-ASCII identifiers are supported only with protocols >= 3.
694
+ if self .proto >= 3 :
695
+ write (GLOBAL + bytes (module , "utf-8" ) + b'\n ' +
696
+ bytes (name , "utf-8" ) + b'\n ' )
697
+ else :
698
+ try :
699
+ write (GLOBAL + bytes (module , "ascii" ) + b'\n ' +
700
+ bytes (name , "ascii" ) + b'\n ' )
701
+ except UnicodeEncodeError :
702
+ raise PicklingError (
703
+ "can't pickle global identifier '%s.%s' using "
704
+ "pickle protocol %i" % (module , name , self .proto ))
723
705
724
- write (GLOBAL + bytes (module , "utf-8" ) + b'\n ' +
725
- bytes (name , "utf-8" ) + b'\n ' )
726
706
self .memoize (obj )
727
707
728
708
dispatch [FunctionType ] = save_global
@@ -781,7 +761,7 @@ def whichmodule(func, funcname):
781
761
782
762
# Unpickling machinery
783
763
784
- class Unpickler :
764
+ class _Unpickler :
785
765
786
766
def __init__ (self , file , * , encoding = "ASCII" , errors = "strict" ):
787
767
"""This takes a binary file for reading a pickle data stream.
@@ -841,6 +821,9 @@ def marker(self):
841
821
while stack [k ] is not mark : k = k - 1
842
822
return k
843
823
824
+ def persistent_load (self , pid ):
825
+ raise UnpickingError ("unsupported persistent id encountered" )
826
+
844
827
dispatch = {}
845
828
846
829
def load_proto (self ):
@@ -850,7 +833,7 @@ def load_proto(self):
850
833
dispatch [PROTO [0 ]] = load_proto
851
834
852
835
def load_persid (self ):
853
- pid = self .readline ()[:- 1 ]
836
+ pid = self .readline ()[:- 1 ]. decode ( "ascii" )
854
837
self .append (self .persistent_load (pid ))
855
838
dispatch [PERSID [0 ]] = load_persid
856
839
@@ -879,9 +862,9 @@ def load_int(self):
879
862
val = True
880
863
else :
881
864
try :
882
- val = int (data )
865
+ val = int (data , 0 )
883
866
except ValueError :
884
- val = int (data )
867
+ val = int (data , 0 )
885
868
self .append (val )
886
869
dispatch [INT [0 ]] = load_int
887
870
@@ -933,7 +916,8 @@ def load_string(self):
933
916
break
934
917
else :
935
918
raise ValueError ("insecure string pickle: %r" % orig )
936
- self .append (codecs .escape_decode (rep )[0 ])
919
+ self .append (codecs .escape_decode (rep )[0 ]
920
+ .decode (self .encoding , self .errors ))
937
921
dispatch [STRING [0 ]] = load_string
938
922
939
923
def load_binstring (self ):
@@ -975,7 +959,7 @@ def load_tuple(self):
975
959
dispatch [TUPLE [0 ]] = load_tuple
976
960
977
961
def load_empty_tuple (self ):
978
- self .stack . append (())
962
+ self .append (())
979
963
dispatch [EMPTY_TUPLE [0 ]] = load_empty_tuple
980
964
981
965
def load_tuple1 (self ):
@@ -991,11 +975,11 @@ def load_tuple3(self):
991
975
dispatch [TUPLE3 [0 ]] = load_tuple3
992
976
993
977
def load_empty_list (self ):
994
- self .stack . append ([])
978
+ self .append ([])
995
979
dispatch [EMPTY_LIST [0 ]] = load_empty_list
996
980
997
981
def load_empty_dictionary (self ):
998
- self .stack . append ({})
982
+ self .append ({})
999
983
dispatch [EMPTY_DICT [0 ]] = load_empty_dictionary
1000
984
1001
985
def load_list (self ):
@@ -1022,13 +1006,13 @@ def load_dict(self):
1022
1006
def _instantiate (self , klass , k ):
1023
1007
args = tuple (self .stack [k + 1 :])
1024
1008
del self .stack [k :]
1025
- instantiated = 0
1009
+ instantiated = False
1026
1010
if (not args and
1027
1011
isinstance (klass , type ) and
1028
1012
not hasattr (klass , "__getinitargs__" )):
1029
1013
value = _EmptyClass ()
1030
1014
value .__class__ = klass
1031
- instantiated = 1
1015
+ instantiated = True
1032
1016
if not instantiated :
1033
1017
try :
1034
1018
value = klass (* args )
@@ -1038,8 +1022,8 @@ def _instantiate(self, klass, k):
1038
1022
self .append (value )
1039
1023
1040
1024
def load_inst (self ):
1041
- module = self .readline ()[:- 1 ]
1042
- name = self .readline ()[:- 1 ]
1025
+ module = self .readline ()[:- 1 ]. decode ( "ascii" )
1026
+ name = self .readline ()[:- 1 ]. decode ( "ascii" )
1043
1027
klass = self .find_class (module , name )
1044
1028
self ._instantiate (klass , self .marker ())
1045
1029
dispatch [INST [0 ]] = load_inst
@@ -1059,8 +1043,8 @@ def load_newobj(self):
1059
1043
dispatch [NEWOBJ [0 ]] = load_newobj
1060
1044
1061
1045
def load_global (self ):
1062
- module = self .readline ()[:- 1 ]
1063
- name = self .readline ()[:- 1 ]
1046
+ module = self .readline ()[:- 1 ]. decode ( "utf-8" )
1047
+ name = self .readline ()[:- 1 ]. decode ( "utf-8" )
1064
1048
klass = self .find_class (module , name )
1065
1049
self .append (klass )
1066
1050
dispatch [GLOBAL [0 ]] = load_global
@@ -1095,11 +1079,7 @@ def get_extension(self, code):
1095
1079
1096
1080
def find_class (self , module , name ):
1097
1081
# Subclasses may override this
1098
- if isinstance (module , bytes_types ):
1099
- module = module .decode ("utf-8" )
1100
- if isinstance (name , bytes_types ):
1101
- name = name .decode ("utf-8" )
1102
- __import__ (module )
1082
+ __import__ (module , level = 0 )
1103
1083
mod = sys .modules [module ]
1104
1084
klass = getattr (mod , name )
1105
1085
return klass
@@ -1131,31 +1111,33 @@ def load_dup(self):
1131
1111
dispatch [DUP [0 ]] = load_dup
1132
1112
1133
1113
def load_get (self ):
1134
- self .append (self .memo [self .readline ()[:- 1 ].decode ("ascii" )])
1114
+ i = int (self .readline ()[:- 1 ])
1115
+ self .append (self .memo [i ])
1135
1116
dispatch [GET [0 ]] = load_get
1136
1117
1137
1118
def load_binget (self ):
1138
- i = ord ( self .read (1 ))
1139
- self .append (self .memo [repr ( i ) ])
1119
+ i = self .read (1 )[ 0 ]
1120
+ self .append (self .memo [i ])
1140
1121
dispatch [BINGET [0 ]] = load_binget
1141
1122
1142
1123
def load_long_binget (self ):
1143
1124
i = mloads (b'i' + self .read (4 ))
1144
- self .append (self .memo [repr ( i ) ])
1125
+ self .append (self .memo [i ])
1145
1126
dispatch [LONG_BINGET [0 ]] = load_long_binget
1146
1127
1147
1128
def load_put (self ):
1148
- self .memo [self .readline ()[:- 1 ].decode ("ascii" )] = self .stack [- 1 ]
1129
+ i = int (self .readline ()[:- 1 ])
1130
+ self .memo [i ] = self .stack [- 1 ]
1149
1131
dispatch [PUT [0 ]] = load_put
1150
1132
1151
1133
def load_binput (self ):
1152
- i = ord ( self .read (1 ))
1153
- self .memo [repr ( i ) ] = self .stack [- 1 ]
1134
+ i = self .read (1 )[ 0 ]
1135
+ self .memo [i ] = self .stack [- 1 ]
1154
1136
dispatch [BINPUT [0 ]] = load_binput
1155
1137
1156
1138
def load_long_binput (self ):
1157
1139
i = mloads (b'i' + self .read (4 ))
1158
- self .memo [repr ( i ) ] = self .stack [- 1 ]
1140
+ self .memo [i ] = self .stack [- 1 ]
1159
1141
dispatch [LONG_BINPUT [0 ]] = load_long_binput
1160
1142
1161
1143
def load_append (self ):
@@ -1321,6 +1303,12 @@ def decode_long(data):
1321
1303
n -= 1 << (nbytes * 8 )
1322
1304
return n
1323
1305
1306
+ # Use the faster _pickle if possible
1307
+ try :
1308
+ from _pickle import *
1309
+ except ImportError :
1310
+ Pickler , Unpickler = _Pickler , _Unpickler
1311
+
1324
1312
# Shorthands
1325
1313
1326
1314
def dump (obj , file , protocol = None ):
@@ -1333,14 +1321,14 @@ def dumps(obj, protocol=None):
1333
1321
assert isinstance (res , bytes_types )
1334
1322
return res
1335
1323
1336
- def load (file ):
1337
- return Unpickler (file ).load ()
1324
+ def load (file , * , encoding = "ASCII" , errors = "strict" ):
1325
+ return Unpickler (file , encoding = encoding , errors = errors ).load ()
1338
1326
1339
- def loads (s ):
1327
+ def loads (s , * , encoding = "ASCII" , errors = "strict" ):
1340
1328
if isinstance (s , str ):
1341
1329
raise TypeError ("Can't load pickle from unicode string" )
1342
1330
file = io .BytesIO (s )
1343
- return Unpickler (file ).load ()
1331
+ return Unpickler (file , encoding = encoding , errors = errors ).load ()
1344
1332
1345
1333
# Doctest
1346
1334
0 commit comments