diff -r fa097a336079 Lib/unittest/case.py --- a/Lib/unittest/case.py Wed Jun 24 12:51:55 2015 -0400 +++ b/Lib/unittest/case.py Thu Jul 02 22:17:54 2015 +0700 @@ -333,7 +333,13 @@ .format(logging.getLevelName(self.level), self.logger.name)) -class TestCase(object): +class TestCaseMeta(type): + def __new__(cls, class_name, parents, attrs): + attrs['_cleanupsClass'] = [] + return type.__new__(cls, class_name, parents, attrs) + + +class TestCase(metaclass=TestCaseMeta): """A class whose instances are single test cases. By default, the test code itself should be placed in a method named @@ -380,6 +386,7 @@ _classSetupFailed = False + def __init__(self, methodName='runTest'): """Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does @@ -435,6 +442,15 @@ Cleanup items are called even if setUp fails (unlike tearDown).""" self._cleanups.append((function, args, kwargs)) + @classmethod + def addCleanupClass(cls, function, *args, **kwargs): + """Add a function, with arguments, to be called when the test class is + completed. Functions added are called on a LIFO basis and are + called after tearDownClass on test failure or success. + + Cleanup items are called even if setUpClass fails (unlike tearDownClass).""" + cls._cleanupsClass.append((function, args, kwargs)) + def setUp(self): "Hook method for setting up the test fixture before exercising it." pass diff -r fa097a336079 Lib/unittest/suite.py --- a/Lib/unittest/suite.py Wed Jun 24 12:51:55 2015 -0400 +++ b/Lib/unittest/suite.py Thu Jul 02 22:17:54 2015 +0700 @@ -110,6 +110,7 @@ if _isnotsuite(test): self._tearDownPreviousClass(test, result) + self._doCleanupClass(test, result) self._handleModuleFixture(test, result) self._handleClassSetUp(test, result) result._previousTestClass = test.__class__ @@ -128,6 +129,7 @@ if topLevel: self._tearDownPreviousClass(None, result) + self._doCleanupClass(None, result) self._handleModuleTearDown(result) result._testRunEntered = False return result @@ -266,6 +268,31 @@ finally: _call_if_exists(result, '_restoreStdout') + def _doCleanupClass(self, test, result): + previousClass = getattr(result, '_previousTestClass', None) + currentClass = test.__class__ + if currentClass == previousClass: + return + if getattr(result, '_moduleSetUpFailed', False): + return + if getattr(previousClass, "__unittest_skip__", False): + return + + cleanupsClass = getattr(previousClass, '_cleanupsClass', None) + while cleanupsClass: + _call_if_exists(result, '_setupStdout') + try: + function, args, kwargs = cleanupsClass.pop() + function(*args, **kwargs) + except Exception as e: + if isinstance(result, _DebugResult): + raise + className = util.strclass(previousClass) + errorName = 'doCleanupClass (%s)' % className + self._addClassOrModuleLevelException(result, e, errorName) + finally: + _call_if_exists(result, '_restoreStdout') + class _ErrorHolder(object): """ diff -r fa097a336079 Lib/unittest/test/test_setups.py --- a/Lib/unittest/test/test_setups.py Wed Jun 24 12:51:55 2015 -0400 +++ b/Lib/unittest/test/test_setups.py Thu Jul 02 22:17:54 2015 +0700 @@ -166,6 +166,7 @@ class Test(unittest.TestCase): classSetUp = False tornDown = False + cleanedUp = False @classmethod def setUpClass(cls): Test.classSetUp = True @@ -175,10 +176,15 @@ def test_one(self): pass + def cleanup(): + Test.cleanedUp = True + Test.addCleanupClass(cleanup) + Test = unittest.skip("hop")(Test) self.runTests(Test) self.assertFalse(Test.classSetUp) self.assertFalse(Test.tornDown) + self.assertFalse(Test.cleanedUp) def test_setup_teardown_order_with_pathological_suite(self): results = [] @@ -223,6 +229,11 @@ def testTwo(self): results.append('Test2.testTwo') + def Test2cleanup(): + results.append('cleanup 2') + + Test2.addCleanupClass(Test2cleanup) + class Test3(unittest.TestCase): @classmethod def setUpClass(cls): @@ -257,7 +268,7 @@ ['Module1.setUpModule', 'setup 1', 'Test1.testOne', 'Test1.testTwo', 'teardown 1', 'setup 2', 'Test2.testOne', 'Test2.testTwo', - 'teardown 2', 'Module1.tearDownModule', + 'teardown 2', 'cleanup 2', 'Module1.tearDownModule', 'Module2.setUpModule', 'setup 3', 'Test3.testOne', 'Test3.testTwo', 'teardown 3', 'Module2.tearDownModule']) @@ -295,6 +306,7 @@ Module.moduleTornDown += 1 class Test(unittest.TestCase): + classCleanedUp = False classSetUp = False classTornDown = False @classmethod @@ -313,6 +325,10 @@ pass def test_two(self): pass + + def cleanup(): + Test.classCleanedUp = True + Test.addCleanupClass(cleanup) Test.__module__ = 'Module' Test2.__module__ = 'Module' sys.modules['Module'] = Module @@ -323,6 +339,7 @@ self.assertEqual(result.testsRun, 0) self.assertFalse(Test.classSetUp) self.assertFalse(Test.classTornDown) + self.assertFalse(Test.classCleanedUp) self.assertEqual(len(result.errors), 1) error, _ = result.errors[0] self.assertEqual(str(error), 'setUpModule (Module)') @@ -368,6 +385,7 @@ raise TypeError('foo') class Test(unittest.TestCase): + classCleanedUp = False classSetUp = False classTornDown = False @classmethod @@ -390,11 +408,16 @@ Test2.__module__ = 'Module' sys.modules['Module'] = Module + def cleanup(): + Test.classCleanedUp = True + Test.addCleanupClass(cleanup) + result = self.runTests(Test, Test2) self.assertEqual(Module.moduleTornDown, 1) self.assertEqual(result.testsRun, 4) self.assertTrue(Test.classSetUp) self.assertTrue(Test.classTornDown) + self.assertTrue(Test.classCleanedUp) self.assertEqual(len(result.errors), 1) error, _ = result.errors[0] self.assertEqual(str(error), 'tearDownModule (Module)') @@ -463,9 +486,13 @@ Test.__module__ = 'Module' sys.modules['Module'] = Module + def cleanup(): + ordering.append('addCleanupClass') + Test.addCleanupClass(cleanup) + suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test) suite.debug() - expectedOrder = ['setUpModule', 'setUpClass', 'test_something', 'tearDownClass', 'tearDownModule'] + expectedOrder = ['setUpModule', 'setUpClass', 'test_something', 'tearDownClass', 'addCleanupClass', 'tearDownModule'] self.assertEqual(ordering, expectedOrder) def test_suite_debug_propagates_exceptions(self): @@ -476,32 +503,88 @@ raise Exception('setUpModule') @staticmethod def tearDownModule(): - if phase == 1: + if phase == 2: raise Exception('tearDownModule') class Test(unittest.TestCase): @classmethod def setUpClass(cls): - if phase == 2: + if phase == 3: raise Exception('setUpClass') @classmethod def tearDownClass(cls): - if phase == 3: + if phase == 4: raise Exception('tearDownClass') def test_something(self): - if phase == 4: + if phase == 5: raise Exception('test_something') Test.__module__ = 'Module' sys.modules['Module'] = Module - messages = ('setUpModule', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something') + def cleanup(): + if phase == 1: + raise Exception('addCleanupClass') + Test.addCleanupClass(cleanup) + + messages = ('setUpModule', 'addCleanupClass', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something') for phase, msg in enumerate(messages): _suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test) suite = unittest.TestSuite([_suite]) with self.assertRaisesRegex(Exception, msg): suite.debug() + def test_class_cleanup(self): + class Test(unittest.TestCase): + cleanedUp = 0 + def test_one(self): + pass + def test_two(self): + pass + + def cleanup(): + Test.cleanedUp += 1 + + Test.addCleanupClass(cleanup) + self.runTests(Test) + self.assertEqual(Test.cleanedUp, 1) + + def test_class_cleanup_error_in_setup_class(self): + class Test(unittest.TestCase): + cleanedUp = 0 + @classmethod + def setUpClass(cls): + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + def cleanup(): + Test.cleanedUp += 1 + + Test.addCleanupClass(cleanup) + self.runTests(Test) + self.assertEqual(Test.cleanedUp, 1) + + def test_class_cleanup_error_in_teardown_class(self): + class Test(unittest.TestCase): + cleanedUp = 0 + @classmethod + def tearDownClass(cls): + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + def cleanup(): + Test.cleanedUp += 1 + + Test.addCleanupClass(cleanup) + self.runTests(Test) + self.assertEqual(Test.cleanedUp, 1) + if __name__ == '__main__': unittest.main()