Skip to content

Commit bcf60ac

Browse files
authored
[mypyc] Support __pow__, __rpow__, and __ipow__ dunders (#14616)
Unlike every other slot, power slots are ternary. Some special casing had to be done in generate_bin_op_wrapper() to support the third slot argument. Annoyingly, pow() also has these unique behaviours: - Ternary pow() does NOT fallback to `__rpow__` if `__pow__` returns `NotImplemented` unlike binary ops. - Ternary pow() does NOT try the right operand's `__rpow__` first if it's a subclass of the left operand and redefines `__rpow__` unlike binary ops. Add in the fact it's allowed and common to only define `__(r|i)pow__` to take two arguments (actually mypy won't let you define `__rpow__` to take three arguments) and the patch becomes frustratingly non-trivial. Towards mypyc/mypyc#553. Fixes mypyc/mypyc#907.
1 parent d586070 commit bcf60ac

File tree

8 files changed

+297
-25
lines changed

8 files changed

+297
-25
lines changed

mypyc/codegen/emitclass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
generate_dunder_wrapper,
1414
generate_get_wrapper,
1515
generate_hash_wrapper,
16+
generate_ipow_wrapper,
1617
generate_len_wrapper,
1718
generate_richcompare_wrapper,
1819
generate_set_del_item_wrapper,
@@ -109,6 +110,11 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
109110
"__ior__": ("nb_inplace_or", generate_dunder_wrapper),
110111
"__ixor__": ("nb_inplace_xor", generate_dunder_wrapper),
111112
"__imatmul__": ("nb_inplace_matrix_multiply", generate_dunder_wrapper),
113+
# Ternary operations. (yes, really)
114+
# These are special cased in generate_bin_op_wrapper().
115+
"__pow__": ("nb_power", generate_bin_op_wrapper),
116+
"__rpow__": ("nb_power", generate_bin_op_wrapper),
117+
"__ipow__": ("nb_inplace_power", generate_ipow_wrapper),
112118
}
113119

114120
AS_ASYNC_SLOT_DEFS: SlotTable = {

mypyc/codegen/emitwrapper.py

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,32 @@ def generate_dunder_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
301301
return gen.wrapper_name()
302302

303303

304+
def generate_ipow_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
305+
"""Generate a wrapper for native __ipow__.
306+
307+
Since __ipow__ fills a ternary slot, but almost no one defines __ipow__ to take three
308+
arguments, the wrapper needs to tweaked to force it to accept three arguments.
309+
"""
310+
gen = WrapperGenerator(cl, emitter)
311+
gen.set_target(fn)
312+
assert len(fn.args) in (2, 3), "__ipow__ should only take 2 or 3 arguments"
313+
gen.arg_names = ["self", "exp", "mod"]
314+
gen.emit_header()
315+
gen.emit_arg_processing()
316+
handle_third_pow_argument(
317+
fn,
318+
emitter,
319+
gen,
320+
if_unsupported=[
321+
'PyErr_SetString(PyExc_TypeError, "__ipow__ takes 2 positional arguments but 3 were given");',
322+
"return NULL;",
323+
],
324+
)
325+
gen.emit_call()
326+
gen.finish()
327+
return gen.wrapper_name()
328+
329+
304330
def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
305331
"""Generates a wrapper for a native binary dunder method.
306332
@@ -311,13 +337,16 @@ def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
311337
"""
312338
gen = WrapperGenerator(cl, emitter)
313339
gen.set_target(fn)
314-
gen.arg_names = ["left", "right"]
340+
if fn.name in ("__pow__", "__rpow__"):
341+
gen.arg_names = ["left", "right", "mod"]
342+
else:
343+
gen.arg_names = ["left", "right"]
315344
wrapper_name = gen.wrapper_name()
316345

317346
gen.emit_header()
318347
if fn.name not in reverse_op_methods and fn.name in reverse_op_method_names:
319348
# There's only a reverse operator method.
320-
generate_bin_op_reverse_only_wrapper(emitter, gen)
349+
generate_bin_op_reverse_only_wrapper(fn, emitter, gen)
321350
else:
322351
rmethod = reverse_op_methods[fn.name]
323352
fn_rev = cl.get_method(rmethod)
@@ -334,6 +363,7 @@ def generate_bin_op_forward_only_wrapper(
334363
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
335364
) -> None:
336365
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
366+
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
337367
gen.emit_call(not_implemented_handler="goto typefail;")
338368
gen.emit_error_handling()
339369
emitter.emit_label("typefail")
@@ -352,19 +382,16 @@ def generate_bin_op_forward_only_wrapper(
352382
# if not isinstance(other, int):
353383
# return NotImplemented
354384
# ...
355-
rmethod = reverse_op_methods[fn.name]
356-
emitter.emit_line(f"_Py_IDENTIFIER({rmethod});")
357-
emitter.emit_line(
358-
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
359-
op_methods_to_symbols[fn.name], rmethod
360-
)
361-
)
385+
generate_bin_op_reverse_dunder_call(fn, emitter, reverse_op_methods[fn.name])
362386
gen.finish()
363387

364388

365-
def generate_bin_op_reverse_only_wrapper(emitter: Emitter, gen: WrapperGenerator) -> None:
389+
def generate_bin_op_reverse_only_wrapper(
390+
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
391+
) -> None:
366392
gen.arg_names = ["right", "left"]
367393
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
394+
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
368395
gen.emit_call()
369396
gen.emit_error_handling()
370397
emitter.emit_label("typefail")
@@ -390,7 +417,14 @@ def generate_bin_op_both_wrappers(
390417
)
391418
)
392419
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
393-
gen.emit_call(not_implemented_handler="goto typefail;")
420+
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail2;"])
421+
# Ternary __rpow__ calls aren't a thing so immediately bail
422+
# if ternary __pow__ returns NotImplemented.
423+
if fn.name == "__pow__" and len(fn.args) == 3:
424+
fwd_not_implemented_handler = "goto typefail2;"
425+
else:
426+
fwd_not_implemented_handler = "goto typefail;"
427+
gen.emit_call(not_implemented_handler=fwd_not_implemented_handler)
394428
gen.emit_error_handling()
395429
emitter.emit_line("}")
396430
emitter.emit_label("typefail")
@@ -402,22 +436,59 @@ def generate_bin_op_both_wrappers(
402436
gen.set_target(fn_rev)
403437
gen.arg_names = ["right", "left"]
404438
gen.emit_arg_processing(error=GotoHandler("typefail2"), raise_exception=False)
439+
handle_third_pow_argument(fn_rev, emitter, gen, if_unsupported=["goto typefail2;"])
405440
gen.emit_call()
406441
gen.emit_error_handling()
407442
emitter.emit_line("} else {")
408-
emitter.emit_line(f"_Py_IDENTIFIER({fn_rev.name});")
409-
emitter.emit_line(
410-
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
411-
op_methods_to_symbols[fn.name], fn_rev.name
412-
)
413-
)
443+
generate_bin_op_reverse_dunder_call(fn, emitter, fn_rev.name)
414444
emitter.emit_line("}")
415445
emitter.emit_label("typefail2")
416446
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
417447
emitter.emit_line("return Py_NotImplemented;")
418448
gen.finish()
419449

420450

451+
def generate_bin_op_reverse_dunder_call(fn: FuncIR, emitter: Emitter, rmethod: str) -> None:
452+
if fn.name in ("__pow__", "__rpow__"):
453+
# Ternary pow() will never call the reverse dunder.
454+
emitter.emit_line("if (obj_mod == Py_None) {")
455+
emitter.emit_line(f"_Py_IDENTIFIER({rmethod});")
456+
emitter.emit_line(
457+
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format(
458+
op_methods_to_symbols[fn.name], rmethod
459+
)
460+
)
461+
if fn.name in ("__pow__", "__rpow__"):
462+
emitter.emit_line("} else {")
463+
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
464+
emitter.emit_line("return Py_NotImplemented;")
465+
emitter.emit_line("}")
466+
467+
468+
def handle_third_pow_argument(
469+
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator, *, if_unsupported: list[str]
470+
) -> None:
471+
if fn.name not in ("__pow__", "__rpow__", "__ipow__"):
472+
return
473+
474+
if (fn.name in ("__pow__", "__ipow__") and len(fn.args) == 2) or fn.name == "__rpow__":
475+
# If the power dunder only supports two arguments and the third
476+
# argument (AKA mod) is set to a non-default value, simply bail.
477+
#
478+
# Importantly, this prevents any ternary __rpow__ calls from
479+
# happening (as per the language specification).
480+
emitter.emit_line("if (obj_mod != Py_None) {")
481+
for line in if_unsupported:
482+
emitter.emit_line(line)
483+
emitter.emit_line("}")
484+
# The slot wrapper will receive three arguments, but the call only
485+
# supports two so make sure that the third argument isn't passed
486+
# along. This is needed as two-argument __(i)pow__ is allowed and
487+
# rather common.
488+
if len(gen.arg_names) == 3:
489+
gen.arg_names.pop()
490+
491+
421492
RICHCOMPARE_OPS = {
422493
"__lt__": "Py_LT",
423494
"__gt__": "Py_GT",

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ CPyTagged CPyObject_Hash(PyObject *o);
344344
PyObject *CPyObject_GetAttr3(PyObject *v, PyObject *name, PyObject *defl);
345345
PyObject *CPyIter_Next(PyObject *iter);
346346
PyObject *CPyNumber_Power(PyObject *base, PyObject *index);
347+
PyObject *CPyNumber_InPlacePower(PyObject *base, PyObject *index);
347348
PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);
348349

349350

mypyc/lib-rt/generic_ops.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ PyObject *CPyNumber_Power(PyObject *base, PyObject *index)
4141
return PyNumber_Power(base, index, Py_None);
4242
}
4343

44+
PyObject *CPyNumber_InPlacePower(PyObject *base, PyObject *index)
45+
{
46+
return PyNumber_InPlacePower(base, index, Py_None);
47+
}
48+
4449
PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
4550
PyObject *start_obj = CPyTagged_AsObject(start);
4651
PyObject *end_obj = CPyTagged_AsObject(end);

mypyc/primitives/generic_ops.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,25 @@
109109
priority=0,
110110
)
111111

112-
binary_op(
113-
name="**",
114-
arg_types=[object_rprimitive, object_rprimitive],
115-
return_type=object_rprimitive,
116-
error_kind=ERR_MAGIC,
117-
c_function_name="CPyNumber_Power",
118-
priority=0,
119-
)
112+
for op, c_function in (("**", "CPyNumber_Power"), ("**=", "CPyNumber_InPlacePower")):
113+
binary_op(
114+
name=op,
115+
arg_types=[object_rprimitive, object_rprimitive],
116+
return_type=object_rprimitive,
117+
error_kind=ERR_MAGIC,
118+
c_function_name=c_function,
119+
priority=0,
120+
)
121+
122+
for arg_count, c_function in ((2, "CPyNumber_Power"), (3, "PyNumber_Power")):
123+
function_op(
124+
name="builtins.pow",
125+
arg_types=[object_rprimitive] * arg_count,
126+
return_type=object_rprimitive,
127+
error_kind=ERR_MAGIC,
128+
c_function_name=c_function,
129+
priority=0,
130+
)
120131

121132
binary_op(
122133
name="in",

mypyc/test-data/fixtures/ir.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@ def __divmod__(self, other: T_contra) -> T_co: ...
2222
class __SupportsRDivMod(Protocol[T_contra, T_co]):
2323
def __rdivmod__(self, other: T_contra) -> T_co: ...
2424

25+
_M = TypeVar("_M", contravariant=True)
26+
27+
class __SupportsPow2(Protocol[T_contra, T_co]):
28+
def __pow__(self, other: T_contra) -> T_co: ...
29+
30+
class __SupportsPow3NoneOnly(Protocol[T_contra, T_co]):
31+
def __pow__(self, other: T_contra, modulo: None = ...) -> T_co: ...
32+
33+
class __SupportsPow3(Protocol[T_contra, _M, T_co]):
34+
def __pow__(self, other: T_contra, modulo: _M) -> T_co: ...
35+
36+
__SupportsSomeKindOfPow = Union[
37+
__SupportsPow2[Any, Any], __SupportsPow3NoneOnly[Any, Any] | __SupportsPow3[Any, Any, Any]
38+
]
39+
2540
class object:
2641
def __init__(self) -> None: pass
2742
def __eq__(self, x: object) -> bool: pass
@@ -99,6 +114,7 @@ def __add__(self, n: float) -> float: pass
99114
def __sub__(self, n: float) -> float: pass
100115
def __mul__(self, n: float) -> float: pass
101116
def __truediv__(self, n: float) -> float: pass
117+
def __pow__(self, n: float) -> float: pass
102118
def __neg__(self) -> float: pass
103119
def __pos__(self) -> float: pass
104120
def __abs__(self) -> float: pass
@@ -318,6 +334,12 @@ def abs(x: __SupportsAbs[T]) -> T: ...
318334
def divmod(x: __SupportsDivMod[T_contra, T_co], y: T_contra) -> T_co: ...
319335
@overload
320336
def divmod(x: T_contra, y: __SupportsRDivMod[T_contra, T_co]) -> T_co: ...
337+
@overload
338+
def pow(base: __SupportsPow2[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
339+
@overload
340+
def pow(base: __SupportsPow3NoneOnly[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
341+
@overload
342+
def pow(base: __SupportsPow3[T_contra, _M, T_co], exp: T_contra, mod: _M) -> T_co: ...
321343
def exit() -> None: ...
322344
def min(x: T, y: T) -> T: ...
323345
def max(x: T, y: T) -> T: ...

mypyc/test-data/irbuild-any.test

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ L0:
201201
[case testFunctionBasedOps]
202202
def f() -> None:
203203
a = divmod(5, 2)
204+
def f2() -> int:
205+
return pow(2, 5)
206+
def f3() -> float:
207+
return pow(2, 5, 3)
204208
[out]
205209
def f():
206210
r0, r1, r2 :: object
@@ -212,4 +216,25 @@ L0:
212216
r3 = unbox(tuple[float, float], r2)
213217
a = r3
214218
return 1
219+
def f2():
220+
r0, r1, r2 :: object
221+
r3 :: int
222+
L0:
223+
r0 = object 2
224+
r1 = object 5
225+
r2 = CPyNumber_Power(r0, r1)
226+
r3 = unbox(int, r2)
227+
return r3
228+
def f3():
229+
r0, r1, r2, r3 :: object
230+
r4 :: int
231+
r5 :: object
232+
L0:
233+
r0 = object 2
234+
r1 = object 5
235+
r2 = object 3
236+
r3 = PyNumber_Power(r0, r1, r2)
237+
r4 = unbox(int, r3)
238+
r5 = box(int, r4)
239+
return r5
215240

0 commit comments

Comments
 (0)