Commit | Line | Data |
---|---|---|
920dae64 AT |
1 | """Helper to provide extensibility for pickle/cPickle. |
2 | ||
3 | This is only useful to add pickle support for extension types defined in | |
4 | C, not for instances of user-defined classes. | |
5 | """ | |
6 | ||
7 | from types import ClassType as _ClassType | |
8 | ||
9 | __all__ = ["pickle", "constructor", | |
10 | "add_extension", "remove_extension", "clear_extension_cache"] | |
11 | ||
12 | dispatch_table = {} | |
13 | ||
14 | def pickle(ob_type, pickle_function, constructor_ob=None): | |
15 | if type(ob_type) is _ClassType: | |
16 | raise TypeError("copy_reg is not intended for use with classes") | |
17 | ||
18 | if not callable(pickle_function): | |
19 | raise TypeError("reduction functions must be callable") | |
20 | dispatch_table[ob_type] = pickle_function | |
21 | ||
22 | # The constructor_ob function is a vestige of safe for unpickling. | |
23 | # There is no reason for the caller to pass it anymore. | |
24 | if constructor_ob is not None: | |
25 | constructor(constructor_ob) | |
26 | ||
27 | def constructor(object): | |
28 | if not callable(object): | |
29 | raise TypeError("constructors must be callable") | |
30 | ||
31 | # Example: provide pickling support for complex numbers. | |
32 | ||
33 | try: | |
34 | complex | |
35 | except NameError: | |
36 | pass | |
37 | else: | |
38 | ||
39 | def pickle_complex(c): | |
40 | return complex, (c.real, c.imag) | |
41 | ||
42 | pickle(complex, pickle_complex, complex) | |
43 | ||
44 | # Support for pickling new-style objects | |
45 | ||
46 | def _reconstructor(cls, base, state): | |
47 | if base is object: | |
48 | obj = object.__new__(cls) | |
49 | else: | |
50 | obj = base.__new__(cls, state) | |
51 | base.__init__(obj, state) | |
52 | return obj | |
53 | ||
54 | _HEAPTYPE = 1<<9 | |
55 | ||
56 | # Python code for object.__reduce_ex__ for protocols 0 and 1 | |
57 | ||
58 | def _reduce_ex(self, proto): | |
59 | assert proto < 2 | |
60 | for base in self.__class__.__mro__: | |
61 | if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE: | |
62 | break | |
63 | else: | |
64 | base = object # not really reachable | |
65 | if base is object: | |
66 | state = None | |
67 | else: | |
68 | if base is self.__class__: | |
69 | raise TypeError, "can't pickle %s objects" % base.__name__ | |
70 | state = base(self) | |
71 | args = (self.__class__, base, state) | |
72 | try: | |
73 | getstate = self.__getstate__ | |
74 | except AttributeError: | |
75 | if getattr(self, "__slots__", None): | |
76 | raise TypeError("a class that defines __slots__ without " | |
77 | "defining __getstate__ cannot be pickled") | |
78 | try: | |
79 | dict = self.__dict__ | |
80 | except AttributeError: | |
81 | dict = None | |
82 | else: | |
83 | dict = getstate() | |
84 | if dict: | |
85 | return _reconstructor, args, dict | |
86 | else: | |
87 | return _reconstructor, args | |
88 | ||
89 | # Helper for __reduce_ex__ protocol 2 | |
90 | ||
91 | def __newobj__(cls, *args): | |
92 | return cls.__new__(cls, *args) | |
93 | ||
94 | def _slotnames(cls): | |
95 | """Return a list of slot names for a given class. | |
96 | ||
97 | This needs to find slots defined by the class and its bases, so we | |
98 | can't simply return the __slots__ attribute. We must walk down | |
99 | the Method Resolution Order and concatenate the __slots__ of each | |
100 | class found there. (This assumes classes don't modify their | |
101 | __slots__ attribute to misrepresent their slots after the class is | |
102 | defined.) | |
103 | """ | |
104 | ||
105 | # Get the value from a cache in the class if possible | |
106 | names = cls.__dict__.get("__slotnames__") | |
107 | if names is not None: | |
108 | return names | |
109 | ||
110 | # Not cached -- calculate the value | |
111 | names = [] | |
112 | if not hasattr(cls, "__slots__"): | |
113 | # This class has no slots | |
114 | pass | |
115 | else: | |
116 | # Slots found -- gather slot names from all base classes | |
117 | for c in cls.__mro__: | |
118 | if "__slots__" in c.__dict__: | |
119 | names += [name for name in c.__dict__["__slots__"] | |
120 | if name not in ("__dict__", "__weakref__")] | |
121 | ||
122 | # Cache the outcome in the class if at all possible | |
123 | try: | |
124 | cls.__slotnames__ = names | |
125 | except: | |
126 | pass # But don't die if we can't | |
127 | ||
128 | return names | |
129 | ||
130 | # A registry of extension codes. This is an ad-hoc compression | |
131 | # mechanism. Whenever a global reference to <module>, <name> is about | |
132 | # to be pickled, the (<module>, <name>) tuple is looked up here to see | |
133 | # if it is a registered extension code for it. Extension codes are | |
134 | # universal, so that the meaning of a pickle does not depend on | |
135 | # context. (There are also some codes reserved for local use that | |
136 | # don't have this restriction.) Codes are positive ints; 0 is | |
137 | # reserved. | |
138 | ||
139 | _extension_registry = {} # key -> code | |
140 | _inverted_registry = {} # code -> key | |
141 | _extension_cache = {} # code -> object | |
142 | # Don't ever rebind those names: cPickle grabs a reference to them when | |
143 | # it's initialized, and won't see a rebinding. | |
144 | ||
145 | def add_extension(module, name, code): | |
146 | """Register an extension code.""" | |
147 | code = int(code) | |
148 | if not 1 <= code <= 0x7fffffff: | |
149 | raise ValueError, "code out of range" | |
150 | key = (module, name) | |
151 | if (_extension_registry.get(key) == code and | |
152 | _inverted_registry.get(code) == key): | |
153 | return # Redundant registrations are benign | |
154 | if key in _extension_registry: | |
155 | raise ValueError("key %s is already registered with code %s" % | |
156 | (key, _extension_registry[key])) | |
157 | if code in _inverted_registry: | |
158 | raise ValueError("code %s is already in use for key %s" % | |
159 | (code, _inverted_registry[code])) | |
160 | _extension_registry[key] = code | |
161 | _inverted_registry[code] = key | |
162 | ||
163 | def remove_extension(module, name, code): | |
164 | """Unregister an extension code. For testing only.""" | |
165 | key = (module, name) | |
166 | if (_extension_registry.get(key) != code or | |
167 | _inverted_registry.get(code) != key): | |
168 | raise ValueError("key %s is not registered with code %s" % | |
169 | (key, code)) | |
170 | del _extension_registry[key] | |
171 | del _inverted_registry[code] | |
172 | if code in _extension_cache: | |
173 | del _extension_cache[code] | |
174 | ||
175 | def clear_extension_cache(): | |
176 | _extension_cache.clear() | |
177 | ||
178 | # Standard extension code assignments | |
179 | ||
180 | # Reserved ranges | |
181 | ||
182 | # First Last Count Purpose | |
183 | # 1 127 127 Reserved for Python standard library | |
184 | # 128 191 64 Reserved for Zope | |
185 | # 192 239 48 Reserved for 3rd parties | |
186 | # 240 255 16 Reserved for private use (will never be assigned) | |
187 | # 256 Inf Inf Reserved for future assignment | |
188 | ||
189 | # Extension codes are assigned by the Python Software Foundation. |