diff --git a/Lib/test/test_warnings.py b/Lib/test/test_warnings.py --- a/Lib/test/test_warnings.py +++ b/Lib/test/test_warnings.py @@ -365,16 +365,45 @@ class WarnTests(BaseTest): """Warning with a bad format string for __str__.""" def __str__(self): return ("A bad formatted string %(err)" % {"err" : "there is no %(err)s"}) with self.assertRaises(ValueError): self.module.warn(BadStrWarning()) + def test_warning_classes(self): + class MyWarningClass(Warning): + pass + + class NonWarningSubclass: + pass + + msg_regex = 'category must be a Warning subclass, not (.*)' + + # passing a non-subclass of Warning should raise a TypeError + with self.assertRaisesRegex(TypeError, msg_regex): + self.module.warn('bad warning category', '') + + with self.assertRaisesRegex(TypeError, msg_regex): + self.module.warn('bad warning category', NonWarningSubclass) + + # check that warning instances also raise a TypeError + with self.assertRaisesRegex(TypeError, msg_regex): + self.module.warn('bad warning category', MyWarningClass()) + + with self.assertWarnsRegex(MyWarningClass, 'good warning category'): + self.module.warn('good warning category', MyWarningClass) + + with self.assertWarnsRegex(UserWarning, 'good warning category'): + self.module.warn('good warning category', None) + + with self.assertWarns(MyWarningClass) as cm: + self.module.warn('good warning category', MyWarningClass) + self.assertIsInstance(cm.warning, Warning) class CWarnTests(WarnTests, unittest.TestCase): module = c_warnings # As an early adopter, we sanity check the # test.support.import_fresh_module utility function def test_accelerated(self): self.assertFalse(original_warnings is self.module) diff --git a/Lib/warnings.py b/Lib/warnings.py --- a/Lib/warnings.py +++ b/Lib/warnings.py @@ -157,17 +157,21 @@ def _getcategory(category): def warn(message, category=None, stacklevel=1): """Issue a warning, or maybe ignore it or raise an exception.""" # Check if message is already a Warning object if isinstance(message, Warning): category = message.__class__ # Check category argument if category is None: category = UserWarning - assert issubclass(category, Warning) + # user-error means that category could be either a class + # or some invalid object instance + if not (isinstance(category, type) and issubclass(category, Warning)): + raise TypeError('category must be a Warning subclass, ' + 'not {!r}'.format(category)) # Get context information try: caller = sys._getframe(stacklevel) except ValueError: globals = sys.__dict__ lineno = 1 else: globals = caller.f_globals diff --git a/Python/_warnings.c b/Python/_warnings.c --- a/Python/_warnings.c +++ b/Python/_warnings.c @@ -614,26 +614,26 @@ get_category(PyObject *message, PyObject /* Get category. */ rc = PyObject_IsInstance(message, PyExc_Warning); if (rc == -1) return NULL; if (rc == 1) category = (PyObject*)message->ob_type; - else if (category == NULL) + else if (category == NULL || category == Py_None) category = PyExc_UserWarning; /* Validate category. */ rc = PyObject_IsSubclass(category, PyExc_Warning); - if (rc == -1) - return NULL; - if (rc == 0) { - PyErr_SetString(PyExc_ValueError, - "category is not a subclass of Warning"); + /* either not a subclass or an error from PyObject_IsSubclass */ + if (rc == -1 || rc == 0) { + PyErr_Format(PyExc_TypeError, + "category must be a Warning subclass, not %R", + category); return NULL; } return category; } static PyObject * do_warn(PyObject *message, PyObject *category, Py_ssize_t stack_level)