Commit | Line | Data |
---|---|---|
920dae64 AT |
1 | """Unittests for heapq.""" |
2 | ||
3 | from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest | |
4 | import random | |
5 | import unittest | |
6 | from test import test_support | |
7 | import sys | |
8 | ||
9 | ||
10 | def heapiter(heap): | |
11 | # An iterator returning a heap's elements, smallest-first. | |
12 | try: | |
13 | while 1: | |
14 | yield heappop(heap) | |
15 | except IndexError: | |
16 | pass | |
17 | ||
18 | class TestHeap(unittest.TestCase): | |
19 | ||
20 | def test_push_pop(self): | |
21 | # 1) Push 256 random numbers and pop them off, verifying all's OK. | |
22 | heap = [] | |
23 | data = [] | |
24 | self.check_invariant(heap) | |
25 | for i in range(256): | |
26 | item = random.random() | |
27 | data.append(item) | |
28 | heappush(heap, item) | |
29 | self.check_invariant(heap) | |
30 | results = [] | |
31 | while heap: | |
32 | item = heappop(heap) | |
33 | self.check_invariant(heap) | |
34 | results.append(item) | |
35 | data_sorted = data[:] | |
36 | data_sorted.sort() | |
37 | self.assertEqual(data_sorted, results) | |
38 | # 2) Check that the invariant holds for a sorted array | |
39 | self.check_invariant(results) | |
40 | ||
41 | self.assertRaises(TypeError, heappush, []) | |
42 | try: | |
43 | self.assertRaises(TypeError, heappush, None, None) | |
44 | self.assertRaises(TypeError, heappop, None) | |
45 | except AttributeError: | |
46 | pass | |
47 | ||
48 | def check_invariant(self, heap): | |
49 | # Check the heap invariant. | |
50 | for pos, item in enumerate(heap): | |
51 | if pos: # pos 0 has no parent | |
52 | parentpos = (pos-1) >> 1 | |
53 | self.assert_(heap[parentpos] <= item) | |
54 | ||
55 | def test_heapify(self): | |
56 | for size in range(30): | |
57 | heap = [random.random() for dummy in range(size)] | |
58 | heapify(heap) | |
59 | self.check_invariant(heap) | |
60 | ||
61 | self.assertRaises(TypeError, heapify, None) | |
62 | ||
63 | def test_naive_nbest(self): | |
64 | data = [random.randrange(2000) for i in range(1000)] | |
65 | heap = [] | |
66 | for item in data: | |
67 | heappush(heap, item) | |
68 | if len(heap) > 10: | |
69 | heappop(heap) | |
70 | heap.sort() | |
71 | self.assertEqual(heap, sorted(data)[-10:]) | |
72 | ||
73 | def test_nbest(self): | |
74 | # Less-naive "N-best" algorithm, much faster (if len(data) is big | |
75 | # enough <wink>) than sorting all of data. However, if we had a max | |
76 | # heap instead of a min heap, it could go faster still via | |
77 | # heapify'ing all of data (linear time), then doing 10 heappops | |
78 | # (10 log-time steps). | |
79 | data = [random.randrange(2000) for i in range(1000)] | |
80 | heap = data[:10] | |
81 | heapify(heap) | |
82 | for item in data[10:]: | |
83 | if item > heap[0]: # this gets rarer the longer we run | |
84 | heapreplace(heap, item) | |
85 | self.assertEqual(list(heapiter(heap)), sorted(data)[-10:]) | |
86 | ||
87 | self.assertRaises(TypeError, heapreplace, None) | |
88 | self.assertRaises(TypeError, heapreplace, None, None) | |
89 | self.assertRaises(IndexError, heapreplace, [], None) | |
90 | ||
91 | def test_heapsort(self): | |
92 | # Exercise everything with repeated heapsort checks | |
93 | for trial in xrange(100): | |
94 | size = random.randrange(50) | |
95 | data = [random.randrange(25) for i in range(size)] | |
96 | if trial & 1: # Half of the time, use heapify | |
97 | heap = data[:] | |
98 | heapify(heap) | |
99 | else: # The rest of the time, use heappush | |
100 | heap = [] | |
101 | for item in data: | |
102 | heappush(heap, item) | |
103 | heap_sorted = [heappop(heap) for i in range(size)] | |
104 | self.assertEqual(heap_sorted, sorted(data)) | |
105 | ||
106 | def test_nsmallest(self): | |
107 | data = [random.randrange(2000) for i in range(1000)] | |
108 | for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): | |
109 | self.assertEqual(nsmallest(n, data), sorted(data)[:n]) | |
110 | ||
111 | def test_largest(self): | |
112 | data = [random.randrange(2000) for i in range(1000)] | |
113 | for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): | |
114 | self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n]) | |
115 | ||
116 | ||
117 | #============================================================================== | |
118 | ||
119 | class LenOnly: | |
120 | "Dummy sequence class defining __len__ but not __getitem__." | |
121 | def __len__(self): | |
122 | return 10 | |
123 | ||
124 | class GetOnly: | |
125 | "Dummy sequence class defining __getitem__ but not __len__." | |
126 | def __getitem__(self, ndx): | |
127 | return 10 | |
128 | ||
129 | class CmpErr: | |
130 | "Dummy element that always raises an error during comparison" | |
131 | def __cmp__(self, other): | |
132 | raise ZeroDivisionError | |
133 | ||
134 | def R(seqn): | |
135 | 'Regular generator' | |
136 | for i in seqn: | |
137 | yield i | |
138 | ||
139 | class G: | |
140 | 'Sequence using __getitem__' | |
141 | def __init__(self, seqn): | |
142 | self.seqn = seqn | |
143 | def __getitem__(self, i): | |
144 | return self.seqn[i] | |
145 | ||
146 | class I: | |
147 | 'Sequence using iterator protocol' | |
148 | def __init__(self, seqn): | |
149 | self.seqn = seqn | |
150 | self.i = 0 | |
151 | def __iter__(self): | |
152 | return self | |
153 | def next(self): | |
154 | if self.i >= len(self.seqn): raise StopIteration | |
155 | v = self.seqn[self.i] | |
156 | self.i += 1 | |
157 | return v | |
158 | ||
159 | class Ig: | |
160 | 'Sequence using iterator protocol defined with a generator' | |
161 | def __init__(self, seqn): | |
162 | self.seqn = seqn | |
163 | self.i = 0 | |
164 | def __iter__(self): | |
165 | for val in self.seqn: | |
166 | yield val | |
167 | ||
168 | class X: | |
169 | 'Missing __getitem__ and __iter__' | |
170 | def __init__(self, seqn): | |
171 | self.seqn = seqn | |
172 | self.i = 0 | |
173 | def next(self): | |
174 | if self.i >= len(self.seqn): raise StopIteration | |
175 | v = self.seqn[self.i] | |
176 | self.i += 1 | |
177 | return v | |
178 | ||
179 | class N: | |
180 | 'Iterator missing next()' | |
181 | def __init__(self, seqn): | |
182 | self.seqn = seqn | |
183 | self.i = 0 | |
184 | def __iter__(self): | |
185 | return self | |
186 | ||
187 | class E: | |
188 | 'Test propagation of exceptions' | |
189 | def __init__(self, seqn): | |
190 | self.seqn = seqn | |
191 | self.i = 0 | |
192 | def __iter__(self): | |
193 | return self | |
194 | def next(self): | |
195 | 3 // 0 | |
196 | ||
197 | class S: | |
198 | 'Test immediate stop' | |
199 | def __init__(self, seqn): | |
200 | pass | |
201 | def __iter__(self): | |
202 | return self | |
203 | def next(self): | |
204 | raise StopIteration | |
205 | ||
206 | from itertools import chain, imap | |
207 | def L(seqn): | |
208 | 'Test multiple tiers of iterators' | |
209 | return chain(imap(lambda x:x, R(Ig(G(seqn))))) | |
210 | ||
211 | class TestErrorHandling(unittest.TestCase): | |
212 | ||
213 | def test_non_sequence(self): | |
214 | for f in (heapify, heappop): | |
215 | self.assertRaises(TypeError, f, 10) | |
216 | for f in (heappush, heapreplace, nlargest, nsmallest): | |
217 | self.assertRaises(TypeError, f, 10, 10) | |
218 | ||
219 | def test_len_only(self): | |
220 | for f in (heapify, heappop): | |
221 | self.assertRaises(TypeError, f, LenOnly()) | |
222 | for f in (heappush, heapreplace): | |
223 | self.assertRaises(TypeError, f, LenOnly(), 10) | |
224 | for f in (nlargest, nsmallest): | |
225 | self.assertRaises(TypeError, f, 2, LenOnly()) | |
226 | ||
227 | def test_get_only(self): | |
228 | for f in (heapify, heappop): | |
229 | self.assertRaises(TypeError, f, GetOnly()) | |
230 | for f in (heappush, heapreplace): | |
231 | self.assertRaises(TypeError, f, GetOnly(), 10) | |
232 | for f in (nlargest, nsmallest): | |
233 | self.assertRaises(TypeError, f, 2, GetOnly()) | |
234 | ||
235 | def test_get_only(self): | |
236 | seq = [CmpErr(), CmpErr(), CmpErr()] | |
237 | for f in (heapify, heappop): | |
238 | self.assertRaises(ZeroDivisionError, f, seq) | |
239 | for f in (heappush, heapreplace): | |
240 | self.assertRaises(ZeroDivisionError, f, seq, 10) | |
241 | for f in (nlargest, nsmallest): | |
242 | self.assertRaises(ZeroDivisionError, f, 2, seq) | |
243 | ||
244 | def test_arg_parsing(self): | |
245 | for f in (heapify, heappop, heappush, heapreplace, nlargest, nsmallest): | |
246 | self.assertRaises(TypeError, f, 10) | |
247 | ||
248 | def test_iterable_args(self): | |
249 | for f in (nlargest, nsmallest): | |
250 | for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): | |
251 | for g in (G, I, Ig, L, R): | |
252 | self.assertEqual(f(2, g(s)), f(2,s)) | |
253 | self.assertEqual(f(2, S(s)), []) | |
254 | self.assertRaises(TypeError, f, 2, X(s)) | |
255 | self.assertRaises(TypeError, f, 2, N(s)) | |
256 | self.assertRaises(ZeroDivisionError, f, 2, E(s)) | |
257 | ||
258 | #============================================================================== | |
259 | ||
260 | ||
261 | def test_main(verbose=None): | |
262 | from types import BuiltinFunctionType | |
263 | ||
264 | test_classes = [TestHeap] | |
265 | if isinstance(heapify, BuiltinFunctionType): | |
266 | test_classes.append(TestErrorHandling) | |
267 | test_support.run_unittest(*test_classes) | |
268 | ||
269 | # verify reference counting | |
270 | if verbose and hasattr(sys, "gettotalrefcount"): | |
271 | import gc | |
272 | counts = [None] * 5 | |
273 | for i in xrange(len(counts)): | |
274 | test_support.run_unittest(*test_classes) | |
275 | gc.collect() | |
276 | counts[i] = sys.gettotalrefcount() | |
277 | print counts | |
278 | ||
279 | if __name__ == "__main__": | |
280 | test_main(verbose=True) |