diff --git a/Lib/test/support.py b/Lib/test/support.py --- a/Lib/test/support.py +++ b/Lib/test/support.py @@ -71,7 +71,8 @@ "TestHandler", "Matcher", "can_symlink", "skip_unless_symlink", "skip_unless_xattr", "import_fresh_module", "requires_zlib", "PIPE_MAX_SIZE", "failfast", "anticipate_failure", "run_with_tz", - "requires_bz2", "requires_lzma" + "requires_bz2", "requires_lzma", "exported_module", + "DualImplementationTests", ] class Error(Exception): @@ -198,6 +199,138 @@ return fresh_module +@contextlib.contextmanager +def exported_module(name, module): + """Context manager to temporarily swap in a module to sys.modules.""" + if name in sys.modules: + original = sys.modules[name] + sys.modules[name] = module + yield + sys.modules[name] = original + else: + sys.modules[name] = module + yield + del sys.modules[name] + + +class DualImplementationTests: + """A helper for the PEP 399 boilerplate. + + PEP 399 calls for accelerated modules to pass the same tests as + their corresponding pure Python modules. + + Any test methods that rely on this helper need to refer to the module + by "self.". This is the class attribute to which the pure + Python or accelerated module is bound. + + Example: + + from test.support import DualImplementationTests + + dual_impl_tests = MultiImplementationTests('heapq', '_heapq') + + class ExampleTest: + + def test_example(self): + self.assertTrue(hasattr(self.heapq, 'heapify')) + + PyExampleTest, CExampleTest = dual_impl_tests.create_test_cases(ExampleTest) + + (See issue #17037.) + + """ + + # Using keyword-only arguments you could go as far as allowing for + # customizing the attribute name to set, class name prefixes, etc. + def __init__(self, module_name, *accelerated_names): + self.module_name = module_name + self.accelerated_names = accelerated_names + self.py_module = import_fresh_module(module_name, + blocked=accelerated_names) + self.accelerated_module = import_fresh_module(module_name, + fresh=accelerated_names) + + def create_test_cases(self, test_class): + """Return (PyTestCase, AcceleratedTestCase) based on the test class. + + This generates two subclasses from scratch. In accordance with PEP + 399, test_class should not be a subclass of unittest.TestCase. + + """ + if issubclass(test_class, unittest.TestCase): + raise TypeError("base class cannot subclass unittest.TestCase") + + class PyTestCase(test_class, unittest.TestCase): + pass + PyTestCase.__name__ = 'Py' + test_class.__name__ + setattr(PyTestCase, self.module_name, self.py_module) + + if self.accelerated_module is None: + AcceleratedTestCase = None + else: + class AcceleratedTestCase(test_class, unittest.TestCase): + pass + AcceleratedTestCase.__name__ = 'Accelerated' + test_class.__name__ + setattr(AcceleratedTestCase, self.module_name, self.accelerated_module) + + return PyTestCase, AcceleratedTestCase + + def accelerated_only(self, test_item): + """A decorator for tests specific to the accelerated module. + + May also be used for classes or for specific test methods in a class. + + This method is roughly equivalent to: + + @unittest.skipUnless(dual_impl_tests.accelerated_module, 'requires ...') + + """ + msg = "specific to accelerated module: {}".format( + self.accelerated_names) + if isinstance(test_item, type): + if self.accelerated_module: + setattr(test_item, self.module_name, self.accelerated_module) + return test_item + else: + raise unittest.SkipTest(msg) + else: + @functools.wraps(test_item) + def skip_wrapper(self_, *args, **kwargs): + if getattr(self_, self.module_name) is self.py_module: + raise unittest.SkipTest(msg) + else: + return test_item(self_, *args, **kwargs) + return skip_wrapper + + def pure_python_only(self, test_item): + """Like accelerated_only, but for python-only tests.""" + msg = "specific to pure-Python module: {}".format(self.module_name) + if isinstance(test_item, type): + if self.py_module: + setattr(test_item, self.module_name, self.py_module) + return test_item + else: + raise unittest.SkipTest(msg) + else: + @functools.wraps(test_item) + def skip_wrapper(self_, *args, **kwargs): + if getattr(self_, self.module_name) is not self.py_module: + raise unittest.SkipTest(msg) + else: + return test_item(self_, *args, **kwargs) + return skip_wrapper + + def with_module_exported(self, f): + """A method decorator for ensuring the module is set in sys.modules.""" + # XXX make this a context manager instead? + @functools.wraps(f) + def func_wrapper(self_, *args, **kwargs): + module = getattr(self_, self.module_name) + with exported_module(self.module_name, module): + return f(self_, *args, **kwargs) + return func_wrapper + + def get_attribute(obj, name): """Get an attribute, raising SkipTest if AttributeError is raised.""" try: diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -26,6 +26,29 @@ def test_import_fresh_module(self): support.import_fresh_module("ftplib") + def test_exported_module(self): + assert "ham" not in sys.modules + module = "spam" + + self.assertNotIn("ham", sys.modules) + with support.exported_module("ham", module): + self.assertIs(sys.modules["ham"], module) + self.assertNotIn("ham", sys.modules) + + def test_exported_module_already_exists(self): + assert "ham" not in sys.modules + module = "spam" + existing = "eggs" + + sys.modules["ham"] = existing + try: + self.assertIs(sys.modules["ham"], existing) + with support.exported_module("ham", module): + self.assertIs(sys.modules["ham"], module) + self.assertIs(sys.modules["ham"], existing) + finally: + del sys.modules["ham"] + def test_get_attribute(self): self.assertEqual(support.get_attribute(self, "test_get_attribute"), self.test_get_attribute) @@ -190,9 +213,175 @@ # can_symlink # skip_unless_symlink +class TestDualImplementationTests(unittest.TestCase): + + PURE = "\n".join([ + "try:", + " from _spam import eggs", + "except ImportError:", + " eggs = 'hard-boiled'" + ]) + ACCELERATED = "eggs = 'microwaved'" + HAM = "eggs = 'fried'" + + @classmethod + def setUpClass(cls): + assert "spam" not in sys.modules + assert "_spam" not in sys.modules + assert "ham" not in sys.modules + cls.dirpath = tempfile.mkdtemp() + try: + with open(os.path.join(cls.dirpath, "spam.py"), "w") as modfile: + modfile.write(cls.PURE) + with open(os.path.join(cls.dirpath, "_spam.py"), "w") as modfile: + modfile.write(cls.ACCELERATED) + with open(os.path.join(cls.dirpath, "ham.py"), "w") as modfile: + modfile.write(cls.HAM) + cls.sys_path_cm = support.DirsOnSysPath(cls.dirpath) + except Exception: + support.rmtree(cls.dirpath) + raise + + @classmethod + def tearDownClass(cls): + try: + cls.sys_path_cm.__exit__() + finally: + support.rmtree(cls.dirpath) + + def test_init(self): + tests = support.DualImplementationTests("spam", "_spam") + + self.assertEqual(tests.module_name, "spam") + self.assertEqual(tests.accelerated_names, ("_spam",)) + self.assertEqual(tests.py_module.eggs, "hard-boiled") + self.assertEqual(tests.accelerated_module.eggs, "microwaved") + + def test_create_test_cases(self): + tests = support.DualImplementationTests("spam", "_spam") + + class ExampleTest: + def test_example(self): + self.assertTrue(hasattr(self.spam, 'eggs')) + + PyExampleTest, AccExampleTest = tests.create_test_cases(ExampleTest) + + self.assertTrue(issubclass(PyExampleTest, ExampleTest)) + self.assertTrue(issubclass(AccExampleTest, ExampleTest)) + self.assertEqual(PyExampleTest.spam.eggs, "hard-boiled") + self.assertEqual(AccExampleTest.spam.eggs, "microwaved") + + def test_create_test_cases_no_accelerator(self): + tests = support.DualImplementationTests("ham", "_ham") + + class ExampleTest: + def test_example(self): + self.assertTrue(hasattr(self.ham, 'eggs')) + + PyExampleTest, AccExampleTest = tests.create_test_cases(ExampleTest) + + self.assertTrue(issubclass(PyExampleTest, ExampleTest)) + self.assertIsNone(AccExampleTest) + self.assertEqual(PyExampleTest.ham.eggs, "fried") + + def test_class_decorators(self): + tests = support.DualImplementationTests("spam", "_spam") + + @tests.pure_python_only + class PyExampleTest: + def test_example(self): + self.assertEqual(hasattr(self.spam.eggs, "hard-boiled")) + + @tests.accelerated_only + class AccExampleTest: + def test_example(self): + self.assertEqual(hasattr(self.spam.eggs, "microwaved")) + + self.assertEqual(PyExampleTest.spam.eggs, "hard-boiled") + self.assertEqual(AccExampleTest.spam.eggs, "microwaved") + + def test_class_decorators_no_accelerator(self): + tests = support.DualImplementationTests("ham", "_ham") + + @tests.pure_python_only + class PyExampleTest: + def test_example(self): + self.assertEqual(hasattr(self.ham.eggs, "fried")) + + with self.assertRaises(unittest.SkipTest): + @tests.accelerated_only + class AccExampleTest: + def test_example(self): + self.assertEqual(hasattr(self.ham.eggs, "???")) + + self.assertEqual(PyExampleTest.ham.eggs, "fried") + + def test_accelerated_only_method(self): + tests = support.DualImplementationTests("spam", "_spam") + + class ExampleTest: + def test_example(self): + self.assertTrue(hasattr(self.spam, 'eggs')) + + @tests.accelerated_only + def test_oh_so_special(self): + self.assertTrue(hasattr(self.spam, 'eggs')) + + PyExampleTest, AccExampleTest = tests.create_test_cases(ExampleTest) + + pytest = PyExampleTest() + acctest = AccExampleTest() + + pytest.test_example() + acctest.test_example() + acctest.test_oh_so_special() + with self.assertRaises(unittest.SkipTest): + pytest.test_oh_so_special() + + def test_pure_python_only(self): + tests = support.DualImplementationTests("spam", "_spam") + + class ExampleTest: + def test_example(self): + self.assertTrue(hasattr(self.spam, 'eggs')) + + @tests.pure_python_only + def test_oh_so_special(self): + self.assertTrue(hasattr(self.spam, 'eggs')) + + PyExampleTest, AccExampleTest = tests.create_test_cases(ExampleTest) + + pytest = PyExampleTest() + acctest = AccExampleTest() + + pytest.test_example() + acctest.test_example() + pytest.test_oh_so_special() + with self.assertRaises(unittest.SkipTest): + acctest.test_oh_so_special() + + def test_with_module_exported(self): + tests = support.DualImplementationTests("spam", "_spam") + + class ExampleTest: + @tests.with_module_exported + def test_example(self): + self.assertIs(sys.modules["spam"], self.spam) + + PyExampleTest, AccExampleTest = tests.create_test_cases(ExampleTest) + pytest = PyExampleTest() + + original = sys.modules['spam'] + sys.modules['spam'] = "belgian waffles" + try: + pytest.test_example() + self.assertEqual(sys.modules['spam'], "belgian waffles") + finally: + sys.modules['spam'] = original + def test_main(): - tests = [TestSupport] + tests = [TestSupport, TestDualImplementationTests] support.run_unittest(*tests) if __name__ == '__main__':