"""Unittests for heapq."""
from heapq
import heappush
, heappop
, heapify
, heapreplace
, nlargest
, nsmallest
from test
import test_support
# An iterator returning a heap's elements, smallest-first.
class TestHeap(unittest
.TestCase
):
# 1) Push 256 random numbers and pop them off, verifying all's OK.
self
.check_invariant(heap
)
self
.check_invariant(heap
)
self
.check_invariant(heap
)
self
.assertEqual(data_sorted
, results
)
# 2) Check that the invariant holds for a sorted array
self
.check_invariant(results
)
self
.assertRaises(TypeError, heappush
, [])
self
.assertRaises(TypeError, heappush
, None, None)
self
.assertRaises(TypeError, heappop
, None)
def check_invariant(self
, heap
):
# Check the heap invariant.
for pos
, item
in enumerate(heap
):
if pos
: # pos 0 has no parent
self
.assert_(heap
[parentpos
] <= item
)
heap
= [random
.random() for dummy
in range(size
)]
self
.check_invariant(heap
)
self
.assertRaises(TypeError, heapify
, None)
def test_naive_nbest(self
):
data
= [random
.randrange(2000) for i
in range(1000)]
self
.assertEqual(heap
, sorted(data
)[-10:])
# Less-naive "N-best" algorithm, much faster (if len(data) is big
# enough <wink>) than sorting all of data. However, if we had a max
# heap instead of a min heap, it could go faster still via
# heapify'ing all of data (linear time), then doing 10 heappops
data
= [random
.randrange(2000) for i
in range(1000)]
if item
> heap
[0]: # this gets rarer the longer we run
self
.assertEqual(list(heapiter(heap
)), sorted(data
)[-10:])
self
.assertRaises(TypeError, heapreplace
, None)
self
.assertRaises(TypeError, heapreplace
, None, None)
self
.assertRaises(IndexError, heapreplace
, [], None)
# Exercise everything with repeated heapsort checks
for trial
in xrange(100):
size
= random
.randrange(50)
data
= [random
.randrange(25) for i
in range(size
)]
if trial
& 1: # Half of the time, use heapify
else: # The rest of the time, use heappush
heap_sorted
= [heappop(heap
) for i
in range(size
)]
self
.assertEqual(heap_sorted
, sorted(data
))
def test_nsmallest(self
):
data
= [random
.randrange(2000) for i
in range(1000)]
for n
in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self
.assertEqual(nsmallest(n
, data
), sorted(data
)[:n
])
data
= [random
.randrange(2000) for i
in range(1000)]
for n
in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self
.assertEqual(nlargest(n
, data
), sorted(data
, reverse
=True)[:n
])
#==============================================================================
"Dummy sequence class defining __len__ but not __getitem__."
"Dummy sequence class defining __getitem__ but not __len__."
def __getitem__(self
, ndx
):
"Dummy element that always raises an error during comparison"
def __cmp__(self
, other
):
'Sequence using __getitem__'
def __init__(self
, seqn
):
def __getitem__(self
, i
):
'Sequence using iterator protocol'
def __init__(self
, seqn
):
if self
.i
>= len(self
.seqn
): raise StopIteration
'Sequence using iterator protocol defined with a generator'
def __init__(self
, seqn
):
'Missing __getitem__ and __iter__'
def __init__(self
, seqn
):
if self
.i
>= len(self
.seqn
): raise StopIteration
'Iterator missing next()'
def __init__(self
, seqn
):
'Test propagation of exceptions'
def __init__(self
, seqn
):
def __init__(self
, seqn
):
from itertools
import chain
, imap
'Test multiple tiers of iterators'
return chain(imap(lambda x
:x
, R(Ig(G(seqn
)))))
class TestErrorHandling(unittest
.TestCase
):
def test_non_sequence(self
):
for f
in (heapify
, heappop
):
self
.assertRaises(TypeError, f
, 10)
for f
in (heappush
, heapreplace
, nlargest
, nsmallest
):
self
.assertRaises(TypeError, f
, 10, 10)
for f
in (heapify
, heappop
):
self
.assertRaises(TypeError, f
, LenOnly())
for f
in (heappush
, heapreplace
):
self
.assertRaises(TypeError, f
, LenOnly(), 10)
for f
in (nlargest
, nsmallest
):
self
.assertRaises(TypeError, f
, 2, LenOnly())
for f
in (heapify
, heappop
):
self
.assertRaises(TypeError, f
, GetOnly())
for f
in (heappush
, heapreplace
):
self
.assertRaises(TypeError, f
, GetOnly(), 10)
for f
in (nlargest
, nsmallest
):
self
.assertRaises(TypeError, f
, 2, GetOnly())
seq
= [CmpErr(), CmpErr(), CmpErr()]
for f
in (heapify
, heappop
):
self
.assertRaises(ZeroDivisionError, f
, seq
)
for f
in (heappush
, heapreplace
):
self
.assertRaises(ZeroDivisionError, f
, seq
, 10)
for f
in (nlargest
, nsmallest
):
self
.assertRaises(ZeroDivisionError, f
, 2, seq
)
def test_arg_parsing(self
):
for f
in (heapify
, heappop
, heappush
, heapreplace
, nlargest
, nsmallest
):
self
.assertRaises(TypeError, f
, 10)
def test_iterable_args(self
):
for f
in (nlargest
, nsmallest
):
for s
in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
for g
in (G
, I
, Ig
, L
, R
):
self
.assertEqual(f(2, g(s
)), f(2,s
))
self
.assertEqual(f(2, S(s
)), [])
self
.assertRaises(TypeError, f
, 2, X(s
))
self
.assertRaises(TypeError, f
, 2, N(s
))
self
.assertRaises(ZeroDivisionError, f
, 2, E(s
))
#==============================================================================
def test_main(verbose
=None):
from types
import BuiltinFunctionType
test_classes
= [TestHeap
]
if isinstance(heapify
, BuiltinFunctionType
):
test_classes
.append(TestErrorHandling
)
test_support
.run_unittest(*test_classes
)
# verify reference counting
if verbose
and hasattr(sys
, "gettotalrefcount"):
for i
in xrange(len(counts
)):
test_support
.run_unittest(*test_classes
)
counts
[i
] = sys
.gettotalrefcount()
if __name__
== "__main__":