diff -r dd2639508dfe Lib/test/test_heapq.py --- a/Lib/test/test_heapq.py Sat May 07 15:19:34 2011 -0700 +++ b/Lib/test/test_heapq.py Sun May 08 21:35:16 2011 +0300 @@ -1,16 +1,36 @@ """Unittests for heapq.""" +import sys import random -import unittest + from test import test_support -import sys +from types import FunctionType, BuiltinFunctionType +from unittest import TestCase, skipUnless # We do a bit of trickery here to be able to test both the C implementation # and the Python implementation of the module. -import heapq as c_heapq +c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq']) py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq']) -class TestHeap(unittest.TestCase): + +# used later to make sure that the tests are testing the correct functions +# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest, when +# _heapq is imported, so check them there +func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', + 'heapreplace', '_nlargest', '_nsmallest'] + +class TestModules(TestCase): + def test_py_functions(self): + for fname in func_names: + self.assertIsInstance(getattr(py_heapq, fname), FunctionType) + + @skipUnless(c_heapq, 'requires _heapq') + def test_c_functions(self): + for fname in func_names: + self.assertIsInstance(getattr(c_heapq, fname), BuiltinFunctionType) + + +class TestHeap(TestCase): module = None def test_push_pop(self): @@ -175,16 +195,12 @@ self.assertEqual(self.module.nlargest(n, data, key=f), sorted(data, key=f, reverse=True)[:n]) + class TestHeapPython(TestHeap): module = py_heapq - # As an early adopter, we sanity check the - # test_support.import_fresh_module utility function - def test_pure_python(self): - self.assertFalse(sys.modules['heapq'] is self.module) - self.assertTrue(hasattr(self.module.heapify, 'func_code')) - +@skipUnless(c_heapq, 'requires _heapq') class TestHeapC(TestHeap): module = c_heapq @@ -304,7 +320,7 @@ 'Test multiple tiers of iterators' return chain(imap(lambda x:x, R(Ig(G(seqn))))) -class TestErrorHandling(unittest.TestCase): +class TestErrorHandling(TestCase): def test_non_sequence(self): for f in (self.module.heapify, self.module.heappop): @@ -349,9 +365,12 @@ self.assertRaises(TypeError, f, 2, N(s)) self.assertRaises(ZeroDivisionError, f, 2, E(s)) + class TestErrorHandling_Python(TestErrorHandling): module = py_heapq + +@skipUnless(c_heapq, 'requires _heapq') class TestErrorHandling_C(TestErrorHandling): module = c_heapq @@ -360,8 +379,8 @@ def test_main(verbose=None): - test_classes = [TestHeapPython, TestHeapC, TestErrorHandling_Python, - TestErrorHandling_C] + test_classes = [TestModules, TestHeapPython, TestHeapC, + TestErrorHandling_Python, TestErrorHandling_C] test_support.run_unittest(*test_classes) # verify reference counting diff -r dd2639508dfe Lib/test/test_support.py --- a/Lib/test/test_support.py Sat May 07 15:19:34 2011 -0700 +++ b/Lib/test/test_support.py Sun May 08 21:35:16 2011 +0300 @@ -83,12 +83,14 @@ def _save_and_remove_module(name, orig_modules): """Helper function to save and remove a module from sys.modules - Return value is True if the module was in sys.modules and - False otherwise.""" + Return True if the module was in sys.modules, False otherwise. + Might raise ImportError if the module can't be imported.""" saved = True try: orig_modules[name] = sys.modules[name] except KeyError: + # try to import the module and raise an error if it can't be imported + __import__(name) saved = False else: del sys.modules[name] @@ -98,8 +100,7 @@ def _save_and_block_module(name, orig_modules): """Helper function to save and block a module in sys.modules - Return value is True if the module was in sys.modules and - False otherwise.""" + Return True if the module was in sys.modules, False otherwise.""" saved = True try: orig_modules[name] = sys.modules[name] @@ -136,6 +137,8 @@ if not _save_and_block_module(blocked_name, orig_modules): names_to_remove.append(blocked_name) fresh_module = importlib.import_module(name) + except ImportError: + fresh_module = None finally: for orig_name, module in orig_modules.items(): sys.modules[orig_name] = module