Commit | Line | Data |
---|---|---|
86530b38 AT |
1 | import unittest |
2 | import pickle | |
3 | import cPickle | |
4 | import pickletools | |
5 | import copy_reg | |
6 | ||
7 | from test.test_support import TestFailed, have_unicode, TESTFN | |
8 | ||
9 | # Tests that try a number of pickle protocols should have a | |
10 | # for proto in protocols: | |
11 | # kind of outer loop. | |
12 | assert pickle.HIGHEST_PROTOCOL == cPickle.HIGHEST_PROTOCOL == 2 | |
13 | protocols = range(pickle.HIGHEST_PROTOCOL + 1) | |
14 | ||
15 | ||
16 | # Return True if opcode code appears in the pickle, else False. | |
17 | def opcode_in_pickle(code, pickle): | |
18 | for op, dummy, dummy in pickletools.genops(pickle): | |
19 | if op.code == code: | |
20 | return True | |
21 | return False | |
22 | ||
23 | # Return the number of times opcode code appears in pickle. | |
24 | def count_opcode(code, pickle): | |
25 | n = 0 | |
26 | for op, dummy, dummy in pickletools.genops(pickle): | |
27 | if op.code == code: | |
28 | n += 1 | |
29 | return n | |
30 | ||
31 | # We can't very well test the extension registry without putting known stuff | |
32 | # in it, but we have to be careful to restore its original state. Code | |
33 | # should do this: | |
34 | # | |
35 | # e = ExtensionSaver(extension_code) | |
36 | # try: | |
37 | # fiddle w/ the extension registry's stuff for extension_code | |
38 | # finally: | |
39 | # e.restore() | |
40 | ||
41 | class ExtensionSaver: | |
42 | # Remember current registration for code (if any), and remove it (if | |
43 | # there is one). | |
44 | def __init__(self, code): | |
45 | self.code = code | |
46 | if code in copy_reg._inverted_registry: | |
47 | self.pair = copy_reg._inverted_registry[code] | |
48 | copy_reg.remove_extension(self.pair[0], self.pair[1], code) | |
49 | else: | |
50 | self.pair = None | |
51 | ||
52 | # Restore previous registration for code. | |
53 | def restore(self): | |
54 | code = self.code | |
55 | curpair = copy_reg._inverted_registry.get(code) | |
56 | if curpair is not None: | |
57 | copy_reg.remove_extension(curpair[0], curpair[1], code) | |
58 | pair = self.pair | |
59 | if pair is not None: | |
60 | copy_reg.add_extension(pair[0], pair[1], code) | |
61 | ||
62 | class C: | |
63 | def __cmp__(self, other): | |
64 | return cmp(self.__dict__, other.__dict__) | |
65 | ||
66 | import __main__ | |
67 | __main__.C = C | |
68 | C.__module__ = "__main__" | |
69 | ||
70 | class myint(int): | |
71 | def __init__(self, x): | |
72 | self.str = str(x) | |
73 | ||
74 | class initarg(C): | |
75 | ||
76 | def __init__(self, a, b): | |
77 | self.a = a | |
78 | self.b = b | |
79 | ||
80 | def __getinitargs__(self): | |
81 | return self.a, self.b | |
82 | ||
83 | class metaclass(type): | |
84 | pass | |
85 | ||
86 | class use_metaclass(object): | |
87 | __metaclass__ = metaclass | |
88 | ||
89 | # DATA0 .. DATA2 are the pickles we expect under the various protocols, for | |
90 | # the object returned by create_data(). | |
91 | ||
92 | # break into multiple strings to avoid confusing font-lock-mode | |
93 | DATA0 = """(lp1 | |
94 | I0 | |
95 | aL1L | |
96 | aF2 | |
97 | ac__builtin__ | |
98 | complex | |
99 | p2 | |
100 | """ + \ | |
101 | """(F3 | |
102 | F0 | |
103 | tRp3 | |
104 | aI1 | |
105 | aI-1 | |
106 | aI255 | |
107 | aI-255 | |
108 | aI-256 | |
109 | aI65535 | |
110 | aI-65535 | |
111 | aI-65536 | |
112 | aI2147483647 | |
113 | aI-2147483647 | |
114 | aI-2147483648 | |
115 | a""" + \ | |
116 | """(S'abc' | |
117 | p4 | |
118 | g4 | |
119 | """ + \ | |
120 | """(i__main__ | |
121 | C | |
122 | p5 | |
123 | """ + \ | |
124 | """(dp6 | |
125 | S'foo' | |
126 | p7 | |
127 | I1 | |
128 | sS'bar' | |
129 | p8 | |
130 | I2 | |
131 | sbg5 | |
132 | tp9 | |
133 | ag9 | |
134 | aI5 | |
135 | a. | |
136 | """ | |
137 | ||
138 | # Disassembly of DATA0. | |
139 | DATA0_DIS = """\ | |
140 | 0: ( MARK | |
141 | 1: l LIST (MARK at 0) | |
142 | 2: p PUT 1 | |
143 | 5: I INT 0 | |
144 | 8: a APPEND | |
145 | 9: L LONG 1L | |
146 | 13: a APPEND | |
147 | 14: F FLOAT 2.0 | |
148 | 17: a APPEND | |
149 | 18: c GLOBAL '__builtin__ complex' | |
150 | 39: p PUT 2 | |
151 | 42: ( MARK | |
152 | 43: F FLOAT 3.0 | |
153 | 46: F FLOAT 0.0 | |
154 | 49: t TUPLE (MARK at 42) | |
155 | 50: R REDUCE | |
156 | 51: p PUT 3 | |
157 | 54: a APPEND | |
158 | 55: I INT 1 | |
159 | 58: a APPEND | |
160 | 59: I INT -1 | |
161 | 63: a APPEND | |
162 | 64: I INT 255 | |
163 | 69: a APPEND | |
164 | 70: I INT -255 | |
165 | 76: a APPEND | |
166 | 77: I INT -256 | |
167 | 83: a APPEND | |
168 | 84: I INT 65535 | |
169 | 91: a APPEND | |
170 | 92: I INT -65535 | |
171 | 100: a APPEND | |
172 | 101: I INT -65536 | |
173 | 109: a APPEND | |
174 | 110: I INT 2147483647 | |
175 | 122: a APPEND | |
176 | 123: I INT -2147483647 | |
177 | 136: a APPEND | |
178 | 137: I INT -2147483648 | |
179 | 150: a APPEND | |
180 | 151: ( MARK | |
181 | 152: S STRING 'abc' | |
182 | 159: p PUT 4 | |
183 | 162: g GET 4 | |
184 | 165: ( MARK | |
185 | 166: i INST '__main__ C' (MARK at 165) | |
186 | 178: p PUT 5 | |
187 | 181: ( MARK | |
188 | 182: d DICT (MARK at 181) | |
189 | 183: p PUT 6 | |
190 | 186: S STRING 'foo' | |
191 | 193: p PUT 7 | |
192 | 196: I INT 1 | |
193 | 199: s SETITEM | |
194 | 200: S STRING 'bar' | |
195 | 207: p PUT 8 | |
196 | 210: I INT 2 | |
197 | 213: s SETITEM | |
198 | 214: b BUILD | |
199 | 215: g GET 5 | |
200 | 218: t TUPLE (MARK at 151) | |
201 | 219: p PUT 9 | |
202 | 222: a APPEND | |
203 | 223: g GET 9 | |
204 | 226: a APPEND | |
205 | 227: I INT 5 | |
206 | 230: a APPEND | |
207 | 231: . STOP | |
208 | highest protocol among opcodes = 0 | |
209 | """ | |
210 | ||
211 | DATA1 = (']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00' | |
212 | 'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00' | |
213 | '\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff' | |
214 | '\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff' | |
215 | 'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00' | |
216 | '\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n' | |
217 | 'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh' | |
218 | '\x06tq\nh\nK\x05e.' | |
219 | ) | |
220 | ||
221 | # Disassembly of DATA1. | |
222 | DATA1_DIS = """\ | |
223 | 0: ] EMPTY_LIST | |
224 | 1: q BINPUT 1 | |
225 | 3: ( MARK | |
226 | 4: K BININT1 0 | |
227 | 6: L LONG 1L | |
228 | 10: G BINFLOAT 2.0 | |
229 | 19: c GLOBAL '__builtin__ complex' | |
230 | 40: q BINPUT 2 | |
231 | 42: ( MARK | |
232 | 43: G BINFLOAT 3.0 | |
233 | 52: G BINFLOAT 0.0 | |
234 | 61: t TUPLE (MARK at 42) | |
235 | 62: R REDUCE | |
236 | 63: q BINPUT 3 | |
237 | 65: K BININT1 1 | |
238 | 67: J BININT -1 | |
239 | 72: K BININT1 255 | |
240 | 74: J BININT -255 | |
241 | 79: J BININT -256 | |
242 | 84: M BININT2 65535 | |
243 | 87: J BININT -65535 | |
244 | 92: J BININT -65536 | |
245 | 97: J BININT 2147483647 | |
246 | 102: J BININT -2147483647 | |
247 | 107: J BININT -2147483648 | |
248 | 112: ( MARK | |
249 | 113: U SHORT_BINSTRING 'abc' | |
250 | 118: q BINPUT 4 | |
251 | 120: h BINGET 4 | |
252 | 122: ( MARK | |
253 | 123: c GLOBAL '__main__ C' | |
254 | 135: q BINPUT 5 | |
255 | 137: o OBJ (MARK at 122) | |
256 | 138: q BINPUT 6 | |
257 | 140: } EMPTY_DICT | |
258 | 141: q BINPUT 7 | |
259 | 143: ( MARK | |
260 | 144: U SHORT_BINSTRING 'foo' | |
261 | 149: q BINPUT 8 | |
262 | 151: K BININT1 1 | |
263 | 153: U SHORT_BINSTRING 'bar' | |
264 | 158: q BINPUT 9 | |
265 | 160: K BININT1 2 | |
266 | 162: u SETITEMS (MARK at 143) | |
267 | 163: b BUILD | |
268 | 164: h BINGET 6 | |
269 | 166: t TUPLE (MARK at 112) | |
270 | 167: q BINPUT 10 | |
271 | 169: h BINGET 10 | |
272 | 171: K BININT1 5 | |
273 | 173: e APPENDS (MARK at 3) | |
274 | 174: . STOP | |
275 | highest protocol among opcodes = 1 | |
276 | """ | |
277 | ||
278 | DATA2 = ('\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00' | |
279 | 'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00' | |
280 | '\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK' | |
281 | '\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff' | |
282 | 'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00' | |
283 | '\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo' | |
284 | 'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.') | |
285 | ||
286 | # Disassembly of DATA2. | |
287 | DATA2_DIS = """\ | |
288 | 0: \x80 PROTO 2 | |
289 | 2: ] EMPTY_LIST | |
290 | 3: q BINPUT 1 | |
291 | 5: ( MARK | |
292 | 6: K BININT1 0 | |
293 | 8: \x8a LONG1 1L | |
294 | 11: G BINFLOAT 2.0 | |
295 | 20: c GLOBAL '__builtin__ complex' | |
296 | 41: q BINPUT 2 | |
297 | 43: G BINFLOAT 3.0 | |
298 | 52: G BINFLOAT 0.0 | |
299 | 61: \x86 TUPLE2 | |
300 | 62: R REDUCE | |
301 | 63: q BINPUT 3 | |
302 | 65: K BININT1 1 | |
303 | 67: J BININT -1 | |
304 | 72: K BININT1 255 | |
305 | 74: J BININT -255 | |
306 | 79: J BININT -256 | |
307 | 84: M BININT2 65535 | |
308 | 87: J BININT -65535 | |
309 | 92: J BININT -65536 | |
310 | 97: J BININT 2147483647 | |
311 | 102: J BININT -2147483647 | |
312 | 107: J BININT -2147483648 | |
313 | 112: ( MARK | |
314 | 113: U SHORT_BINSTRING 'abc' | |
315 | 118: q BINPUT 4 | |
316 | 120: h BINGET 4 | |
317 | 122: ( MARK | |
318 | 123: c GLOBAL '__main__ C' | |
319 | 135: q BINPUT 5 | |
320 | 137: o OBJ (MARK at 122) | |
321 | 138: q BINPUT 6 | |
322 | 140: } EMPTY_DICT | |
323 | 141: q BINPUT 7 | |
324 | 143: ( MARK | |
325 | 144: U SHORT_BINSTRING 'foo' | |
326 | 149: q BINPUT 8 | |
327 | 151: K BININT1 1 | |
328 | 153: U SHORT_BINSTRING 'bar' | |
329 | 158: q BINPUT 9 | |
330 | 160: K BININT1 2 | |
331 | 162: u SETITEMS (MARK at 143) | |
332 | 163: b BUILD | |
333 | 164: h BINGET 6 | |
334 | 166: t TUPLE (MARK at 112) | |
335 | 167: q BINPUT 10 | |
336 | 169: h BINGET 10 | |
337 | 171: K BININT1 5 | |
338 | 173: e APPENDS (MARK at 5) | |
339 | 174: . STOP | |
340 | highest protocol among opcodes = 2 | |
341 | """ | |
342 | ||
343 | def create_data(): | |
344 | c = C() | |
345 | c.foo = 1 | |
346 | c.bar = 2 | |
347 | x = [0, 1L, 2.0, 3.0+0j] | |
348 | # Append some integer test cases at cPickle.c's internal size | |
349 | # cutoffs. | |
350 | uint1max = 0xff | |
351 | uint2max = 0xffff | |
352 | int4max = 0x7fffffff | |
353 | x.extend([1, -1, | |
354 | uint1max, -uint1max, -uint1max-1, | |
355 | uint2max, -uint2max, -uint2max-1, | |
356 | int4max, -int4max, -int4max-1]) | |
357 | y = ('abc', 'abc', c, c) | |
358 | x.append(y) | |
359 | x.append(y) | |
360 | x.append(5) | |
361 | return x | |
362 | ||
363 | class AbstractPickleTests(unittest.TestCase): | |
364 | # Subclass must define self.dumps, self.loads, self.error. | |
365 | ||
366 | _testdata = create_data() | |
367 | ||
368 | def setUp(self): | |
369 | pass | |
370 | ||
371 | def test_misc(self): | |
372 | # test various datatypes not tested by testdata | |
373 | for proto in protocols: | |
374 | x = myint(4) | |
375 | s = self.dumps(x, proto) | |
376 | y = self.loads(s) | |
377 | self.assertEqual(x, y) | |
378 | ||
379 | x = (1, ()) | |
380 | s = self.dumps(x, proto) | |
381 | y = self.loads(s) | |
382 | self.assertEqual(x, y) | |
383 | ||
384 | x = initarg(1, x) | |
385 | s = self.dumps(x, proto) | |
386 | y = self.loads(s) | |
387 | self.assertEqual(x, y) | |
388 | ||
389 | # XXX test __reduce__ protocol? | |
390 | ||
391 | def test_roundtrip_equality(self): | |
392 | expected = self._testdata | |
393 | for proto in protocols: | |
394 | s = self.dumps(expected, proto) | |
395 | got = self.loads(s) | |
396 | self.assertEqual(expected, got) | |
397 | ||
398 | def test_load_from_canned_string(self): | |
399 | expected = self._testdata | |
400 | for canned in DATA0, DATA1, DATA2: | |
401 | got = self.loads(canned) | |
402 | self.assertEqual(expected, got) | |
403 | ||
404 | # There are gratuitous differences between pickles produced by | |
405 | # pickle and cPickle, largely because cPickle starts PUT indices at | |
406 | # 1 and pickle starts them at 0. See XXX comment in cPickle's put2() -- | |
407 | # there's a comment with an exclamation point there whose meaning | |
408 | # is a mystery. cPickle also suppresses PUT for objects with a refcount | |
409 | # of 1. | |
410 | def dont_test_disassembly(self): | |
411 | from cStringIO import StringIO | |
412 | from pickletools import dis | |
413 | ||
414 | for proto, expected in (0, DATA0_DIS), (1, DATA1_DIS): | |
415 | s = self.dumps(self._testdata, proto) | |
416 | filelike = StringIO() | |
417 | dis(s, out=filelike) | |
418 | got = filelike.getvalue() | |
419 | self.assertEqual(expected, got) | |
420 | ||
421 | def test_recursive_list(self): | |
422 | l = [] | |
423 | l.append(l) | |
424 | for proto in protocols: | |
425 | s = self.dumps(l, proto) | |
426 | x = self.loads(s) | |
427 | self.assertEqual(len(x), 1) | |
428 | self.assert_(x is x[0]) | |
429 | ||
430 | def test_recursive_dict(self): | |
431 | d = {} | |
432 | d[1] = d | |
433 | for proto in protocols: | |
434 | s = self.dumps(d, proto) | |
435 | x = self.loads(s) | |
436 | self.assertEqual(x.keys(), [1]) | |
437 | self.assert_(x[1] is x) | |
438 | ||
439 | def test_recursive_inst(self): | |
440 | i = C() | |
441 | i.attr = i | |
442 | for proto in protocols: | |
443 | s = self.dumps(i, 2) | |
444 | x = self.loads(s) | |
445 | self.assertEqual(dir(x), dir(i)) | |
446 | self.assert_(x.attr is x) | |
447 | ||
448 | def test_recursive_multi(self): | |
449 | l = [] | |
450 | d = {1:l} | |
451 | i = C() | |
452 | i.attr = d | |
453 | l.append(i) | |
454 | for proto in protocols: | |
455 | s = self.dumps(l, proto) | |
456 | x = self.loads(s) | |
457 | self.assertEqual(len(x), 1) | |
458 | self.assertEqual(dir(x[0]), dir(i)) | |
459 | self.assertEqual(x[0].attr.keys(), [1]) | |
460 | self.assert_(x[0].attr[1] is x) | |
461 | ||
462 | def test_garyp(self): | |
463 | self.assertRaises(self.error, self.loads, 'garyp') | |
464 | ||
465 | def test_insecure_strings(self): | |
466 | insecure = ["abc", "2 + 2", # not quoted | |
467 | #"'abc' + 'def'", # not a single quoted string | |
468 | "'abc", # quote is not closed | |
469 | "'abc\"", # open quote and close quote don't match | |
470 | "'abc' ?", # junk after close quote | |
471 | "'\\'", # trailing backslash | |
472 | # some tests of the quoting rules | |
473 | #"'abc\"\''", | |
474 | #"'\\\\a\'\'\'\\\'\\\\\''", | |
475 | ] | |
476 | for s in insecure: | |
477 | buf = "S" + s + "\012p0\012." | |
478 | self.assertRaises(ValueError, self.loads, buf) | |
479 | ||
480 | if have_unicode: | |
481 | def test_unicode(self): | |
482 | endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'), | |
483 | unicode('<\n>'), unicode('<\\>')] | |
484 | for proto in protocols: | |
485 | for u in endcases: | |
486 | p = self.dumps(u, proto) | |
487 | u2 = self.loads(p) | |
488 | self.assertEqual(u2, u) | |
489 | ||
490 | def test_ints(self): | |
491 | import sys | |
492 | for proto in protocols: | |
493 | n = sys.maxint | |
494 | while n: | |
495 | for expected in (-n, n): | |
496 | s = self.dumps(expected, proto) | |
497 | n2 = self.loads(s) | |
498 | self.assertEqual(expected, n2) | |
499 | n = n >> 1 | |
500 | ||
501 | def test_maxint64(self): | |
502 | maxint64 = (1L << 63) - 1 | |
503 | data = 'I' + str(maxint64) + '\n.' | |
504 | got = self.loads(data) | |
505 | self.assertEqual(got, maxint64) | |
506 | ||
507 | # Try too with a bogus literal. | |
508 | data = 'I' + str(maxint64) + 'JUNK\n.' | |
509 | self.assertRaises(ValueError, self.loads, data) | |
510 | ||
511 | def test_long(self): | |
512 | for proto in protocols: | |
513 | # 256 bytes is where LONG4 begins. | |
514 | for nbits in 1, 8, 8*254, 8*255, 8*256, 8*257: | |
515 | nbase = 1L << nbits | |
516 | for npos in nbase-1, nbase, nbase+1: | |
517 | for n in npos, -npos: | |
518 | pickle = self.dumps(n, proto) | |
519 | got = self.loads(pickle) | |
520 | self.assertEqual(n, got) | |
521 | # Try a monster. This is quadratic-time in protos 0 & 1, so don't | |
522 | # bother with those. | |
523 | nbase = long("deadbeeffeedface", 16) | |
524 | nbase += nbase << 1000000 | |
525 | for n in nbase, -nbase: | |
526 | p = self.dumps(n, 2) | |
527 | got = self.loads(p) | |
528 | self.assertEqual(n, got) | |
529 | ||
530 | def test_reduce(self): | |
531 | pass | |
532 | ||
533 | def test_getinitargs(self): | |
534 | pass | |
535 | ||
536 | def test_metaclass(self): | |
537 | a = use_metaclass() | |
538 | for proto in protocols: | |
539 | s = self.dumps(a, proto) | |
540 | b = self.loads(s) | |
541 | self.assertEqual(a.__class__, b.__class__) | |
542 | ||
543 | def test_structseq(self): | |
544 | import time | |
545 | import os | |
546 | ||
547 | t = time.localtime() | |
548 | for proto in protocols: | |
549 | s = self.dumps(t, proto) | |
550 | u = self.loads(s) | |
551 | self.assertEqual(t, u) | |
552 | if hasattr(os, "stat"): | |
553 | t = os.stat(os.curdir) | |
554 | s = self.dumps(t, proto) | |
555 | u = self.loads(s) | |
556 | self.assertEqual(t, u) | |
557 | if hasattr(os, "statvfs"): | |
558 | t = os.statvfs(os.curdir) | |
559 | s = self.dumps(t, proto) | |
560 | u = self.loads(s) | |
561 | self.assertEqual(t, u) | |
562 | ||
563 | # Tests for protocol 2 | |
564 | ||
565 | def test_proto(self): | |
566 | build_none = pickle.NONE + pickle.STOP | |
567 | for proto in protocols: | |
568 | expected = build_none | |
569 | if proto >= 2: | |
570 | expected = pickle.PROTO + chr(proto) + expected | |
571 | p = self.dumps(None, proto) | |
572 | self.assertEqual(p, expected) | |
573 | ||
574 | oob = protocols[-1] + 1 # a future protocol | |
575 | badpickle = pickle.PROTO + chr(oob) + build_none | |
576 | try: | |
577 | self.loads(badpickle) | |
578 | except ValueError, detail: | |
579 | self.failUnless(str(detail).startswith( | |
580 | "unsupported pickle protocol")) | |
581 | else: | |
582 | self.fail("expected bad protocol number to raise ValueError") | |
583 | ||
584 | def test_long1(self): | |
585 | x = 12345678910111213141516178920L | |
586 | for proto in protocols: | |
587 | s = self.dumps(x, proto) | |
588 | y = self.loads(s) | |
589 | self.assertEqual(x, y) | |
590 | self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2) | |
591 | ||
592 | def test_long4(self): | |
593 | x = 12345678910111213141516178920L << (256*8) | |
594 | for proto in protocols: | |
595 | s = self.dumps(x, proto) | |
596 | y = self.loads(s) | |
597 | self.assertEqual(x, y) | |
598 | self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2) | |
599 | ||
600 | def test_short_tuples(self): | |
601 | # Map (proto, len(tuple)) to expected opcode. | |
602 | expected_opcode = {(0, 0): pickle.TUPLE, | |
603 | (0, 1): pickle.TUPLE, | |
604 | (0, 2): pickle.TUPLE, | |
605 | (0, 3): pickle.TUPLE, | |
606 | (0, 4): pickle.TUPLE, | |
607 | ||
608 | (1, 0): pickle.EMPTY_TUPLE, | |
609 | (1, 1): pickle.TUPLE, | |
610 | (1, 2): pickle.TUPLE, | |
611 | (1, 3): pickle.TUPLE, | |
612 | (1, 4): pickle.TUPLE, | |
613 | ||
614 | (2, 0): pickle.EMPTY_TUPLE, | |
615 | (2, 1): pickle.TUPLE1, | |
616 | (2, 2): pickle.TUPLE2, | |
617 | (2, 3): pickle.TUPLE3, | |
618 | (2, 4): pickle.TUPLE, | |
619 | } | |
620 | a = () | |
621 | b = (1,) | |
622 | c = (1, 2) | |
623 | d = (1, 2, 3) | |
624 | e = (1, 2, 3, 4) | |
625 | for proto in protocols: | |
626 | for x in a, b, c, d, e: | |
627 | s = self.dumps(x, proto) | |
628 | y = self.loads(s) | |
629 | self.assertEqual(x, y, (proto, x, s, y)) | |
630 | expected = expected_opcode[proto, len(x)] | |
631 | self.assertEqual(opcode_in_pickle(expected, s), True) | |
632 | ||
633 | def test_singletons(self): | |
634 | # Map (proto, singleton) to expected opcode. | |
635 | expected_opcode = {(0, None): pickle.NONE, | |
636 | (1, None): pickle.NONE, | |
637 | (2, None): pickle.NONE, | |
638 | ||
639 | (0, True): pickle.INT, | |
640 | (1, True): pickle.INT, | |
641 | (2, True): pickle.NEWTRUE, | |
642 | ||
643 | (0, False): pickle.INT, | |
644 | (1, False): pickle.INT, | |
645 | (2, False): pickle.NEWFALSE, | |
646 | } | |
647 | for proto in protocols: | |
648 | for x in None, False, True: | |
649 | s = self.dumps(x, proto) | |
650 | y = self.loads(s) | |
651 | self.assert_(x is y, (proto, x, s, y)) | |
652 | expected = expected_opcode[proto, x] | |
653 | self.assertEqual(opcode_in_pickle(expected, s), True) | |
654 | ||
655 | def test_newobj_tuple(self): | |
656 | x = MyTuple([1, 2, 3]) | |
657 | x.foo = 42 | |
658 | x.bar = "hello" | |
659 | for proto in protocols: | |
660 | s = self.dumps(x, proto) | |
661 | y = self.loads(s) | |
662 | self.assertEqual(tuple(x), tuple(y)) | |
663 | self.assertEqual(x.__dict__, y.__dict__) | |
664 | ||
665 | def test_newobj_list(self): | |
666 | x = MyList([1, 2, 3]) | |
667 | x.foo = 42 | |
668 | x.bar = "hello" | |
669 | for proto in protocols: | |
670 | s = self.dumps(x, proto) | |
671 | y = self.loads(s) | |
672 | self.assertEqual(list(x), list(y)) | |
673 | self.assertEqual(x.__dict__, y.__dict__) | |
674 | ||
675 | def test_newobj_generic(self): | |
676 | for proto in protocols: | |
677 | for C in myclasses: | |
678 | B = C.__base__ | |
679 | x = C(C.sample) | |
680 | x.foo = 42 | |
681 | s = self.dumps(x, proto) | |
682 | y = self.loads(s) | |
683 | detail = (proto, C, B, x, y, type(y)) | |
684 | self.assertEqual(B(x), B(y), detail) | |
685 | self.assertEqual(x.__dict__, y.__dict__, detail) | |
686 | ||
687 | # Register a type with copy_reg, with extension code extcode. Pickle | |
688 | # an object of that type. Check that the resulting pickle uses opcode | |
689 | # (EXT[124]) under proto 2, and not in proto 1. | |
690 | ||
691 | def produce_global_ext(self, extcode, opcode): | |
692 | e = ExtensionSaver(extcode) | |
693 | try: | |
694 | copy_reg.add_extension(__name__, "MyList", extcode) | |
695 | x = MyList([1, 2, 3]) | |
696 | x.foo = 42 | |
697 | x.bar = "hello" | |
698 | ||
699 | # Dump using protocol 1 for comparison. | |
700 | s1 = self.dumps(x, 1) | |
701 | self.assert_(__name__ in s1) | |
702 | self.assert_("MyList" in s1) | |
703 | self.assertEqual(opcode_in_pickle(opcode, s1), False) | |
704 | ||
705 | y = self.loads(s1) | |
706 | self.assertEqual(list(x), list(y)) | |
707 | self.assertEqual(x.__dict__, y.__dict__) | |
708 | ||
709 | # Dump using protocol 2 for test. | |
710 | s2 = self.dumps(x, 2) | |
711 | self.assert_(__name__ not in s2) | |
712 | self.assert_("MyList" not in s2) | |
713 | self.assertEqual(opcode_in_pickle(opcode, s2), True) | |
714 | ||
715 | y = self.loads(s2) | |
716 | self.assertEqual(list(x), list(y)) | |
717 | self.assertEqual(x.__dict__, y.__dict__) | |
718 | ||
719 | finally: | |
720 | e.restore() | |
721 | ||
722 | def test_global_ext1(self): | |
723 | self.produce_global_ext(0x00000001, pickle.EXT1) # smallest EXT1 code | |
724 | self.produce_global_ext(0x000000ff, pickle.EXT1) # largest EXT1 code | |
725 | ||
726 | def test_global_ext2(self): | |
727 | self.produce_global_ext(0x00000100, pickle.EXT2) # smallest EXT2 code | |
728 | self.produce_global_ext(0x0000ffff, pickle.EXT2) # largest EXT2 code | |
729 | self.produce_global_ext(0x0000abcd, pickle.EXT2) # check endianness | |
730 | ||
731 | def test_global_ext4(self): | |
732 | self.produce_global_ext(0x00010000, pickle.EXT4) # smallest EXT4 code | |
733 | self.produce_global_ext(0x7fffffff, pickle.EXT4) # largest EXT4 code | |
734 | self.produce_global_ext(0x12abcdef, pickle.EXT4) # check endianness | |
735 | ||
736 | def test_list_chunking(self): | |
737 | n = 10 # too small to chunk | |
738 | x = range(n) | |
739 | for proto in protocols: | |
740 | s = self.dumps(x, proto) | |
741 | y = self.loads(s) | |
742 | self.assertEqual(x, y) | |
743 | num_appends = count_opcode(pickle.APPENDS, s) | |
744 | self.assertEqual(num_appends, proto > 0) | |
745 | ||
746 | n = 2500 # expect at least two chunks when proto > 0 | |
747 | x = range(n) | |
748 | for proto in protocols: | |
749 | s = self.dumps(x, proto) | |
750 | y = self.loads(s) | |
751 | self.assertEqual(x, y) | |
752 | num_appends = count_opcode(pickle.APPENDS, s) | |
753 | if proto == 0: | |
754 | self.assertEqual(num_appends, 0) | |
755 | else: | |
756 | self.failUnless(num_appends >= 2) | |
757 | ||
758 | def test_dict_chunking(self): | |
759 | n = 10 # too small to chunk | |
760 | x = dict.fromkeys(range(n)) | |
761 | for proto in protocols: | |
762 | s = self.dumps(x, proto) | |
763 | y = self.loads(s) | |
764 | self.assertEqual(x, y) | |
765 | num_setitems = count_opcode(pickle.SETITEMS, s) | |
766 | self.assertEqual(num_setitems, proto > 0) | |
767 | ||
768 | n = 2500 # expect at least two chunks when proto > 0 | |
769 | x = dict.fromkeys(range(n)) | |
770 | for proto in protocols: | |
771 | s = self.dumps(x, proto) | |
772 | y = self.loads(s) | |
773 | self.assertEqual(x, y) | |
774 | num_setitems = count_opcode(pickle.SETITEMS, s) | |
775 | if proto == 0: | |
776 | self.assertEqual(num_setitems, 0) | |
777 | else: | |
778 | self.failUnless(num_setitems >= 2) | |
779 | ||
780 | def test_simple_newobj(self): | |
781 | x = object.__new__(SimpleNewObj) # avoid __init__ | |
782 | x.abc = 666 | |
783 | for proto in protocols: | |
784 | s = self.dumps(x, proto) | |
785 | self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2) | |
786 | y = self.loads(s) # will raise TypeError if __init__ called | |
787 | self.assertEqual(y.abc, 666) | |
788 | self.assertEqual(x.__dict__, y.__dict__) | |
789 | ||
790 | def test_newobj_list_slots(self): | |
791 | x = SlotList([1, 2, 3]) | |
792 | x.foo = 42 | |
793 | x.bar = "hello" | |
794 | s = self.dumps(x, 2) | |
795 | y = self.loads(s) | |
796 | self.assertEqual(list(x), list(y)) | |
797 | self.assertEqual(x.__dict__, y.__dict__) | |
798 | self.assertEqual(x.foo, y.foo) | |
799 | self.assertEqual(x.bar, y.bar) | |
800 | ||
801 | def test_reduce_overrides_default_reduce_ex(self): | |
802 | for proto in 0, 1, 2: | |
803 | x = REX_one() | |
804 | self.assertEqual(x._reduce_called, 0) | |
805 | s = self.dumps(x, proto) | |
806 | self.assertEqual(x._reduce_called, 1) | |
807 | y = self.loads(s) | |
808 | self.assertEqual(y._reduce_called, 0) | |
809 | ||
810 | def test_reduce_ex_called(self): | |
811 | for proto in 0, 1, 2: | |
812 | x = REX_two() | |
813 | self.assertEqual(x._proto, None) | |
814 | s = self.dumps(x, proto) | |
815 | self.assertEqual(x._proto, proto) | |
816 | y = self.loads(s) | |
817 | self.assertEqual(y._proto, None) | |
818 | ||
819 | def test_reduce_ex_overrides_reduce(self): | |
820 | for proto in 0, 1, 2: | |
821 | x = REX_three() | |
822 | self.assertEqual(x._proto, None) | |
823 | s = self.dumps(x, proto) | |
824 | self.assertEqual(x._proto, proto) | |
825 | y = self.loads(s) | |
826 | self.assertEqual(y._proto, None) | |
827 | ||
828 | # Test classes for reduce_ex | |
829 | ||
830 | class REX_one(object): | |
831 | _reduce_called = 0 | |
832 | def __reduce__(self): | |
833 | self._reduce_called = 1 | |
834 | return REX_one, () | |
835 | # No __reduce_ex__ here, but inheriting it from object | |
836 | ||
837 | class REX_two(object): | |
838 | _proto = None | |
839 | def __reduce_ex__(self, proto): | |
840 | self._proto = proto | |
841 | return REX_two, () | |
842 | # No __reduce__ here, but inheriting it from object | |
843 | ||
844 | class REX_three(object): | |
845 | _proto = None | |
846 | def __reduce_ex__(self, proto): | |
847 | self._proto = proto | |
848 | return REX_two, () | |
849 | def __reduce__(self): | |
850 | raise TestFailed, "This __reduce__ shouldn't be called" | |
851 | ||
852 | # Test classes for newobj | |
853 | ||
854 | class MyInt(int): | |
855 | sample = 1 | |
856 | ||
857 | class MyLong(long): | |
858 | sample = 1L | |
859 | ||
860 | class MyFloat(float): | |
861 | sample = 1.0 | |
862 | ||
863 | class MyComplex(complex): | |
864 | sample = 1.0 + 0.0j | |
865 | ||
866 | class MyStr(str): | |
867 | sample = "hello" | |
868 | ||
869 | class MyUnicode(unicode): | |
870 | sample = u"hello \u1234" | |
871 | ||
872 | class MyTuple(tuple): | |
873 | sample = (1, 2, 3) | |
874 | ||
875 | class MyList(list): | |
876 | sample = [1, 2, 3] | |
877 | ||
878 | class MyDict(dict): | |
879 | sample = {"a": 1, "b": 2} | |
880 | ||
881 | myclasses = [MyInt, MyLong, MyFloat, | |
882 | MyComplex, | |
883 | MyStr, MyUnicode, | |
884 | MyTuple, MyList, MyDict] | |
885 | ||
886 | ||
887 | class SlotList(MyList): | |
888 | __slots__ = ["foo"] | |
889 | ||
890 | class SimpleNewObj(object): | |
891 | def __init__(self, a, b, c): | |
892 | # raise an error, to make sure this isn't called | |
893 | raise TypeError("SimpleNewObj.__init__() didn't expect to get called") | |
894 | ||
895 | class AbstractPickleModuleTests(unittest.TestCase): | |
896 | ||
897 | def test_dump_closed_file(self): | |
898 | import os | |
899 | f = open(TESTFN, "w") | |
900 | try: | |
901 | f.close() | |
902 | self.assertRaises(ValueError, self.module.dump, 123, f) | |
903 | finally: | |
904 | os.remove(TESTFN) | |
905 | ||
906 | def test_load_closed_file(self): | |
907 | import os | |
908 | f = open(TESTFN, "w") | |
909 | try: | |
910 | f.close() | |
911 | self.assertRaises(ValueError, self.module.dump, 123, f) | |
912 | finally: | |
913 | os.remove(TESTFN) | |
914 | ||
915 | def test_highest_protocol(self): | |
916 | # Of course this needs to be changed when HIGHEST_PROTOCOL changes. | |
917 | self.assertEqual(self.module.HIGHEST_PROTOCOL, 2) | |
918 | ||
919 | def test_callapi(self): | |
920 | from cStringIO import StringIO | |
921 | f = StringIO() | |
922 | # With and without keyword arguments | |
923 | self.module.dump(123, f, -1) | |
924 | self.module.dump(123, file=f, protocol=-1) | |
925 | self.module.dumps(123, -1) | |
926 | self.module.dumps(123, protocol=-1) | |
927 | self.module.Pickler(f, -1) | |
928 | self.module.Pickler(f, protocol=-1) | |
929 | ||
930 | class AbstractPersistentPicklerTests(unittest.TestCase): | |
931 | ||
932 | # This class defines persistent_id() and persistent_load() | |
933 | # functions that should be used by the pickler. All even integers | |
934 | # are pickled using persistent ids. | |
935 | ||
936 | def persistent_id(self, object): | |
937 | if isinstance(object, int) and object % 2 == 0: | |
938 | self.id_count += 1 | |
939 | return str(object) | |
940 | else: | |
941 | return None | |
942 | ||
943 | def persistent_load(self, oid): | |
944 | self.load_count += 1 | |
945 | object = int(oid) | |
946 | assert object % 2 == 0 | |
947 | return object | |
948 | ||
949 | def test_persistence(self): | |
950 | self.id_count = 0 | |
951 | self.load_count = 0 | |
952 | L = range(10) | |
953 | self.assertEqual(self.loads(self.dumps(L)), L) | |
954 | self.assertEqual(self.id_count, 5) | |
955 | self.assertEqual(self.load_count, 5) | |
956 | ||
957 | def test_bin_persistence(self): | |
958 | self.id_count = 0 | |
959 | self.load_count = 0 | |
960 | L = range(10) | |
961 | self.assertEqual(self.loads(self.dumps(L, 1)), L) | |
962 | self.assertEqual(self.id_count, 5) | |
963 | self.assertEqual(self.load_count, 5) |