Index: Lib/unittest.py =================================================================== --- Lib/unittest.py (revision 51654) +++ Lib/unittest.py (working copy) @@ -25,7 +25,7 @@ Further information is available in the bundled documentation, and from - http://pyunit.sourceforge.net/ + http://docs.python.org/lib/module-unittest.html Copyright (c) 1999-2003 Steve Purcell This module is free software, and you may redistribute it and/or modify @@ -107,7 +107,7 @@ self.failures = [] self.errors = [] self.testsRun = 0 - self.shouldStop = 0 + self.shouldStop = False def startTest(self, test): "Called when the given test is about to be run" @@ -234,6 +234,18 @@ def id(self): return "%s.%s" % (_strclass(self.__class__), self._testMethodName) + + def __eq__(self, other): + if type(self) is not type(other): + return False + + return self._testMethodName == other._testMethodName + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash(str(hash(type(self))) + str(hash(self._testMethodName))) def __str__(self): return "%s (%s)" % (self._testMethodName, _strclass(self.__class__)) @@ -291,10 +303,7 @@ minimised; usually the top level of the traceback frame is not needed. """ - exctype, excvalue, tb = sys.exc_info() - if sys.platform[:4] == 'java': ## tracebacks look different in Jython - return (exctype, excvalue, tb) - return (exctype, excvalue, tb) + return sys.exc_info() def fail(self, msg=None): """Fail immediately, with the given message.""" @@ -400,6 +409,14 @@ return "<%s tests=%s>" % (_strclass(self.__class__), self._tests) __str__ = __repr__ + + def __eq__(self, other): + if type(self) != type(other): + return False + return self._tests == other._tests + + def __ne__(self, other): + return not self == other def __iter__(self): return iter(self._tests) @@ -411,6 +428,10 @@ return cases def addTest(self, test): + if not isinstance(test, (TestCase, TestSuite)): + raise TypeError( + "%s must be an instance of TestCase or TestSuite" % test) + self._tests.append(test) def addTests(self, tests): @@ -436,7 +457,7 @@ """A test case that wraps a test function. This is useful for slipping pre-existing test functions into the - PyUnit framework. Optionally, set-up and tidy-up functions can be + unittest framework. Optionally, set-up and tidy-up functions can be supplied. As with TestCase, the tidy-up ('tearDown') function will always be called if the set-up ('setUp') function ran successfully. """ @@ -474,15 +495,32 @@ doc = self.__testFunc.__doc__ return doc and doc.split("\n")[0].strip() or None + def __eq__(self, other): + if type(self) is not type(other): + return False + + return self.__setUpFunc == other.__setUpFunc and \ + self.__tearDownFunc == other.__tearDownFunc and \ + self.__testFunc == other.__testFunc and \ + self.__description == other.__description + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash(''.join(str(hash(x)) for x in [type(self) + ,self.__setUpFunc + ,self.__tearDownFunc + ,self.__testFunc + ,self.__description])) - ############################################################################## # Locating and loading tests ############################################################################## class TestLoader: """This class is responsible for loading tests according to various - criteria and returning them wrapped in a Test + criteria and returning them wrapped in a TestSuite """ testMethodPrefix = 'test' sortTestMethodsUsing = cmp @@ -537,17 +575,21 @@ issubclass(obj, TestCase)): return self.loadTestsFromTestCase(obj) elif type(obj) == types.UnboundMethodType: - return parent(obj.__name__) + return TestSuite([parent(obj.__name__)]) elif isinstance(obj, TestSuite): return obj elif callable(obj): test = obj() - if not isinstance(test, (TestCase, TestSuite)): - raise ValueError, \ + if isinstance(test, TestSuite): + return test + elif isinstance(test, TestCase): + return TestSuite([test]) + else: + raise TypeError, \ "calling %s returned %s, not a test" % (obj,test) return test else: - raise ValueError, "don't know how to make test from: %s" % obj + raise TypeError, "don't know how to make test from: %s" % obj def loadTestsFromNames(self, names, module=None): """Return a suite of all tests cases found using the given sequence @@ -562,10 +604,6 @@ def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname)) testFnNames = filter(isTestMethod, dir(testCaseClass)) - for baseclass in testCaseClass.__bases__: - for testFnName in self.getTestCaseNames(baseclass): - if testFnName not in testFnNames: # handle overridden methods - testFnNames.append(testFnName) if self.sortTestMethodsUsing: testFnNames.sort(self.sortTestMethodsUsing) return testFnNames