Commit | Line | Data |
---|---|---|
920dae64 AT |
1 | import unittest |
2 | ||
3 | from test import test_support | |
4 | ||
5 | class G: | |
6 | 'Sequence using __getitem__' | |
7 | def __init__(self, seqn): | |
8 | self.seqn = seqn | |
9 | def __getitem__(self, i): | |
10 | return self.seqn[i] | |
11 | ||
12 | class I: | |
13 | 'Sequence using iterator protocol' | |
14 | def __init__(self, seqn): | |
15 | self.seqn = seqn | |
16 | self.i = 0 | |
17 | def __iter__(self): | |
18 | return self | |
19 | def next(self): | |
20 | if self.i >= len(self.seqn): raise StopIteration | |
21 | v = self.seqn[self.i] | |
22 | self.i += 1 | |
23 | return v | |
24 | ||
25 | class Ig: | |
26 | 'Sequence using iterator protocol defined with a generator' | |
27 | def __init__(self, seqn): | |
28 | self.seqn = seqn | |
29 | self.i = 0 | |
30 | def __iter__(self): | |
31 | for val in self.seqn: | |
32 | yield val | |
33 | ||
34 | class X: | |
35 | 'Missing __getitem__ and __iter__' | |
36 | def __init__(self, seqn): | |
37 | self.seqn = seqn | |
38 | self.i = 0 | |
39 | def next(self): | |
40 | if self.i >= len(self.seqn): raise StopIteration | |
41 | v = self.seqn[self.i] | |
42 | self.i += 1 | |
43 | return v | |
44 | ||
45 | class E: | |
46 | 'Test propagation of exceptions' | |
47 | def __init__(self, seqn): | |
48 | self.seqn = seqn | |
49 | self.i = 0 | |
50 | def __iter__(self): | |
51 | return self | |
52 | def next(self): | |
53 | 3 // 0 | |
54 | ||
55 | class N: | |
56 | 'Iterator missing next()' | |
57 | def __init__(self, seqn): | |
58 | self.seqn = seqn | |
59 | self.i = 0 | |
60 | def __iter__(self): | |
61 | return self | |
62 | ||
63 | class EnumerateTestCase(unittest.TestCase): | |
64 | ||
65 | enum = enumerate | |
66 | seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')] | |
67 | ||
68 | def test_basicfunction(self): | |
69 | self.assertEqual(type(self.enum(self.seq)), self.enum) | |
70 | e = self.enum(self.seq) | |
71 | self.assertEqual(iter(e), e) | |
72 | self.assertEqual(list(self.enum(self.seq)), self.res) | |
73 | self.enum.__doc__ | |
74 | ||
75 | def test_getitemseqn(self): | |
76 | self.assertEqual(list(self.enum(G(self.seq))), self.res) | |
77 | e = self.enum(G('')) | |
78 | self.assertRaises(StopIteration, e.next) | |
79 | ||
80 | def test_iteratorseqn(self): | |
81 | self.assertEqual(list(self.enum(I(self.seq))), self.res) | |
82 | e = self.enum(I('')) | |
83 | self.assertRaises(StopIteration, e.next) | |
84 | ||
85 | def test_iteratorgenerator(self): | |
86 | self.assertEqual(list(self.enum(Ig(self.seq))), self.res) | |
87 | e = self.enum(Ig('')) | |
88 | self.assertRaises(StopIteration, e.next) | |
89 | ||
90 | def test_noniterable(self): | |
91 | self.assertRaises(TypeError, self.enum, X(self.seq)) | |
92 | ||
93 | def test_illformediterable(self): | |
94 | self.assertRaises(TypeError, list, self.enum(N(self.seq))) | |
95 | ||
96 | def test_exception_propagation(self): | |
97 | self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq))) | |
98 | ||
99 | def test_argumentcheck(self): | |
100 | self.assertRaises(TypeError, self.enum) # no arguments | |
101 | self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable) | |
102 | self.assertRaises(TypeError, self.enum, 'abc', 2) # too many arguments | |
103 | ||
104 | def test_tuple_reuse(self): | |
105 | # Tests an implementation detail where tuple is reused | |
106 | # whenever nothing else holds a reference to it | |
107 | self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq)) | |
108 | self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq))) | |
109 | ||
110 | class MyEnum(enumerate): | |
111 | pass | |
112 | ||
113 | class SubclassTestCase(EnumerateTestCase): | |
114 | ||
115 | enum = MyEnum | |
116 | ||
117 | class TestEmpty(EnumerateTestCase): | |
118 | ||
119 | seq, res = '', [] | |
120 | ||
121 | class TestBig(EnumerateTestCase): | |
122 | ||
123 | seq = range(10,20000,2) | |
124 | res = zip(range(20000), seq) | |
125 | ||
126 | class TestReversed(unittest.TestCase): | |
127 | ||
128 | def test_simple(self): | |
129 | class A: | |
130 | def __getitem__(self, i): | |
131 | if i < 5: | |
132 | return str(i) | |
133 | raise StopIteration | |
134 | def __len__(self): | |
135 | return 5 | |
136 | for data in 'abc', range(5), tuple(enumerate('abc')), A(), xrange(1,17,5): | |
137 | self.assertEqual(list(data)[::-1], list(reversed(data))) | |
138 | self.assertRaises(TypeError, reversed, {}) | |
139 | ||
140 | def test_xrange_optimization(self): | |
141 | x = xrange(1) | |
142 | self.assertEqual(type(reversed(x)), type(iter(x))) | |
143 | ||
144 | def test_len(self): | |
145 | # This is an implementation detail, not an interface requirement | |
146 | for s in ('hello', tuple('hello'), list('hello'), xrange(5)): | |
147 | self.assertEqual(len(reversed(s)), len(s)) | |
148 | r = reversed(s) | |
149 | list(r) | |
150 | self.assertEqual(len(r), 0) | |
151 | class SeqWithWeirdLen: | |
152 | called = False | |
153 | def __len__(self): | |
154 | if not self.called: | |
155 | self.called = True | |
156 | return 10 | |
157 | raise ZeroDivisionError | |
158 | def __getitem__(self, index): | |
159 | return index | |
160 | r = reversed(SeqWithWeirdLen()) | |
161 | self.assertRaises(ZeroDivisionError, len, r) | |
162 | ||
163 | ||
164 | def test_gc(self): | |
165 | class Seq: | |
166 | def __len__(self): | |
167 | return 10 | |
168 | def __getitem__(self, index): | |
169 | return index | |
170 | s = Seq() | |
171 | r = reversed(s) | |
172 | s.r = r | |
173 | ||
174 | def test_args(self): | |
175 | self.assertRaises(TypeError, reversed) | |
176 | self.assertRaises(TypeError, reversed, [], 'extra') | |
177 | ||
178 | def test_main(verbose=None): | |
179 | testclasses = (EnumerateTestCase, SubclassTestCase, TestEmpty, TestBig, | |
180 | TestReversed) | |
181 | test_support.run_unittest(*testclasses) | |
182 | ||
183 | # verify reference counting | |
184 | import sys | |
185 | if verbose and hasattr(sys, "gettotalrefcount"): | |
186 | counts = [None] * 5 | |
187 | for i in xrange(len(counts)): | |
188 | test_support.run_unittest(*testclasses) | |
189 | counts[i] = sys.gettotalrefcount() | |
190 | print counts | |
191 | ||
192 | if __name__ == "__main__": | |
193 | test_main(verbose=True) |