Skip to content

Commit 2d06ba7

Browse files
authored
[mypyc] Add primitive ops for dictionary methods (#8742)
This covers both view and list versions (it looks like both are used relatively often, at least in mypy). This accompanies #8725 to account for cases where `keys`/`values`/`items` appear in non-loop contexts.
1 parent 8f6a1cc commit 2d06ba7

File tree

6 files changed

+201
-4
lines changed

6 files changed

+201
-4
lines changed

mypyc/irbuild/specialize.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
)
2323
from mypyc.ir.rtypes import (
2424
RType, RTuple, str_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive,
25-
bool_rprimitive
25+
bool_rprimitive, is_dict_rprimitive
2626
)
27+
from mypyc.primitives.dict_ops import dict_keys_op, dict_values_op, dict_items_op
2728
from mypyc.primitives.misc_ops import true_op, false_op
2829
from mypyc.irbuild.builder import IRBuilder
2930
from mypyc.irbuild.for_helpers import translate_list_comprehension, comprehension_helper
@@ -77,6 +78,34 @@ def translate_len(
7778
return None
7879

7980

81+
@specialize_function('builtins.list')
82+
def dict_methods_fast_path(
83+
builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]:
84+
# Specialize a common case when list() is called on a dictionary view
85+
# method call, for example foo = list(bar.keys()).
86+
if not (len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]):
87+
return None
88+
arg = expr.args[0]
89+
if not (isinstance(arg, CallExpr) and not arg.args
90+
and isinstance(arg.callee, MemberExpr)):
91+
return None
92+
base = arg.callee.expr
93+
attr = arg.callee.name
94+
rtype = builder.node_type(base)
95+
if not (is_dict_rprimitive(rtype) and attr in ('keys', 'values', 'items')):
96+
return None
97+
98+
obj = builder.accept(base)
99+
# Note that it is not safe to use fast methods on dict subclasses, so
100+
# the corresponding helpers in CPy.h fallback to (inlined) generic logic.
101+
if attr == 'keys':
102+
return builder.primitive_op(dict_keys_op, [obj], expr.line)
103+
elif attr == 'values':
104+
return builder.primitive_op(dict_values_op, [obj], expr.line)
105+
else:
106+
return builder.primitive_op(dict_items_op, [obj], expr.line)
107+
108+
80109
@specialize_function('builtins.tuple')
81110
@specialize_function('builtins.set')
82111
@specialize_function('builtins.dict')

mypyc/lib-rt/CPy.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,84 @@ static tuple_T3OOO CPy_GetExcInfo(void) {
13771377
return ret;
13781378
}
13791379

1380+
static PyObject *CPyDict_KeysView(PyObject *dict) {
1381+
if (PyDict_CheckExact(dict)){
1382+
return _CPyDictView_New(dict, &PyDictKeys_Type);
1383+
}
1384+
return PyObject_CallMethod(dict, "keys", NULL);
1385+
}
1386+
1387+
static PyObject *CPyDict_ValuesView(PyObject *dict) {
1388+
if (PyDict_CheckExact(dict)){
1389+
return _CPyDictView_New(dict, &PyDictValues_Type);
1390+
}
1391+
return PyObject_CallMethod(dict, "values", NULL);
1392+
}
1393+
1394+
static PyObject *CPyDict_ItemsView(PyObject *dict) {
1395+
if (PyDict_CheckExact(dict)){
1396+
return _CPyDictView_New(dict, &PyDictItems_Type);
1397+
}
1398+
return PyObject_CallMethod(dict, "items", NULL);
1399+
}
1400+
1401+
static PyObject *CPyDict_Keys(PyObject *dict) {
1402+
if PyDict_CheckExact(dict) {
1403+
return PyDict_Keys(dict);
1404+
}
1405+
// Inline generic fallback logic to also return a list.
1406+
PyObject *list = PyList_New(0);
1407+
PyObject *view = PyObject_CallMethod(dict, "keys", NULL);
1408+
if (view == NULL) {
1409+
return NULL;
1410+
}
1411+
PyObject *res = _PyList_Extend((PyListObject *)list, view);
1412+
Py_DECREF(view);
1413+
if (res == NULL) {
1414+
return NULL;
1415+
}
1416+
Py_DECREF(res);
1417+
return list;
1418+
}
1419+
1420+
static PyObject *CPyDict_Values(PyObject *dict) {
1421+
if PyDict_CheckExact(dict) {
1422+
return PyDict_Values(dict);
1423+
}
1424+
// Inline generic fallback logic to also return a list.
1425+
PyObject *list = PyList_New(0);
1426+
PyObject *view = PyObject_CallMethod(dict, "values", NULL);
1427+
if (view == NULL) {
1428+
return NULL;
1429+
}
1430+
PyObject *res = _PyList_Extend((PyListObject *)list, view);
1431+
Py_DECREF(view);
1432+
if (res == NULL) {
1433+
return NULL;
1434+
}
1435+
Py_DECREF(res);
1436+
return list;
1437+
}
1438+
1439+
static PyObject *CPyDict_Items(PyObject *dict) {
1440+
if PyDict_CheckExact(dict) {
1441+
return PyDict_Items(dict);
1442+
}
1443+
// Inline generic fallback logic to also return a list.
1444+
PyObject *list = PyList_New(0);
1445+
PyObject *view = PyObject_CallMethod(dict, "items", NULL);
1446+
if (view == NULL) {
1447+
return NULL;
1448+
}
1449+
PyObject *res = _PyList_Extend((PyListObject *)list, view);
1450+
Py_DECREF(view);
1451+
if (res == NULL) {
1452+
return NULL;
1453+
}
1454+
Py_DECREF(res);
1455+
return list;
1456+
}
1457+
13801458
static PyObject *CPyDict_GetKeysIter(PyObject *dict) {
13811459
if (PyDict_CheckExact(dict)) {
13821460
// Return dict itself to indicate we can use fast path instead.

mypyc/lib-rt/pythonsupport.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,25 @@ CPyGen_SetStopIterationValue(PyObject *value)
366366
return 0;
367367
}
368368

369+
// Copied from dictobject.c and dictobject.h, these are not Public before
370+
// Python 3.8. Also remove some error checks that we do in the callers.
371+
typedef struct {
372+
PyObject_HEAD
373+
PyDictObject *dv_dict;
374+
} _CPyDictViewObject;
375+
376+
static PyObject *
377+
_CPyDictView_New(PyObject *dict, PyTypeObject *type)
378+
{
379+
_CPyDictViewObject *dv = PyObject_GC_New(_CPyDictViewObject, type);
380+
if (dv == NULL)
381+
return NULL;
382+
Py_INCREF(dict);
383+
dv->dv_dict = (PyDictObject *)dict;
384+
PyObject_GC_Track(dv);
385+
return (PyObject *)dv;
386+
}
387+
369388
#ifdef __cplusplus
370389
}
371390
#endif

mypyc/primitives/dict_ops.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypyc.ir.ops import EmitterInterface, ERR_FALSE, ERR_MAGIC, ERR_NEVER
66
from mypyc.ir.rtypes import (
77
dict_rprimitive, object_rprimitive, bool_rprimitive, int_rprimitive,
8-
dict_next_rtuple_single, dict_next_rtuple_pair
8+
list_rprimitive, dict_next_rtuple_single, dict_next_rtuple_pair
99
)
1010

1111
from mypyc.primitives.registry import (
@@ -125,6 +125,60 @@ def emit_new_dict(emitter: EmitterInterface, args: List[str], dest: str) -> None
125125
error_kind=ERR_MAGIC,
126126
emit=call_emit('CPyDict_FromAny'))
127127

128+
# dict.keys()
129+
method_op(
130+
name='keys',
131+
arg_types=[dict_rprimitive],
132+
result_type=object_rprimitive,
133+
error_kind=ERR_MAGIC,
134+
emit=call_emit('CPyDict_KeysView')
135+
)
136+
137+
# dict.values()
138+
method_op(
139+
name='values',
140+
arg_types=[dict_rprimitive],
141+
result_type=object_rprimitive,
142+
error_kind=ERR_MAGIC,
143+
emit=call_emit('CPyDict_ValuesView')
144+
)
145+
146+
# dict.items()
147+
method_op(
148+
name='items',
149+
arg_types=[dict_rprimitive],
150+
result_type=object_rprimitive,
151+
error_kind=ERR_MAGIC,
152+
emit=call_emit('CPyDict_ItemsView')
153+
)
154+
155+
# list(dict.keys())
156+
dict_keys_op = custom_op(
157+
name='keys',
158+
arg_types=[dict_rprimitive],
159+
result_type=list_rprimitive,
160+
error_kind=ERR_MAGIC,
161+
emit=call_emit('CPyDict_Keys')
162+
)
163+
164+
# list(dict.values())
165+
dict_values_op = custom_op(
166+
name='values',
167+
arg_types=[dict_rprimitive],
168+
result_type=list_rprimitive,
169+
error_kind=ERR_MAGIC,
170+
emit=call_emit('CPyDict_Values')
171+
)
172+
173+
# list(dict.items())
174+
dict_items_op = custom_op(
175+
name='items',
176+
arg_types=[dict_rprimitive],
177+
result_type=list_rprimitive,
178+
error_kind=ERR_MAGIC,
179+
emit=call_emit('CPyDict_Items')
180+
)
181+
128182

129183
def emit_len(emitter: EmitterInterface, args: List[str], dest: str) -> None:
130184
temp = emitter.temp_name()

mypyc/test-data/fixtures/typing-full.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta):
131131
@overload
132132
def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass
133133
def values(self) -> Iterable[T_co]: pass # Approximate return type
134+
def items(self) -> Iterable[Tuple[T, T_co]]: pass # Approximate return type
134135
def __len__(self) -> int: ...
135136
def __contains__(self, arg: object) -> int: pass
136137

mypyc/test-data/run.test

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ update(s, [5, 4, 3])
989989
assert s == {1, 2, 3, 4, 5}
990990

991991
[case testDictStuff]
992-
from typing import Dict, Any
992+
from typing import Dict, Any, List, Set, Tuple
993993
from defaultdictwrap import make_dict
994994

995995
def f(x: int) -> int:
@@ -1035,14 +1035,22 @@ def u(x: int) -> int:
10351035
d.update(x=x)
10361036
return d['x']
10371037

1038+
def get_content(d: Dict[int, int]) -> Tuple[List[int], List[int], List[Tuple[int, int]]]:
1039+
return list(d.keys()), list(d.values()), list(d.items())
1040+
1041+
def get_content_set(d: Dict[int, int]) -> Tuple[Set[int], Set[int], Set[Tuple[int, int]]]:
1042+
return set(d.keys()), set(d.values()), set(d.items())
10381043
[file defaultdictwrap.py]
10391044
from typing import Dict
10401045
from collections import defaultdict # type: ignore
10411046
def make_dict() -> Dict[str, int]:
10421047
return defaultdict(int)
10431048

10441049
[file driver.py]
1045-
from native import f, g, h, u, make_dict1, make_dict2, update_dict
1050+
from collections import OrderedDict
1051+
from native import (
1052+
f, g, h, u, make_dict1, make_dict2, update_dict, get_content, get_content_set
1053+
)
10461054
assert f(1) == 2
10471055
assert f(2) == 1
10481056
assert g() == 30
@@ -1064,6 +1072,14 @@ update_dict(d, object.__dict__)
10641072
assert d == dict(object.__dict__)
10651073

10661074
assert u(10) == 10
1075+
assert get_content({1: 2}) == ([1], [2], [(1, 2)])
1076+
od = OrderedDict([(1, 2), (3, 4)])
1077+
assert get_content(od) == ([1, 3], [2, 4], [(1, 2), (3, 4)])
1078+
od.move_to_end(1)
1079+
assert get_content(od) == ([3, 1], [4, 2], [(3, 4), (1, 2)])
1080+
assert get_content_set({1: 2}) == ({1}, {2}, {(1, 2)})
1081+
assert get_content_set(od) == ({1, 3}, {2, 4}, {(1, 2), (3, 4)})
1082+
[typing fixtures/typing-full.pyi]
10671083

10681084
[case testDictIterationMethodsRun]
10691085
from typing import Dict

0 commit comments

Comments
 (0)