diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -47,7 +47,7 @@ __all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', - 'expectedFailure', 'TextTestResult', 'installHandler', + 'expectedFailure', 'testBaseClass', 'TextTestResult', 'installHandler', 'registerResult', 'removeResult', 'removeHandler'] # Expose obsolete functions for backwards compatibility @@ -57,7 +57,7 @@ from .result import TestResult from .case import (TestCase, FunctionTestCase, SkipTest, skip, skipIf, - skipUnless, expectedFailure) + skipUnless, expectedFailure, testBaseClass) from .suite import BaseTestSuite, TestSuite from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames, findTestCases) diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -120,6 +120,12 @@ return test_item +def testBaseClass(base_class): + base_class.__testbase__ = True + return base_class + + + class _BaseTestCaseContext: def __init__(self, test_case): diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -47,7 +47,6 @@ return path[:-9] return os.path.splitext(path)[0] - class TestLoader(object): """ This class is responsible for loading tests according to various criteria @@ -60,10 +59,14 @@ def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" + if issubclass(testCaseClass, suite.TestSuite): raise TypeError("Test cases should not be derived from " "TestSuite. Maybe you meant to derive from " "TestCase?") + if testCaseClass.__dict__.get('__testbase__', False): + raise TypeError("Base test class should not be run.") + testCaseNames = self.getTestCaseNames(testCaseClass) if not testCaseNames and hasattr(testCaseClass, 'runTest'): testCaseNames = ['runTest'] @@ -75,7 +78,8 @@ tests = [] for name in dir(module): obj = getattr(module, name) - if isinstance(obj, type) and issubclass(obj, case.TestCase): + + if isinstance(obj, type) and issubclass(obj, case.TestCase) and not obj.__dict__.get('__testbase__', False): tests.append(self.loadTestsFromTestCase(obj)) load_tests = getattr(module, 'load_tests', None) diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -1569,6 +1569,17 @@ testcase.run() self.assertEqual(MyException.ninstance, 0) + def test_base_class_should_not_run(self): + # Issue #14534: Add method to mark unittest.TestCases as "do not run" + import unittest.test.testmock.baseclassdecorator + + result = unittest.TestResult() + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(unittest.test.testmock.baseclassdecorator) + + suite.run(result) + self.assertEqual(result.testsRun, 1) if __name__ == "__main__": unittest.main() diff --git a/Lib/unittest/test/testmock/baseclassdecorator.py b/Lib/unittest/test/testmock/baseclassdecorator.py new file mode 100644 --- /dev/null +++ b/Lib/unittest/test/testmock/baseclassdecorator.py @@ -0,0 +1,20 @@ +import unittest + +@unittest.testBaseClass +class TestBaseDecoratorBase(unittest.TestCase): + def setUp(self): + self.is_base_class = True + + def test_base_class_should_not_run(self): + """ + This test will fail if run by the base class but will pass + if run by the child class. + """ + self.assertFalse(self.is_base_class) + +class TestBaseDecoratorChild(TestBaseDecoratorBase): + def setUp(self): + self.is_base_class = False + +if __name__ == "__main__": + unittest.main()