# HG changeset patch # Parent eb392fcc6177d30e32f8e0c9a02dd8697800d796 # User Julien Pagès add a TestCase.patch method diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst --- a/Doc/library/unittest.rst +++ b/Doc/library/unittest.rst @@ -753,16 +753,37 @@ Test cases A test case can contain any number of subtest declarations, and they can be arbitrarily nested. See :ref:`subtests` for more information. .. versionadded:: 3.4 + .. method:: patch(*args, **kargs) + + Convenience method that calls :func:`unittest.mock.patch` and returns the + mock instance used as the patch. + + Example of use:: + + import platform + + class TestCase(unittest.TestCase): + def setUp(self): + self.patch('platform.system', return_value='PatchedLinux') + + def test_patched(self): + self.assertEqual(platform.system(), 'PatchedLinux') + + Note that the patch will automatically be removed with :meth:`doCleanups`. + + .. versionadded:: 3.5 + + .. method:: debug() Run the test without collecting the result. This allows exceptions raised by the test to be propagated to the caller, and can be used to support running tests under a debugger. .. _assert-methods: diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -628,16 +628,24 @@ class TestCase(object): """Run the test without collecting errors in a TestResult""" self.setUp() getattr(self, self._testMethodName)() self.tearDown() while self._cleanups: function, args, kwargs = self._cleanups.pop(-1) function(*args, **kwargs) + def patch(self, *args, **kwargs): + # lazy import + from unittest.mock import patch + p = patch(*args, **kwargs) + result = p.start() + self.addCleanup(p.stop) + return result + def skipTest(self, reason): """Skip this test.""" raise SkipTest(reason) def fail(self, msg=None): """Fail immediately, with the given message.""" raise self.failureException(msg) diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -1564,11 +1564,27 @@ test case def test2(self): raise MyException() for method_name in ('test1', 'test2'): testcase = TestCase(method_name) testcase.run() self.assertEqual(MyException.ninstance, 0) + @staticmethod + def to_be_patched(): + return 1 + + def test_patch(self): + patch_name = __name__ + '.' + self.__class__.__name__ + '.to_be_patched' + class TestCase(unittest.TestCase): + def setUp(self): + self.patch(patch_name, return_value=2) + + testcase = TestCase() + self.assertEqual(self.to_be_patched(), 1) + testcase.setUp() + self.assertEqual(self.to_be_patched(), 2) + testcase.doCleanups() + self.assertEqual(self.to_be_patched(), 1) if __name__ == "__main__": unittest.main()