Commit | Line | Data |
---|---|---|
920dae64 AT |
1 | import sys |
2 | import imp | |
3 | import os | |
4 | import unittest | |
5 | from test import test_support | |
6 | ||
7 | ||
8 | test_src = """\ | |
9 | def get_name(): | |
10 | return __name__ | |
11 | def get_file(): | |
12 | return __file__ | |
13 | """ | |
14 | ||
15 | reload_src = test_src+"""\ | |
16 | reloaded = True | |
17 | """ | |
18 | ||
19 | test_co = compile(test_src, "<???>", "exec") | |
20 | reload_co = compile(reload_src, "<???>", "exec") | |
21 | ||
22 | test_path = "!!!_test_!!!" | |
23 | ||
24 | ||
25 | class ImportTracker: | |
26 | """Importer that only tracks attempted imports.""" | |
27 | def __init__(self): | |
28 | self.imports = [] | |
29 | def find_module(self, fullname, path=None): | |
30 | self.imports.append(fullname) | |
31 | return None | |
32 | ||
33 | ||
34 | class TestImporter: | |
35 | ||
36 | modules = { | |
37 | "hooktestmodule": (False, test_co), | |
38 | "hooktestpackage": (True, test_co), | |
39 | "hooktestpackage.sub": (True, test_co), | |
40 | "hooktestpackage.sub.subber": (False, test_co), | |
41 | "reloadmodule": (False, test_co), | |
42 | } | |
43 | ||
44 | def __init__(self, path=test_path): | |
45 | if path != test_path: | |
46 | # if out class is on sys.path_hooks, we must raise | |
47 | # ImportError for any path item that we can't handle. | |
48 | raise ImportError | |
49 | self.path = path | |
50 | ||
51 | def _get__path__(self): | |
52 | raise NotImplementedError | |
53 | ||
54 | def find_module(self, fullname, path=None): | |
55 | if fullname in self.modules: | |
56 | return self | |
57 | else: | |
58 | return None | |
59 | ||
60 | def load_module(self, fullname): | |
61 | ispkg, code = self.modules[fullname] | |
62 | mod = sys.modules.setdefault(fullname,imp.new_module(fullname)) | |
63 | mod.__file__ = "<%s>" % self.__class__.__name__ | |
64 | mod.__loader__ = self | |
65 | if ispkg: | |
66 | mod.__path__ = self._get__path__() | |
67 | exec code in mod.__dict__ | |
68 | return mod | |
69 | ||
70 | ||
71 | class MetaImporter(TestImporter): | |
72 | def _get__path__(self): | |
73 | return [] | |
74 | ||
75 | class PathImporter(TestImporter): | |
76 | def _get__path__(self): | |
77 | return [self.path] | |
78 | ||
79 | ||
80 | class ImportBlocker: | |
81 | """Place an ImportBlocker instance on sys.meta_path and you | |
82 | can be sure the modules you specified can't be imported, even | |
83 | if it's a builtin.""" | |
84 | def __init__(self, *namestoblock): | |
85 | self.namestoblock = dict.fromkeys(namestoblock) | |
86 | def find_module(self, fullname, path=None): | |
87 | if fullname in self.namestoblock: | |
88 | return self | |
89 | return None | |
90 | def load_module(self, fullname): | |
91 | raise ImportError, "I dare you" | |
92 | ||
93 | ||
94 | class ImpWrapper: | |
95 | ||
96 | def __init__(self, path=None): | |
97 | if path is not None and not os.path.isdir(path): | |
98 | raise ImportError | |
99 | self.path = path | |
100 | ||
101 | def find_module(self, fullname, path=None): | |
102 | subname = fullname.split(".")[-1] | |
103 | if subname != fullname and self.path is None: | |
104 | return None | |
105 | if self.path is None: | |
106 | path = None | |
107 | else: | |
108 | path = [self.path] | |
109 | try: | |
110 | file, filename, stuff = imp.find_module(subname, path) | |
111 | except ImportError: | |
112 | return None | |
113 | return ImpLoader(file, filename, stuff) | |
114 | ||
115 | ||
116 | class ImpLoader: | |
117 | ||
118 | def __init__(self, file, filename, stuff): | |
119 | self.file = file | |
120 | self.filename = filename | |
121 | self.stuff = stuff | |
122 | ||
123 | def load_module(self, fullname): | |
124 | mod = imp.load_module(fullname, self.file, self.filename, self.stuff) | |
125 | if self.file: | |
126 | self.file.close() | |
127 | mod.__loader__ = self # for introspection | |
128 | return mod | |
129 | ||
130 | ||
131 | class ImportHooksBaseTestCase(unittest.TestCase): | |
132 | ||
133 | def setUp(self): | |
134 | self.path = sys.path[:] | |
135 | self.meta_path = sys.meta_path[:] | |
136 | self.path_hooks = sys.path_hooks[:] | |
137 | sys.path_importer_cache.clear() | |
138 | self.tracker = ImportTracker() | |
139 | sys.meta_path.insert(0, self.tracker) | |
140 | ||
141 | def tearDown(self): | |
142 | sys.path[:] = self.path | |
143 | sys.meta_path[:] = self.meta_path | |
144 | sys.path_hooks[:] = self.path_hooks | |
145 | sys.path_importer_cache.clear() | |
146 | for fullname in self.tracker.imports: | |
147 | if fullname in sys.modules: | |
148 | del sys.modules[fullname] | |
149 | ||
150 | ||
151 | class ImportHooksTestCase(ImportHooksBaseTestCase): | |
152 | ||
153 | def doTestImports(self, importer=None): | |
154 | import hooktestmodule | |
155 | import hooktestpackage | |
156 | import hooktestpackage.sub | |
157 | import hooktestpackage.sub.subber | |
158 | self.assertEqual(hooktestmodule.get_name(), | |
159 | "hooktestmodule") | |
160 | self.assertEqual(hooktestpackage.get_name(), | |
161 | "hooktestpackage") | |
162 | self.assertEqual(hooktestpackage.sub.get_name(), | |
163 | "hooktestpackage.sub") | |
164 | self.assertEqual(hooktestpackage.sub.subber.get_name(), | |
165 | "hooktestpackage.sub.subber") | |
166 | if importer: | |
167 | self.assertEqual(hooktestmodule.__loader__, importer) | |
168 | self.assertEqual(hooktestpackage.__loader__, importer) | |
169 | self.assertEqual(hooktestpackage.sub.__loader__, importer) | |
170 | self.assertEqual(hooktestpackage.sub.subber.__loader__, importer) | |
171 | ||
172 | TestImporter.modules['reloadmodule'] = (False, test_co) | |
173 | import reloadmodule | |
174 | self.failIf(hasattr(reloadmodule,'reloaded')) | |
175 | ||
176 | TestImporter.modules['reloadmodule'] = (False, reload_co) | |
177 | reload(reloadmodule) | |
178 | self.failUnless(hasattr(reloadmodule,'reloaded')) | |
179 | ||
180 | def testMetaPath(self): | |
181 | i = MetaImporter() | |
182 | sys.meta_path.append(i) | |
183 | self.doTestImports(i) | |
184 | ||
185 | def testPathHook(self): | |
186 | sys.path_hooks.append(PathImporter) | |
187 | sys.path.append(test_path) | |
188 | self.doTestImports() | |
189 | ||
190 | def testBlocker(self): | |
191 | mname = "exceptions" # an arbitrary harmless builtin module | |
192 | if mname in sys.modules: | |
193 | del sys.modules[mname] | |
194 | sys.meta_path.append(ImportBlocker(mname)) | |
195 | try: | |
196 | __import__(mname) | |
197 | except ImportError: | |
198 | pass | |
199 | else: | |
200 | self.fail("'%s' was not supposed to be importable" % mname) | |
201 | ||
202 | def testImpWrapper(self): | |
203 | i = ImpWrapper() | |
204 | sys.meta_path.append(i) | |
205 | sys.path_hooks.append(ImpWrapper) | |
206 | mnames = ("colorsys", "urlparse", "distutils.core", "compiler.misc") | |
207 | for mname in mnames: | |
208 | parent = mname.split(".")[0] | |
209 | for n in sys.modules.keys(): | |
210 | if n.startswith(parent): | |
211 | del sys.modules[n] | |
212 | for mname in mnames: | |
213 | m = __import__(mname, globals(), locals(), ["__dummy__"]) | |
214 | m.__loader__ # to make sure we actually handled the import | |
215 | ||
216 | def test_main(): | |
217 | test_support.run_unittest(ImportHooksTestCase) | |
218 | ||
219 | if __name__ == "__main__": | |
220 | test_main() |