Commit | Line | Data |
---|---|---|
86530b38 AT |
1 | # Tests for rich comparisons |
2 | ||
3 | import unittest | |
4 | from test import test_support | |
5 | ||
6 | import operator | |
7 | ||
8 | class Number: | |
9 | ||
10 | def __init__(self, x): | |
11 | self.x = x | |
12 | ||
13 | def __lt__(self, other): | |
14 | return self.x < other | |
15 | ||
16 | def __le__(self, other): | |
17 | return self.x <= other | |
18 | ||
19 | def __eq__(self, other): | |
20 | return self.x == other | |
21 | ||
22 | def __ne__(self, other): | |
23 | return self.x != other | |
24 | ||
25 | def __gt__(self, other): | |
26 | return self.x > other | |
27 | ||
28 | def __ge__(self, other): | |
29 | return self.x >= other | |
30 | ||
31 | def __cmp__(self, other): | |
32 | raise test_support.TestFailed, "Number.__cmp__() should not be called" | |
33 | ||
34 | def __repr__(self): | |
35 | return "Number(%r)" % (self.x, ) | |
36 | ||
37 | class Vector: | |
38 | ||
39 | def __init__(self, data): | |
40 | self.data = data | |
41 | ||
42 | def __len__(self): | |
43 | return len(self.data) | |
44 | ||
45 | def __getitem__(self, i): | |
46 | return self.data[i] | |
47 | ||
48 | def __setitem__(self, i, v): | |
49 | self.data[i] = v | |
50 | ||
51 | def __hash__(self): | |
52 | raise TypeError, "Vectors cannot be hashed" | |
53 | ||
54 | def __nonzero__(self): | |
55 | raise TypeError, "Vectors cannot be used in Boolean contexts" | |
56 | ||
57 | def __cmp__(self, other): | |
58 | raise test_support.TestFailed, "Vector.__cmp__() should not be called" | |
59 | ||
60 | def __repr__(self): | |
61 | return "Vector(%r)" % (self.data, ) | |
62 | ||
63 | def __lt__(self, other): | |
64 | return Vector([a < b for a, b in zip(self.data, self.__cast(other))]) | |
65 | ||
66 | def __le__(self, other): | |
67 | return Vector([a <= b for a, b in zip(self.data, self.__cast(other))]) | |
68 | ||
69 | def __eq__(self, other): | |
70 | return Vector([a == b for a, b in zip(self.data, self.__cast(other))]) | |
71 | ||
72 | def __ne__(self, other): | |
73 | return Vector([a != b for a, b in zip(self.data, self.__cast(other))]) | |
74 | ||
75 | def __gt__(self, other): | |
76 | return Vector([a > b for a, b in zip(self.data, self.__cast(other))]) | |
77 | ||
78 | def __ge__(self, other): | |
79 | return Vector([a >= b for a, b in zip(self.data, self.__cast(other))]) | |
80 | ||
81 | def __cast(self, other): | |
82 | if isinstance(other, Vector): | |
83 | other = other.data | |
84 | if len(self.data) != len(other): | |
85 | raise ValueError, "Cannot compare vectors of different length" | |
86 | return other | |
87 | ||
88 | opmap = { | |
89 | "lt": (lambda a,b: a< b, operator.lt, operator.__lt__), | |
90 | "le": (lambda a,b: a<=b, operator.le, operator.__le__), | |
91 | "eq": (lambda a,b: a==b, operator.eq, operator.__eq__), | |
92 | "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__), | |
93 | "gt": (lambda a,b: a> b, operator.gt, operator.__gt__), | |
94 | "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__) | |
95 | } | |
96 | ||
97 | class VectorTest(unittest.TestCase): | |
98 | ||
99 | def checkfail(self, error, opname, *args): | |
100 | for op in opmap[opname]: | |
101 | self.assertRaises(error, op, *args) | |
102 | ||
103 | def checkequal(self, opname, a, b, expres): | |
104 | for op in opmap[opname]: | |
105 | realres = op(a, b) | |
106 | # can't use assertEqual(realres, expres) here | |
107 | self.assertEqual(len(realres), len(expres)) | |
108 | for i in xrange(len(realres)): | |
109 | # results are bool, so we can use "is" here | |
110 | self.assert_(realres[i] is expres[i]) | |
111 | ||
112 | def test_mixed(self): | |
113 | # check that comparisons involving Vector objects | |
114 | # which return rich results (i.e. Vectors with itemwise | |
115 | # comparison results) work | |
116 | a = Vector(range(2)) | |
117 | b = Vector(range(3)) | |
118 | # all comparisons should fail for different length | |
119 | for opname in opmap: | |
120 | self.checkfail(ValueError, opname, a, b) | |
121 | ||
122 | a = range(5) | |
123 | b = 5 * [2] | |
124 | # try mixed arguments (but not (a, b) as that won't return a bool vector) | |
125 | args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))] | |
126 | for (a, b) in args: | |
127 | self.checkequal("lt", a, b, [True, True, False, False, False]) | |
128 | self.checkequal("le", a, b, [True, True, True, False, False]) | |
129 | self.checkequal("eq", a, b, [False, False, True, False, False]) | |
130 | self.checkequal("ne", a, b, [True, True, False, True, True ]) | |
131 | self.checkequal("gt", a, b, [False, False, False, True, True ]) | |
132 | self.checkequal("ge", a, b, [False, False, True, True, True ]) | |
133 | ||
134 | for ops in opmap.itervalues(): | |
135 | for op in ops: | |
136 | # calls __nonzero__, which should fail | |
137 | self.assertRaises(TypeError, bool, op(a, b)) | |
138 | ||
139 | class NumberTest(unittest.TestCase): | |
140 | ||
141 | def test_basic(self): | |
142 | # Check that comparisons involving Number objects | |
143 | # give the same results give as comparing the | |
144 | # corresponding ints | |
145 | for a in xrange(3): | |
146 | for b in xrange(3): | |
147 | for typea in (int, Number): | |
148 | for typeb in (int, Number): | |
149 | if typea==typeb==int: | |
150 | continue # the combination int, int is useless | |
151 | ta = typea(a) | |
152 | tb = typeb(b) | |
153 | for ops in opmap.itervalues(): | |
154 | for op in ops: | |
155 | realoutcome = op(a, b) | |
156 | testoutcome = op(ta, tb) | |
157 | self.assertEqual(realoutcome, testoutcome) | |
158 | ||
159 | def checkvalue(self, opname, a, b, expres): | |
160 | for typea in (int, Number): | |
161 | for typeb in (int, Number): | |
162 | ta = typea(a) | |
163 | tb = typeb(b) | |
164 | for op in opmap[opname]: | |
165 | realres = op(ta, tb) | |
166 | realres = getattr(realres, "x", realres) | |
167 | self.assert_(realres is expres) | |
168 | ||
169 | def test_values(self): | |
170 | # check all operators and all comparison results | |
171 | self.checkvalue("lt", 0, 0, False) | |
172 | self.checkvalue("le", 0, 0, True ) | |
173 | self.checkvalue("eq", 0, 0, True ) | |
174 | self.checkvalue("ne", 0, 0, False) | |
175 | self.checkvalue("gt", 0, 0, False) | |
176 | self.checkvalue("ge", 0, 0, True ) | |
177 | ||
178 | self.checkvalue("lt", 0, 1, True ) | |
179 | self.checkvalue("le", 0, 1, True ) | |
180 | self.checkvalue("eq", 0, 1, False) | |
181 | self.checkvalue("ne", 0, 1, True ) | |
182 | self.checkvalue("gt", 0, 1, False) | |
183 | self.checkvalue("ge", 0, 1, False) | |
184 | ||
185 | self.checkvalue("lt", 1, 0, False) | |
186 | self.checkvalue("le", 1, 0, False) | |
187 | self.checkvalue("eq", 1, 0, False) | |
188 | self.checkvalue("ne", 1, 0, True ) | |
189 | self.checkvalue("gt", 1, 0, True ) | |
190 | self.checkvalue("ge", 1, 0, True ) | |
191 | ||
192 | class MiscTest(unittest.TestCase): | |
193 | ||
194 | def test_misbehavin(self): | |
195 | class Misb: | |
196 | def __lt__(self, other): return 0 | |
197 | def __gt__(self, other): return 0 | |
198 | def __eq__(self, other): return 0 | |
199 | def __le__(self, other): raise TestFailed, "This shouldn't happen" | |
200 | def __ge__(self, other): raise TestFailed, "This shouldn't happen" | |
201 | def __ne__(self, other): raise TestFailed, "This shouldn't happen" | |
202 | def __cmp__(self, other): raise RuntimeError, "expected" | |
203 | a = Misb() | |
204 | b = Misb() | |
205 | self.assertEqual(a<b, 0) | |
206 | self.assertEqual(a==b, 0) | |
207 | self.assertEqual(a>b, 0) | |
208 | self.assertRaises(RuntimeError, cmp, a, b) | |
209 | ||
210 | def test_not(self): | |
211 | # Check that exceptions in __nonzero__ are properly | |
212 | # propagated by the not operator | |
213 | import operator | |
214 | class Exc: | |
215 | pass | |
216 | class Bad: | |
217 | def __nonzero__(self): | |
218 | raise Exc | |
219 | ||
220 | def do(bad): | |
221 | not bad | |
222 | ||
223 | for func in (do, operator.not_): | |
224 | self.assertRaises(Exc, func, Bad()) | |
225 | ||
226 | def test_recursion(self): | |
227 | # Check that comparison for recursive objects fails gracefully | |
228 | from UserList import UserList | |
229 | a = UserList() | |
230 | b = UserList() | |
231 | a.append(b) | |
232 | b.append(a) | |
233 | self.assertRaises(RuntimeError, operator.eq, a, b) | |
234 | self.assertRaises(RuntimeError, operator.ne, a, b) | |
235 | self.assertRaises(RuntimeError, operator.lt, a, b) | |
236 | self.assertRaises(RuntimeError, operator.le, a, b) | |
237 | self.assertRaises(RuntimeError, operator.gt, a, b) | |
238 | self.assertRaises(RuntimeError, operator.ge, a, b) | |
239 | ||
240 | b.append(17) | |
241 | # Even recursive lists of different lengths are different, | |
242 | # but they cannot be ordered | |
243 | self.assert_(not (a == b)) | |
244 | self.assert_(a != b) | |
245 | self.assertRaises(RuntimeError, operator.lt, a, b) | |
246 | self.assertRaises(RuntimeError, operator.le, a, b) | |
247 | self.assertRaises(RuntimeError, operator.gt, a, b) | |
248 | self.assertRaises(RuntimeError, operator.ge, a, b) | |
249 | a.append(17) | |
250 | self.assertRaises(RuntimeError, operator.eq, a, b) | |
251 | self.assertRaises(RuntimeError, operator.ne, a, b) | |
252 | a.insert(0, 11) | |
253 | b.insert(0, 12) | |
254 | self.assert_(not (a == b)) | |
255 | self.assert_(a != b) | |
256 | self.assert_(a < b) | |
257 | ||
258 | class DictTest(unittest.TestCase): | |
259 | ||
260 | def test_dicts(self): | |
261 | # Verify that __eq__ and __ne__ work for dicts even if the keys and | |
262 | # values don't support anything other than __eq__ and __ne__. Complex | |
263 | # numbers are a fine example of that. | |
264 | import random | |
265 | imag1a = {} | |
266 | for i in range(50): | |
267 | imag1a[random.randrange(100)*1j] = random.randrange(100)*1j | |
268 | items = imag1a.items() | |
269 | random.shuffle(items) | |
270 | imag1b = {} | |
271 | for k, v in items: | |
272 | imag1b[k] = v | |
273 | imag2 = imag1b.copy() | |
274 | imag2[k] = v + 1.0 | |
275 | self.assert_(imag1a == imag1a) | |
276 | self.assert_(imag1a == imag1b) | |
277 | self.assert_(imag2 == imag2) | |
278 | self.assert_(imag1a != imag2) | |
279 | for opname in ("lt", "le", "gt", "ge"): | |
280 | for op in opmap[opname]: | |
281 | self.assertRaises(TypeError, op, imag1a, imag2) | |
282 | ||
283 | class ListTest(unittest.TestCase): | |
284 | ||
285 | def assertIs(self, a, b): | |
286 | self.assert_(a is b) | |
287 | ||
288 | def test_coverage(self): | |
289 | # exercise all comparisons for lists | |
290 | x = [42] | |
291 | self.assertIs(x<x, False) | |
292 | self.assertIs(x<=x, True) | |
293 | self.assertIs(x==x, True) | |
294 | self.assertIs(x!=x, False) | |
295 | self.assertIs(x>x, False) | |
296 | self.assertIs(x>=x, True) | |
297 | y = [42, 42] | |
298 | self.assertIs(x<y, True) | |
299 | self.assertIs(x<=y, True) | |
300 | self.assertIs(x==y, False) | |
301 | self.assertIs(x!=y, True) | |
302 | self.assertIs(x>y, False) | |
303 | self.assertIs(x>=y, False) | |
304 | ||
305 | def test_badentry(self): | |
306 | # make sure that exceptions for item comparison are properly | |
307 | # propagated in list comparisons | |
308 | class Exc: | |
309 | pass | |
310 | class Bad: | |
311 | def __eq__(self, other): | |
312 | raise Exc | |
313 | ||
314 | x = [Bad()] | |
315 | y = [Bad()] | |
316 | ||
317 | for op in opmap["eq"]: | |
318 | self.assertRaises(Exc, op, x, y) | |
319 | ||
320 | def test_goodentry(self): | |
321 | # This test exercises the final call to PyObject_RichCompare() | |
322 | # in Objects/listobject.c::list_richcompare() | |
323 | class Good: | |
324 | def __lt__(self, other): | |
325 | return True | |
326 | ||
327 | x = [Good()] | |
328 | y = [Good()] | |
329 | ||
330 | for op in opmap["lt"]: | |
331 | self.assertIs(op(x, y), True) | |
332 | ||
333 | def test_main(): | |
334 | test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest) | |
335 | ||
336 | if __name__ == "__main__": | |
337 | test_main() |