diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 4e09ca5..e240bf4 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -1006,6 +1006,60 @@ class BuiltinTest(unittest.TestCase): self.assertRaises(TypeError, range, 0.0, 0.0, 1) self.assertRaises(TypeError, range, 0.0, 0.0, 1.0) + self.assertEqual(range(3).count(-1), 0) + self.assertEqual(range(3).count(0), 1) + self.assertEqual(range(3).count(1), 1) + self.assertEqual(range(3).count(2), 1) + self.assertEqual(range(3).count(3), 0) + + self.assertEqual(range(10**20).count(1), 1) + self.assertEqual(range(10**20).count(10**20), 0) + self.assertEqual(range(3).index(1), 1) + self.assertEqual(range(1, 2**100, 2).count(2**87), 0) + self.assertEqual(range(1, 2**100, 2).count(2**87+1), 1) + + self.assertEqual(range(1, 10, 3).index(4), 1) + self.assertEqual(range(1, -10, -3).index(-5), 2) + + self.assertEqual(range(10**20).index(1), 1) + self.assertEqual(range(10**20).index(10**20 - 1), 10**20 - 1) + + self.assertRaises(ValueError, range(1, 2**100, 2).index, 2**87) + self.assertEqual(range(1, 2**100, 2).index(2**87+1), 2**86) + + class AlwaysEqual(object): + def __eq__(self, other): + return True + always_equal = AlwaysEqual() + self.assertEqual(range(10).count(always_equal), 10) + self.assertEqual(range(10).index(always_equal), 0) + + def test_range_index(self): + u = range(2) + self.assertEqual(u.index(0), 0) + self.assertEqual(u.index(1), 1) + self.assertRaises(ValueError, u.index, 2) + + u = range(-2, 3) + self.assertEqual(u.count(0), 1) + self.assertEqual(u.index(0), 2) + self.assertRaises(TypeError, u.index) + + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False + + a = range(4) + self.assertRaises(BadExc, a.index, BadCmp()) + + a = range(-2, 3) + self.assertEqual(a.index(0), 2) + def test_input(self): self.write_testfile() fp = open(TESTFN, 'r') diff --git a/Objects/rangeobject.c b/Objects/rangeobject.c index e36469c..9650386 100644 --- a/Objects/rangeobject.c +++ b/Objects/rangeobject.c @@ -273,60 +273,135 @@ range_reduce(rangeobject *r, PyObject *args) r->start, r->stop, r->step); } +/* Assumes (PyLong_CheckExact(ob) || PyBool_Check(ob)) */ +static int +range_contains_long(rangeobject *r, PyObject *ob) +{ + int cmp1, cmp2, cmp3; + PyObject *tmp1 = NULL; + PyObject *tmp2 = NULL; + PyObject *zero = NULL; + int result = -1; + + zero = PyLong_FromLong(0); + if (zero == NULL) /* MemoryError in int(0) */ + goto end; + + /* Check if the value can possibly be in the range. */ + + cmp1 = PyObject_RichCompareBool(r->step, zero, Py_GT); + if (cmp1 == -1) + goto end; + if (cmp1 == 1) { /* positive steps: start <= ob < stop */ + cmp2 = PyObject_RichCompareBool(r->start, ob, Py_LE); + cmp3 = PyObject_RichCompareBool(ob, r->stop, Py_LT); + } + else { /* negative steps: stop < ob <= start */ + cmp2 = PyObject_RichCompareBool(ob, r->start, Py_LE); + cmp3 = PyObject_RichCompareBool(r->stop, ob, Py_LT); + } + + if (cmp2 == -1 || cmp3 == -1) /* TypeError */ + goto end; + if (cmp2 == 0 || cmp3 == 0) { /* ob outside of range */ + result = 0; + goto end; + } + + /* Check that the stride does not invalidate ob's membership. */ + tmp1 = PyNumber_Subtract(ob, r->start); + if (tmp1 == NULL) + goto end; + tmp2 = PyNumber_Remainder(tmp1, r->step); + if (tmp2 == NULL) + goto end; + /* result = (int(ob) - start % step) == 0 */ + result = PyObject_RichCompareBool(tmp2, zero, Py_EQ); + end: + Py_XDECREF(tmp1); + Py_XDECREF(tmp2); + Py_XDECREF(zero); + return result; +} + static int range_contains(rangeobject *r, PyObject *ob) { - if (PyLong_CheckExact(ob) || PyBool_Check(ob)) { - int cmp1, cmp2, cmp3; - PyObject *tmp1 = NULL; - PyObject *tmp2 = NULL; - PyObject *zero = NULL; - int result = -1; + if (PyLong_CheckExact(ob) || PyBool_Check(ob)) + return range_contains_long(r, ob); - zero = PyLong_FromLong(0); - if (zero == NULL) /* MemoryError in int(0) */ - goto end; - - /* Check if the value can possibly be in the range. */ - - cmp1 = PyObject_RichCompareBool(r->step, zero, Py_GT); - if (cmp1 == -1) - goto end; - if (cmp1 == 1) { /* positive steps: start <= ob < stop */ - cmp2 = PyObject_RichCompareBool(r->start, ob, Py_LE); - cmp3 = PyObject_RichCompareBool(ob, r->stop, Py_LT); - } - else { /* negative steps: stop < ob <= start */ - cmp2 = PyObject_RichCompareBool(ob, r->start, Py_LE); - cmp3 = PyObject_RichCompareBool(r->stop, ob, Py_LT); - } - - if (cmp2 == -1 || cmp3 == -1) /* TypeError */ - goto end; - if (cmp2 == 0 || cmp3 == 0) { /* ob outside of range */ - result = 0; - goto end; - } - - /* Check that the stride does not invalidate ob's membership. */ - tmp1 = PyNumber_Subtract(ob, r->start); - if (tmp1 == NULL) - goto end; - tmp2 = PyNumber_Remainder(tmp1, r->step); - if (tmp2 == NULL) - goto end; - /* result = (int(ob) - start % step) == 0 */ - result = PyObject_RichCompareBool(tmp2, zero, Py_EQ); - end: - Py_XDECREF(tmp1); - Py_XDECREF(tmp2); - Py_XDECREF(zero); - return result; - } - /* Fall back to iterative search. */ return (int)_PySequence_IterSearch((PyObject*)r, ob, PY_ITERSEARCH_CONTAINS); } +static PyObject * +range_count(rangeobject *r, PyObject *ob) +{ + if (PyLong_CheckExact(ob) || PyBool_Check(ob)) { + if (range_contains_long(r, ob)) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; + } else { + Py_ssize_t count; + count = _PySequence_IterSearch((PyObject*)r, ob, PY_ITERSEARCH_COUNT); + if (count == -1) + return NULL; + return PyLong_FromSsize_t(count); + } +} + +static PyObject * +range_index(rangeobject *r, PyObject *ob) +{ + PyObject *idx, *tmp; + int contains; + PyObject *format_tuple, *err_string; + static PyObject *err_format = NULL; + + if (!PyLong_CheckExact(ob) && !PyBool_Check(ob)) { + Py_ssize_t index; + index = _PySequence_IterSearch((PyObject*)r, ob, PY_ITERSEARCH_INDEX); + if (index == -1) + return NULL; + return PyLong_FromSsize_t(index); + } + + contains = range_contains_long(r, ob); + if (contains == -1) + return NULL; + + if (!contains) + goto value_error; + + tmp = PyNumber_Subtract(ob, r->start); + if (tmp == NULL) + return NULL; + + /* idx = (ob - r.start) // r.step */ + idx = PyNumber_FloorDivide(tmp, r->step); + Py_DECREF(tmp); + return idx; + +value_error: + + /* object is not in the range */ + if (err_format == NULL) { + err_format = PyUnicode_FromString("%r is not in range"); + if (err_format == NULL) + return NULL; + } + format_tuple = PyTuple_Pack(1, ob); + if (format_tuple == NULL) + return NULL; + err_string = PyUnicode_Format(err_format, format_tuple); + Py_DECREF(format_tuple); + if (err_string == NULL) + return NULL; + PyErr_SetObject(PyExc_ValueError, err_string); + Py_DECREF(err_string); + return NULL; +} + static PySequenceMethods range_as_sequence = { (lenfunc)range_length, /* sq_length */ 0, /* sq_concat */ @@ -344,10 +419,18 @@ static PyObject * range_reverse(PyObject *seq); PyDoc_STRVAR(reverse_doc, "Returns a reverse iterator."); +PyDoc_STRVAR(count_doc, +"rangeobject.count(value) -> integer -- return number of occurrences of value"); + +PyDoc_STRVAR(index_doc, +"rangeobject.index(value, [start, [stop]]) -> integer -- return index of value.\n" +"Raises ValueError if the value is not present."); + static PyMethodDef range_methods[] = { - {"__reversed__", (PyCFunction)range_reverse, METH_NOARGS, - reverse_doc}, - {"__reduce__", (PyCFunction)range_reduce, METH_VARARGS}, + {"__reversed__", (PyCFunction)range_reverse, METH_NOARGS, reverse_doc}, + {"__reduce__", (PyCFunction)range_reduce, METH_VARARGS}, + {"count", (PyCFunction)range_count, METH_O, count_doc}, + {"index", (PyCFunction)range_index, METH_O, index_doc}, {NULL, NULL} /* sentinel */ };