Skip to content

Commit 998cf1f

Browse files
fgreggrhettinger
authored andcommitted
bpo-27575: port set intersection logic into dictview intersection (GH-7696)
1 parent c3ea41e commit 998cf1f

3 files changed

Lines changed: 93 additions & 4 deletions

File tree

Lib/test/test_dictviews.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,27 @@ def test_keys_set_operations(self):
9292
d1 = {'a': 1, 'b': 2}
9393
d2 = {'b': 3, 'c': 2}
9494
d3 = {'d': 4, 'e': 5}
95+
d4 = {'d': 4}
96+
97+
class CustomSet(set):
98+
def intersection(self, other):
99+
return CustomSet(super().intersection(other))
100+
95101
self.assertEqual(d1.keys() & d1.keys(), {'a', 'b'})
96102
self.assertEqual(d1.keys() & d2.keys(), {'b'})
97103
self.assertEqual(d1.keys() & d3.keys(), set())
98104
self.assertEqual(d1.keys() & set(d1.keys()), {'a', 'b'})
99105
self.assertEqual(d1.keys() & set(d2.keys()), {'b'})
100106
self.assertEqual(d1.keys() & set(d3.keys()), set())
101107
self.assertEqual(d1.keys() & tuple(d1.keys()), {'a', 'b'})
108+
self.assertEqual(d3.keys() & d4.keys(), {'d'})
109+
self.assertEqual(d4.keys() & d3.keys(), {'d'})
110+
self.assertEqual(d4.keys() & set(d3.keys()), {'d'})
111+
self.assertIsInstance(d4.keys() & frozenset(d3.keys()), set)
112+
self.assertIsInstance(frozenset(d3.keys()) & d4.keys(), set)
113+
self.assertIs(type(d4.keys() & CustomSet(d3.keys())), set)
114+
self.assertIs(type(d1.keys() & []), set)
115+
self.assertIs(type([] & d1.keys()), set)
102116

103117
self.assertEqual(d1.keys() | d1.keys(), {'a', 'b'})
104118
self.assertEqual(d1.keys() | d2.keys(), {'a', 'b', 'c'})
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Improve speed of dictview intersection by directly using set intersection
2+
logic. Patch by David Su.

Objects/dictobject.c

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4169,24 +4169,97 @@ dictviews_sub(PyObject* self, PyObject *other)
41694169
return result;
41704170
}
41714171

4172-
PyObject*
4172+
static int
4173+
dictitems_contains(_PyDictViewObject *dv, PyObject *obj);
4174+
4175+
PyObject *
41734176
_PyDictView_Intersect(PyObject* self, PyObject *other)
41744177
{
4175-
PyObject *result = PySet_New(self);
4178+
PyObject *result;
4179+
PyObject *it;
4180+
PyObject *key;
4181+
Py_ssize_t len_self;
4182+
int rv;
4183+
int (*dict_contains)(_PyDictViewObject *, PyObject *);
41764184
PyObject *tmp;
4177-
_Py_IDENTIFIER(intersection_update);
41784185

4186+
/* Python interpreter swaps parameters when dict view
4187+
is on right side of & */
4188+
if (!PyDictViewSet_Check(self)) {
4189+
PyObject *tmp = other;
4190+
other = self;
4191+
self = tmp;
4192+
}
4193+
4194+
len_self = dictview_len((_PyDictViewObject *)self);
4195+
4196+
/* if other is a set and self is smaller than other,
4197+
reuse set intersection logic */
4198+
if (Py_TYPE(other) == &PySet_Type && len_self <= PyObject_Size(other)) {
4199+
_Py_IDENTIFIER(intersection);
4200+
return _PyObject_CallMethodIdObjArgs(other, &PyId_intersection, self, NULL);
4201+
}
4202+
4203+
/* if other is another dict view, and it is bigger than self,
4204+
swap them */
4205+
if (PyDictViewSet_Check(other)) {
4206+
Py_ssize_t len_other = dictview_len((_PyDictViewObject *)other);
4207+
if (len_other > len_self) {
4208+
PyObject *tmp = other;
4209+
other = self;
4210+
self = tmp;
4211+
}
4212+
}
4213+
4214+
/* at this point, two things should be true
4215+
1. self is a dictview
4216+
2. if other is a dictview then it is smaller than self */
4217+
result = PySet_New(NULL);
41794218
if (result == NULL)
41804219
return NULL;
41814220

4221+
it = PyObject_GetIter(other);
4222+
4223+
_Py_IDENTIFIER(intersection_update);
41824224
tmp = _PyObject_CallMethodIdOneArg(result, &PyId_intersection_update, other);
41834225
if (tmp == NULL) {
41844226
Py_DECREF(result);
41854227
return NULL;
41864228
}
4187-
41884229
Py_DECREF(tmp);
4230+
4231+
if (PyDictKeys_Check(self)) {
4232+
dict_contains = dictkeys_contains;
4233+
}
4234+
/* else PyDictItems_Check(self) */
4235+
else {
4236+
dict_contains = dictitems_contains;
4237+
}
4238+
4239+
while ((key = PyIter_Next(it)) != NULL) {
4240+
rv = dict_contains((_PyDictViewObject *)self, key);
4241+
if (rv < 0) {
4242+
goto error;
4243+
}
4244+
if (rv) {
4245+
if (PySet_Add(result, key)) {
4246+
goto error;
4247+
}
4248+
}
4249+
Py_DECREF(key);
4250+
}
4251+
Py_DECREF(it);
4252+
if (PyErr_Occurred()) {
4253+
Py_DECREF(result);
4254+
return NULL;
4255+
}
41894256
return result;
4257+
4258+
error:
4259+
Py_DECREF(it);
4260+
Py_DECREF(result);
4261+
Py_DECREF(key);
4262+
return NULL;
41904263
}
41914264

41924265
static PyObject*

0 commit comments

Comments
 (0)