Index: Lib/unittest.py =================================================================== --- Lib/unittest.py (revision 70533) +++ Lib/unittest.py (working copy) @@ -53,6 +53,7 @@ import traceback import os import types +import functools ############################################################################## # Exported classes and functions @@ -84,6 +85,79 @@ def _strclass(cls): return "%s.%s" % (cls.__module__, cls.__name__) + +class SkipTest(Exception): + """ + Raise this exception in a test to skip it. + + Usually you can use TestResult.skip() or one of the skipping decorators + instead of raising this directly. + """ + pass + +class _ExpectedFailure(Exception): + """ + Raise this when a test is expected to fail. + + This is an implementation detail. + """ + + def __init__(self, exc_info): + super(_ExpectedFailure, self).__init__() + self.exc_info = exc_info + +class _UnexpectedSuccess(Exception): + """ + The test was supposed to fail, but it didn't! + """ + pass + +def _id(obj): + return obj + +def skip(reason): + """ + Unconditionally skip a test. + """ + def decorator(test_item): + if isinstance(test_item, type) and issubclass(test_item, TestCase): + test_item.__unittest_skip__ = True + test_item.__unittest_skip_why__ = reason + return test_item + @functools.wraps(test_item) + def skip_wrapper(*args, **kwargs): + raise SkipTest(reason) + return skip_wrapper + return decorator + +def skipIf(condition, reason): + """ + Skip a test if the condition is true. + """ + if condition: + return skip(reason) + return _id + +def skipUnless(condition, reason): + """ + Skip a test unless the condition is true. + """ + if not condition: + return skip(reason) + return _id + + +def expectedFailure(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + func(*args, **kwargs) + except Exception: + raise _ExpectedFailure(sys.exc_info()) + raise _UnexpectedSuccess + return wrapper + + __unittest = 1 class TestResult(object): @@ -101,6 +175,9 @@ self.failures = [] self.errors = [] self.testsRun = 0 + self.skipped = [] + self.expected_failures = [] + self.unexpected_successes = [] self.shouldStop = False def startTest(self, test): @@ -126,6 +203,19 @@ "Called when a test has completed successfully" pass + def addSkip(self, test, reason): + """Called when a test is skipped.""" + self.skipped.append((test, reason)) + + def addExpectedFailure(self, test, err): + """Called when an expected failure/error occured.""" + self.expected_failures.append( + (test, self._exc_info_to_string(err, test))) + + def addUnexpectedSuccess(self, test): + """Called when a test was expected to fail, but succeed.""" + self.unexpected_successes.append(test) + def wasSuccessful(self): "Tells whether or not this result was a success" return len(self.failures) == len(self.errors) == 0 @@ -274,25 +364,36 @@ try: try: self.setUp() + except SkipTest as e: + result.addSkip(self, str(e)) + return except Exception: result.addError(self, self._exc_info()) return - ok = False + success = False try: testMethod() - ok = True except self.failureException: result.addFailure(self, self._exc_info()) + except _ExpectedFailure as e: + result.addExpectedFailure(self, e.exc_info) + except _UnexpectedSuccess: + result.addUnexpectedSuccess(self) + except SkipTest as e: + result.addSkip(self, str(e)) except Exception: result.addError(self, self._exc_info()) + else: + success = True try: self.tearDown() except Exception: result.addError(self, self._exc_info()) - ok = False - if ok: result.addSuccess(self) + success = False + if success: + result.addSuccess(self) finally: result.stopTest(self) @@ -312,6 +413,10 @@ """ return sys.exc_info() + def skip(self, reason): + """Skip this test.""" + raise SkipTest(reason) + def fail(self, msg=None): """Fail immediately, with the given message.""" raise self.failureException(msg) @@ -419,8 +524,8 @@ __str__ = __repr__ def __eq__(self, other): - if type(self) is not type(other): - return False + if not isinstance(other, self.__class__): + return NotImplemented return self._tests == other._tests def __ne__(self, other): @@ -469,6 +574,37 @@ for test in self._tests: test.debug() +class ClassTestSuite(TestSuite): + """ + Suite of tests derived from a single TestCase class. + """ + + def __init__(self, tests, class_collected_from): + super(ClassTestSuite, self).__init__(tests) + self.collected_from = class_collected_from + + def id(self): + module = getattr(self.collected_from, "__module__", None) + if module is not None: + return "{0}.{1}".format(module, self.collected_from.__name__) + return self.collected_from.__name__ + + def run(self, result): + if getattr(self.collected_from, "__unittest_skip__", False): + # ClassTestSuite result pretends to be a TestCase enough to be + # reported. + result.startTest(self) + try: + result.addSkip(self, self.collected_from.__unittest_skip_why__) + finally: + result.stopTest(self) + else: + result = super(ClassTestSuite, self).run(result) + return result + + shortDescription = id + + class FunctionTestCase(TestCase): """A test case that wraps a test function. @@ -540,6 +676,7 @@ testMethodPrefix = 'test' sortTestMethodsUsing = cmp suiteClass = TestSuite + classSuiteClass = ClassTestSuite def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" @@ -548,7 +685,9 @@ testCaseNames = self.getTestCaseNames(testCaseClass) if not testCaseNames and hasattr(testCaseClass, 'runTest'): testCaseNames = ['runTest'] - return self.suiteClass(map(testCaseClass, testCaseNames)) + suite = self.classSuiteClass(map(testCaseClass, testCaseNames), + testCaseClass) + return suite def loadTestsFromModule(self, module): """Return a suite of all tests cases contained in the given module""" @@ -719,6 +858,30 @@ self.stream.write('F') self.stream.flush() + def addSkip(self, test, reason): + TestResult.addSkip(self, test, reason) + if self.showAll: + self.stream.writeln("skipped {0!r}".format(reason)) + elif self.dots: + self.stream.write("s") + self.stream.flush() + + def addExpectedFailure(self, test, err): + TestResult.addExpectedFailure(self, test, err) + if self.showAll: + self.stream.writeln("expected failure") + elif self.dots: + self.stream.write(".") + self.stream.flush() + + def addUnexpectedSuccess(self, test): + TestResult.addUnexpectedSuccess(self, test) + if self.showAll: + self.stream.writeln("unexpected success") + elif self.dots: + self.stream.write(".") + self.stream.flush() + def printErrors(self): if self.dots or self.showAll: self.stream.writeln() @@ -760,17 +923,28 @@ self.stream.writeln("Ran %d test%s in %.3fs" % (run, run != 1 and "s" or "", timeTaken)) self.stream.writeln() + results = map(len, (result.expected_failures, + result.unexpected_successes, + result.skipped)) + expected_fails, unexpected_successes, skipped = results + infos = [] if not result.wasSuccessful(): - self.stream.write("FAILED (") + self.stream.write("FAILED") failed, errored = map(len, (result.failures, result.errors)) if failed: - self.stream.write("failures=%d" % failed) + infos.append("failures=%d" % failed) if errored: - if failed: self.stream.write(", ") - self.stream.write("errors=%d" % errored) - self.stream.writeln(")") + infos.append("errors=%d" % errored) else: - self.stream.writeln("OK") + self.stream.write("OK") + if skipped: + infos.append("skipped=%d" % skipped) + if expected_fails: + infos.append("expected failures=%d" % expected_fails) + if unexpected_successes: + infos.append("unexpected successes=%d" % unexpected_successes) + if infos: + self.stream.writeln(" (%s)" % (", ".join(infos),)) return result @@ -824,9 +998,9 @@ def parseArgs(self, argv): import getopt + long_opts = ['help','verbose','quiet'] try: - options, args = getopt.getopt(argv[1:], 'hHvq', - ['help','verbose','quiet']) + options, args = getopt.getopt(argv[1:], 'hHvq', long_opts) for opt, value in options: if opt in ('-h','-H','--help'): self.usageExit() Index: Lib/test/test_unittest.py =================================================================== --- Lib/test/test_unittest.py (revision 70533) +++ Lib/test/test_unittest.py (working copy) @@ -31,10 +31,27 @@ self._events.append('addFailure') super(LoggingResult, self).addFailure(*args) + def addSuccess(self, *args): + self._events.append('addSuccess') + super(LoggingResult, self).addSuccess(*args) + def addError(self, *args): self._events.append('addError') super(LoggingResult, self).addError(*args) + def addSkip(self, *args): + self._events.append('addSkip') + super(LoggingResult, self).addSkip(*args) + + def addExpectedFailure(self, *args): + self._events.append('addExpectedFailure') + super(LoggingResult, self).addExpectedFailure(*args) + + def addUnexpectedSuccess(self, *args): + self._events.append('addUnexpectedSuccess') + super(LoggingResult, self).addUnexpectedSuccess(*args) + + class TestEquality(object): # Check for a valid __eq__ implementation def test_eq(self): @@ -72,6 +89,13 @@ self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e)) +# List subclass we can add attributes to. +class MyClassSuite(list): + + def __init__(self, tests, klass): + super(MyClassSuite, self).__init__(tests) + + ################################################################ ### /Support code @@ -1223,7 +1247,7 @@ tests = [Foo('test_1'), Foo('test_2')] loader = unittest.TestLoader() - loader.suiteClass = list + loader.classSuiteClass = MyClassSuite self.assertEqual(loader.loadTestsFromTestCase(Foo), tests) # It is implicit in the documentation for TestLoader.suiteClass that @@ -1236,7 +1260,7 @@ def foo_bar(self): pass m.Foo = Foo - tests = [[Foo('test_1'), Foo('test_2')]] + tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)] loader = unittest.TestLoader() loader.suiteClass = list @@ -1255,7 +1279,7 @@ tests = [Foo('test_1'), Foo('test_2')] loader = unittest.TestLoader() - loader.suiteClass = list + loader.classSuiteClass = MyClassSuite self.assertEqual(loader.loadTestsFromName('Foo', m), tests) # It is implicit in the documentation for TestLoader.suiteClass that @@ -1268,7 +1292,7 @@ def foo_bar(self): pass m.Foo = Foo - tests = [[Foo('test_1'), Foo('test_2')]] + tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)] loader = unittest.TestLoader() loader.suiteClass = list @@ -2261,9 +2285,103 @@ # Make run() find a result object on its own Foo('test').run() - expected = ['startTest', 'test', 'stopTest'] + expected = ['startTest', 'test', 'addSuccess', 'stopTest'] self.assertEqual(events, expected) + +class Test_TestSkipping(TestCase): + + def test_skipping(self): + class Foo(unittest.TestCase): + def test_skip_me(self): + self.skip("skip") + events = [] + result = LoggingResult(events) + test = Foo("test_skip_me") + test.run(result) + self.assertEqual(events, ['startTest', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "skip")]) + + # Try letting setUp skip the test now. + class Foo(unittest.TestCase): + def setUp(self): + self.skip("testing") + def test_nothing(self): pass + events = [] + result = LoggingResult(events) + test = Foo("test_nothing") + test.run(result) + self.assertEqual(events, ['startTest', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "testing")]) + self.assertEqual(result.testsRun, 1) + + def test_skipping_decorators(self): + op_table = ((unittest.skipUnless, False, True), + (unittest.skipIf, True, False)) + for deco, do_skip, dont_skip in op_table: + class Foo(unittest.TestCase): + @deco(do_skip, "testing") + def test_skip(self): pass + + @deco(dont_skip, "testing") + def test_dont_skip(self): pass + test_do_skip = Foo("test_skip") + test_dont_skip = Foo("test_dont_skip") + suite = unittest.ClassTestSuite([test_do_skip, test_dont_skip], Foo) + events = [] + result = LoggingResult(events) + suite.run(result) + self.assertEqual(len(result.skipped), 1) + expected = ['startTest', 'addSkip', 'stopTest', + 'startTest', 'addSuccess', 'stopTest'] + self.assertEqual(events, expected) + self.assertEqual(result.testsRun, 2) + self.assertEqual(result.skipped, [(test_do_skip, "testing")]) + self.assertTrue(result.wasSuccessful()) + + def test_skip_class(self): + @unittest.skip("testing") + class Foo(unittest.TestCase): + def test_1(self): + record.append(1) + record = [] + result = unittest.TestResult() + suite = unittest.ClassTestSuite([Foo("test_1")], Foo) + suite.run(result) + self.assertEqual(result.skipped, [(suite, "testing")]) + self.assertEqual(record, []) + + def test_expected_failure(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + self.fail("help me!") + events = [] + result = LoggingResult(events) + test = Foo("test_die") + test.run(result) + self.assertEqual(events, + ['startTest', 'addExpectedFailure', 'stopTest']) + self.assertEqual(result.expected_failures[0][0], test) + self.assertTrue(result.wasSuccessful()) + + def test_unexpected_success(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + pass + events = [] + result = LoggingResult(events) + test = Foo("test_die") + test.run(result) + self.assertEqual(events, + ['startTest', 'addUnexpectedSuccess', 'stopTest']) + self.assertFalse(result.failures) + self.assertEqual(result.unexpected_successes, [test]) + self.assertTrue(result.wasSuccessful()) + + + class Test_Assertions(TestCase): def test_AlmostEqual(self): self.failUnlessAlmostEqual(1.00000001, 1.0) @@ -2328,7 +2446,7 @@ def test_main(): test_support.run_unittest(Test_TestCase, Test_TestLoader, Test_TestSuite, Test_TestResult, Test_FunctionTestCase, - Test_Assertions) + Test_TestSkipping, Test_Assertions) if __name__ == "__main__": test_main()