diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 7a3ed42..145fd0b 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -3458,24 +3458,92 @@ dictviews_sub(PyObject* self, PyObject *other) return result; } +static int +dictitems_contains(_PyDictViewObject *dv, PyObject *obj); + PyObject* _PyDictView_Intersect(PyObject* self, PyObject *other) { - PyObject *result = PySet_New(self); - PyObject *tmp; - _Py_IDENTIFIER(intersection_update); + PyObject *result; + PyObject *it; + PyObject *key; + Py_ssize_t len_self; + int rv; + int (*dict_contains)(_PyDictViewObject *, PyObject *); + + + /* Python interpreter swaps parameters when dict view + is on right side of & */ + if (!PyDictViewSet_Check(self)) { + PyObject *tmp = other; + other = self; + self = tmp; + } + + len_self = dictview_len((_PyDictViewObject *)self); + + + /* if other is a set and self is smaller than other, + reuse set intersection logic */ + if (PyAnySet_Check(other) && len_self <= PyObject_Size(other)) { + _Py_IDENTIFIER(intersection); + + return _PyObject_CallMethodIdObjArgs(other, &PyId_intersection, self, NULL); + } + + /* if other is another dict view, and it is bigger than self, + swap them */ + if (PyDictViewSet_Check(other)) { + Py_ssize_t len_other = dictview_len((_PyDictViewObject *)other); + if (len_other > len_self) { + PyObject *tmp = other; + other = self; + self = tmp; + } + } + /* at this point, self should be bigger than other */ + result = PySet_New(NULL); if (result == NULL) return NULL; - tmp = _PyObject_CallMethodIdObjArgs(result, &PyId_intersection_update, other, NULL); - if (tmp == NULL) { + it = PyObject_GetIter(other); + if (it == NULL) { Py_DECREF(result); return NULL; } - Py_DECREF(tmp); + if (PyDictKeys_Check(self)) { + dict_contains = dictkeys_contains; + } + /* else PyDictItems_Check(self) */ + else { + dict_contains = dictitems_contains; + } + + while ((key = PyIter_Next(it)) != NULL) { + rv = dict_contains((_PyDictViewObject *)self, key); + if (rv < 0) + goto error; + if (rv) { + if (PySet_Add(result, key)) { + goto error; + } + } + Py_DECREF(key); + } + Py_DECREF(it); + if (PyErr_Occurred()) { + Py_DECREF(result); + return NULL; + } return result; + + error: + Py_DECREF(it); + Py_DECREF(result); + Py_DECREF(key); + return NULL; } static PyObject*