Commit | Line | Data |
---|---|---|
920dae64 AT |
1 | import unittest, os |
2 | from test import test_support | |
3 | ||
4 | import warnings | |
5 | warnings.filterwarnings( | |
6 | "ignore", | |
7 | category=DeprecationWarning, | |
8 | message=".*complex divmod.*are deprecated" | |
9 | ) | |
10 | ||
11 | from random import random | |
12 | ||
13 | # These tests ensure that complex math does the right thing | |
14 | ||
15 | class ComplexTest(unittest.TestCase): | |
16 | ||
17 | def assertAlmostEqual(self, a, b): | |
18 | if isinstance(a, complex): | |
19 | if isinstance(b, complex): | |
20 | unittest.TestCase.assertAlmostEqual(self, a.real, b.real) | |
21 | unittest.TestCase.assertAlmostEqual(self, a.imag, b.imag) | |
22 | else: | |
23 | unittest.TestCase.assertAlmostEqual(self, a.real, b) | |
24 | unittest.TestCase.assertAlmostEqual(self, a.imag, 0.) | |
25 | else: | |
26 | if isinstance(b, complex): | |
27 | unittest.TestCase.assertAlmostEqual(self, a, b.real) | |
28 | unittest.TestCase.assertAlmostEqual(self, 0., b.imag) | |
29 | else: | |
30 | unittest.TestCase.assertAlmostEqual(self, a, b) | |
31 | ||
32 | def assertCloseAbs(self, x, y, eps=1e-9): | |
33 | """Return true iff floats x and y "are close\"""" | |
34 | # put the one with larger magnitude second | |
35 | if abs(x) > abs(y): | |
36 | x, y = y, x | |
37 | if y == 0: | |
38 | return abs(x) < eps | |
39 | if x == 0: | |
40 | return abs(y) < eps | |
41 | # check that relative difference < eps | |
42 | self.assert_(abs((x-y)/y) < eps) | |
43 | ||
44 | def assertClose(self, x, y, eps=1e-9): | |
45 | """Return true iff complexes x and y "are close\"""" | |
46 | self.assertCloseAbs(x.real, y.real, eps) | |
47 | self.assertCloseAbs(x.imag, y.imag, eps) | |
48 | ||
49 | def assertIs(self, a, b): | |
50 | self.assert_(a is b) | |
51 | ||
52 | def check_div(self, x, y): | |
53 | """Compute complex z=x*y, and check that z/x==y and z/y==x.""" | |
54 | z = x * y | |
55 | if x != 0: | |
56 | q = z / x | |
57 | self.assertClose(q, y) | |
58 | q = z.__div__(x) | |
59 | self.assertClose(q, y) | |
60 | q = z.__truediv__(x) | |
61 | self.assertClose(q, y) | |
62 | if y != 0: | |
63 | q = z / y | |
64 | self.assertClose(q, x) | |
65 | q = z.__div__(y) | |
66 | self.assertClose(q, x) | |
67 | q = z.__truediv__(y) | |
68 | self.assertClose(q, x) | |
69 | ||
70 | def test_div(self): | |
71 | simple_real = [float(i) for i in xrange(-5, 6)] | |
72 | simple_complex = [complex(x, y) for x in simple_real for y in simple_real] | |
73 | for x in simple_complex: | |
74 | for y in simple_complex: | |
75 | self.check_div(x, y) | |
76 | ||
77 | # A naive complex division algorithm (such as in 2.0) is very prone to | |
78 | # nonsense errors for these (overflows and underflows). | |
79 | self.check_div(complex(1e200, 1e200), 1+0j) | |
80 | self.check_div(complex(1e-200, 1e-200), 1+0j) | |
81 | ||
82 | # Just for fun. | |
83 | for i in xrange(100): | |
84 | self.check_div(complex(random(), random()), | |
85 | complex(random(), random())) | |
86 | ||
87 | self.assertRaises(ZeroDivisionError, complex.__div__, 1+1j, 0+0j) | |
88 | # FIXME: The following currently crashes on Alpha | |
89 | # self.assertRaises(OverflowError, pow, 1e200+1j, 1e200+1j) | |
90 | ||
91 | def test_truediv(self): | |
92 | self.assertAlmostEqual(complex.__truediv__(2+0j, 1+1j), 1-1j) | |
93 | self.assertRaises(ZeroDivisionError, complex.__truediv__, 1+1j, 0+0j) | |
94 | ||
95 | def test_floordiv(self): | |
96 | self.assertAlmostEqual(complex.__floordiv__(3+0j, 1.5+0j), 2) | |
97 | self.assertRaises(ZeroDivisionError, complex.__floordiv__, 3+0j, 0+0j) | |
98 | ||
99 | def test_coerce(self): | |
100 | self.assertRaises(OverflowError, complex.__coerce__, 1+1j, 1L<<10000) | |
101 | ||
102 | def test_richcompare(self): | |
103 | self.assertRaises(OverflowError, complex.__eq__, 1+1j, 1L<<10000) | |
104 | self.assertEqual(complex.__lt__(1+1j, None), NotImplemented) | |
105 | self.assertIs(complex.__eq__(1+1j, 1+1j), True) | |
106 | self.assertIs(complex.__eq__(1+1j, 2+2j), False) | |
107 | self.assertIs(complex.__ne__(1+1j, 1+1j), False) | |
108 | self.assertIs(complex.__ne__(1+1j, 2+2j), True) | |
109 | self.assertRaises(TypeError, complex.__lt__, 1+1j, 2+2j) | |
110 | self.assertRaises(TypeError, complex.__le__, 1+1j, 2+2j) | |
111 | self.assertRaises(TypeError, complex.__gt__, 1+1j, 2+2j) | |
112 | self.assertRaises(TypeError, complex.__ge__, 1+1j, 2+2j) | |
113 | ||
114 | def test_mod(self): | |
115 | self.assertRaises(ZeroDivisionError, (1+1j).__mod__, 0+0j) | |
116 | ||
117 | a = 3.33+4.43j | |
118 | try: | |
119 | a % 0 | |
120 | except ZeroDivisionError: | |
121 | pass | |
122 | else: | |
123 | self.fail("modulo parama can't be 0") | |
124 | ||
125 | def test_divmod(self): | |
126 | self.assertRaises(ZeroDivisionError, divmod, 1+1j, 0+0j) | |
127 | ||
128 | def test_pow(self): | |
129 | self.assertAlmostEqual(pow(1+1j, 0+0j), 1.0) | |
130 | self.assertAlmostEqual(pow(0+0j, 2+0j), 0.0) | |
131 | self.assertRaises(ZeroDivisionError, pow, 0+0j, 1j) | |
132 | self.assertAlmostEqual(pow(1j, -1), 1/1j) | |
133 | self.assertAlmostEqual(pow(1j, 200), 1) | |
134 | self.assertRaises(ValueError, pow, 1+1j, 1+1j, 1+1j) | |
135 | ||
136 | a = 3.33+4.43j | |
137 | self.assertEqual(a ** 0j, 1) | |
138 | self.assertEqual(a ** 0.+0.j, 1) | |
139 | ||
140 | self.assertEqual(3j ** 0j, 1) | |
141 | self.assertEqual(3j ** 0, 1) | |
142 | ||
143 | try: | |
144 | 0j ** a | |
145 | except ZeroDivisionError: | |
146 | pass | |
147 | else: | |
148 | self.fail("should fail 0.0 to negative or complex power") | |
149 | ||
150 | try: | |
151 | 0j ** (3-2j) | |
152 | except ZeroDivisionError: | |
153 | pass | |
154 | else: | |
155 | self.fail("should fail 0.0 to negative or complex power") | |
156 | ||
157 | # The following is used to exercise certain code paths | |
158 | self.assertEqual(a ** 105, a ** 105) | |
159 | self.assertEqual(a ** -105, a ** -105) | |
160 | self.assertEqual(a ** -30, a ** -30) | |
161 | ||
162 | self.assertEqual(0.0j ** 0, 1) | |
163 | ||
164 | b = 5.1+2.3j | |
165 | self.assertRaises(ValueError, pow, a, b, 0) | |
166 | ||
167 | def test_boolcontext(self): | |
168 | for i in xrange(100): | |
169 | self.assert_(complex(random() + 1e-6, random() + 1e-6)) | |
170 | self.assert_(not complex(0.0, 0.0)) | |
171 | ||
172 | def test_conjugate(self): | |
173 | self.assertClose(complex(5.3, 9.8).conjugate(), 5.3-9.8j) | |
174 | ||
175 | def test_constructor(self): | |
176 | class OS: | |
177 | def __init__(self, value): self.value = value | |
178 | def __complex__(self): return self.value | |
179 | class NS(object): | |
180 | def __init__(self, value): self.value = value | |
181 | def __complex__(self): return self.value | |
182 | self.assertEqual(complex(OS(1+10j)), 1+10j) | |
183 | self.assertEqual(complex(NS(1+10j)), 1+10j) | |
184 | self.assertRaises(TypeError, complex, OS(None)) | |
185 | self.assertRaises(TypeError, complex, NS(None)) | |
186 | ||
187 | self.assertAlmostEqual(complex("1+10j"), 1+10j) | |
188 | self.assertAlmostEqual(complex(10), 10+0j) | |
189 | self.assertAlmostEqual(complex(10.0), 10+0j) | |
190 | self.assertAlmostEqual(complex(10L), 10+0j) | |
191 | self.assertAlmostEqual(complex(10+0j), 10+0j) | |
192 | self.assertAlmostEqual(complex(1,10), 1+10j) | |
193 | self.assertAlmostEqual(complex(1,10L), 1+10j) | |
194 | self.assertAlmostEqual(complex(1,10.0), 1+10j) | |
195 | self.assertAlmostEqual(complex(1L,10), 1+10j) | |
196 | self.assertAlmostEqual(complex(1L,10L), 1+10j) | |
197 | self.assertAlmostEqual(complex(1L,10.0), 1+10j) | |
198 | self.assertAlmostEqual(complex(1.0,10), 1+10j) | |
199 | self.assertAlmostEqual(complex(1.0,10L), 1+10j) | |
200 | self.assertAlmostEqual(complex(1.0,10.0), 1+10j) | |
201 | self.assertAlmostEqual(complex(3.14+0j), 3.14+0j) | |
202 | self.assertAlmostEqual(complex(3.14), 3.14+0j) | |
203 | self.assertAlmostEqual(complex(314), 314.0+0j) | |
204 | self.assertAlmostEqual(complex(314L), 314.0+0j) | |
205 | self.assertAlmostEqual(complex(3.14+0j, 0j), 3.14+0j) | |
206 | self.assertAlmostEqual(complex(3.14, 0.0), 3.14+0j) | |
207 | self.assertAlmostEqual(complex(314, 0), 314.0+0j) | |
208 | self.assertAlmostEqual(complex(314L, 0L), 314.0+0j) | |
209 | self.assertAlmostEqual(complex(0j, 3.14j), -3.14+0j) | |
210 | self.assertAlmostEqual(complex(0.0, 3.14j), -3.14+0j) | |
211 | self.assertAlmostEqual(complex(0j, 3.14), 3.14j) | |
212 | self.assertAlmostEqual(complex(0.0, 3.14), 3.14j) | |
213 | self.assertAlmostEqual(complex("1"), 1+0j) | |
214 | self.assertAlmostEqual(complex("1j"), 1j) | |
215 | self.assertAlmostEqual(complex(), 0) | |
216 | self.assertAlmostEqual(complex("-1"), -1) | |
217 | self.assertAlmostEqual(complex("+1"), +1) | |
218 | ||
219 | class complex2(complex): pass | |
220 | self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j) | |
221 | self.assertAlmostEqual(complex(real=17, imag=23), 17+23j) | |
222 | self.assertAlmostEqual(complex(real=17+23j), 17+23j) | |
223 | self.assertAlmostEqual(complex(real=17+23j, imag=23), 17+46j) | |
224 | self.assertAlmostEqual(complex(real=1+2j, imag=3+4j), -3+5j) | |
225 | ||
226 | c = 3.14 + 1j | |
227 | self.assert_(complex(c) is c) | |
228 | del c | |
229 | ||
230 | self.assertRaises(TypeError, complex, "1", "1") | |
231 | self.assertRaises(TypeError, complex, 1, "1") | |
232 | ||
233 | self.assertEqual(complex(" 3.14+J "), 3.14+1j) | |
234 | if test_support.have_unicode: | |
235 | self.assertEqual(complex(unicode(" 3.14+J ")), 3.14+1j) | |
236 | ||
237 | # SF bug 543840: complex(string) accepts strings with \0 | |
238 | # Fixed in 2.3. | |
239 | self.assertRaises(ValueError, complex, '1+1j\0j') | |
240 | ||
241 | self.assertRaises(TypeError, int, 5+3j) | |
242 | self.assertRaises(TypeError, long, 5+3j) | |
243 | self.assertRaises(TypeError, float, 5+3j) | |
244 | self.assertRaises(ValueError, complex, "") | |
245 | self.assertRaises(TypeError, complex, None) | |
246 | self.assertRaises(ValueError, complex, "\0") | |
247 | self.assertRaises(TypeError, complex, "1", "2") | |
248 | self.assertRaises(TypeError, complex, "1", 42) | |
249 | self.assertRaises(TypeError, complex, 1, "2") | |
250 | self.assertRaises(ValueError, complex, "1+") | |
251 | self.assertRaises(ValueError, complex, "1+1j+1j") | |
252 | self.assertRaises(ValueError, complex, "--") | |
253 | if test_support.have_unicode: | |
254 | self.assertRaises(ValueError, complex, unicode("1"*500)) | |
255 | self.assertRaises(ValueError, complex, unicode("x")) | |
256 | ||
257 | class EvilExc(Exception): | |
258 | pass | |
259 | ||
260 | class evilcomplex: | |
261 | def __complex__(self): | |
262 | raise EvilExc | |
263 | ||
264 | self.assertRaises(EvilExc, complex, evilcomplex()) | |
265 | ||
266 | class float2: | |
267 | def __init__(self, value): | |
268 | self.value = value | |
269 | def __float__(self): | |
270 | return self.value | |
271 | ||
272 | self.assertAlmostEqual(complex(float2(42.)), 42) | |
273 | self.assertAlmostEqual(complex(real=float2(17.), imag=float2(23.)), 17+23j) | |
274 | self.assertRaises(TypeError, complex, float2(None)) | |
275 | ||
276 | def test_hash(self): | |
277 | for x in xrange(-30, 30): | |
278 | self.assertEqual(hash(x), hash(complex(x, 0))) | |
279 | x /= 3.0 # now check against floating point | |
280 | self.assertEqual(hash(x), hash(complex(x, 0.))) | |
281 | ||
282 | def test_abs(self): | |
283 | nums = [complex(x/3., y/7.) for x in xrange(-9,9) for y in xrange(-9,9)] | |
284 | for num in nums: | |
285 | self.assertAlmostEqual((num.real**2 + num.imag**2) ** 0.5, abs(num)) | |
286 | ||
287 | def test_repr(self): | |
288 | self.assertEqual(repr(1+6j), '(1+6j)') | |
289 | self.assertEqual(repr(1-6j), '(1-6j)') | |
290 | ||
291 | self.assertNotEqual(repr(-(1+0j)), '(-1+-0j)') | |
292 | ||
293 | def test_neg(self): | |
294 | self.assertEqual(-(1+6j), -1-6j) | |
295 | ||
296 | def test_file(self): | |
297 | a = 3.33+4.43j | |
298 | b = 5.1+2.3j | |
299 | ||
300 | fo = None | |
301 | try: | |
302 | fo = open(test_support.TESTFN, "wb") | |
303 | print >>fo, a, b | |
304 | fo.close() | |
305 | fo = open(test_support.TESTFN, "rb") | |
306 | self.assertEqual(fo.read(), "%s %s\n" % (a, b)) | |
307 | finally: | |
308 | if (fo is not None) and (not fo.closed): | |
309 | fo.close() | |
310 | try: | |
311 | os.remove(test_support.TESTFN) | |
312 | except (OSError, IOError): | |
313 | pass | |
314 | ||
315 | def test_main(): | |
316 | test_support.run_unittest(ComplexTest) | |
317 | ||
318 | if __name__ == "__main__": | |
319 | test_main() |