Skip to content

[mypyc] Implement async for as a statement and in comprehensions #13444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,31 +873,24 @@ def _visit_display(


def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value:
if any(o.generator.is_async):
builder.error("async comprehensions are unimplemented", o.line)
return translate_list_comprehension(builder, o.generator)


def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value:
if any(o.generator.is_async):
builder.error("async comprehensions are unimplemented", o.line)
return translate_set_comprehension(builder, o.generator)


def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value:
if any(o.is_async):
builder.error("async comprehensions are unimplemented", o.line)

d = builder.call_c(dict_new_op, [], o.line)
loop_params = list(zip(o.indices, o.sequences, o.condlists))
d = builder.maybe_spill(builder.call_c(dict_new_op, [], o.line))
loop_params = list(zip(o.indices, o.sequences, o.condlists, o.is_async))

def gen_inner_stmts() -> None:
k = builder.accept(o.key)
v = builder.accept(o.value)
builder.call_c(dict_set_item_op, [d, k, v], o.line)
builder.call_c(dict_set_item_op, [builder.read(d), k, v], o.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
return d
return builder.read(d)


# Misc
Expand All @@ -915,9 +908,6 @@ def get_arg(arg: Expression | None) -> Value:


def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value:
if any(o.is_async):
builder.error("async comprehensions are unimplemented", o.line)

builder.warning("Treating generator comprehension as list", o.line)
return builder.call_c(iter_op, [translate_list_comprehension(builder, o)], o.line)

Expand Down
132 changes: 115 additions & 17 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,30 @@
TupleExpr,
TypeAlias,
)
from mypyc.ir.ops import BasicBlock, Branch, Integer, IntOp, Register, TupleGet, TupleSet, Value
from mypyc.ir.ops import (
BasicBlock,
Branch,
Integer,
IntOp,
LoadAddress,
LoadMem,
Register,
TupleGet,
TupleSet,
Value,
)
from mypyc.ir.rtypes import (
RTuple,
RType,
bool_rprimitive,
int_rprimitive,
is_dict_rprimitive,
is_list_rprimitive,
is_sequence_rprimitive,
is_short_int_rprimitive,
is_str_rprimitive,
is_tuple_rprimitive,
pointer_rprimitive,
short_int_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder
Expand All @@ -45,8 +58,9 @@
dict_value_iter_op,
)
from mypyc.primitives.exc_ops import no_err_occurred_op
from mypyc.primitives.generic_ops import iter_op, next_op
from mypyc.primitives.generic_ops import aiter_op, anext_op, iter_op, next_op
from mypyc.primitives.list_ops import list_append_op, list_get_item_unsafe_op, new_list_set_item_op
from mypyc.primitives.misc_ops import stop_async_iteration_op
from mypyc.primitives.registry import CFunctionDescription
from mypyc.primitives.set_ops import set_add_op

Expand All @@ -59,6 +73,7 @@ def for_loop_helper(
expr: Expression,
body_insts: GenFunc,
else_insts: GenFunc | None,
is_async: bool,
line: int,
) -> None:
"""Generate IR for a loop.
Expand All @@ -81,7 +96,9 @@ def for_loop_helper(
# Determine where we want to exit, if our condition check fails.
normal_loop_exit = else_block if else_insts is not None else exit_block

for_gen = make_for_loop_generator(builder, index, expr, body_block, normal_loop_exit, line)
for_gen = make_for_loop_generator(
builder, index, expr, body_block, normal_loop_exit, line, is_async=is_async
)

builder.push_loop_stack(step_block, exit_block)
condition_block = BasicBlock()
Expand Down Expand Up @@ -220,32 +237,33 @@ def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Valu
if val is not None:
return val

list_ops = builder.new_list_op([], gen.line)
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
list_ops = builder.maybe_spill(builder.new_list_op([], gen.line))

loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))

def gen_inner_stmts() -> None:
e = builder.accept(gen.left_expr)
builder.call_c(list_append_op, [list_ops, e], gen.line)
builder.call_c(list_append_op, [builder.read(list_ops), e], gen.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
return list_ops
return builder.read(list_ops)


def translate_set_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
set_ops = builder.new_set_op([], gen.line)
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
set_ops = builder.maybe_spill(builder.new_set_op([], gen.line))
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))

def gen_inner_stmts() -> None:
e = builder.accept(gen.left_expr)
builder.call_c(set_add_op, [set_ops, e], gen.line)
builder.call_c(set_add_op, [builder.read(set_ops), e], gen.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
return set_ops
return builder.read(set_ops)


def comprehension_helper(
builder: IRBuilder,
loop_params: list[tuple[Lvalue, Expression, list[Expression]]],
loop_params: list[tuple[Lvalue, Expression, list[Expression], bool]],
gen_inner_stmts: Callable[[], None],
line: int,
) -> None:
Expand All @@ -260,20 +278,26 @@ def comprehension_helper(
gen_inner_stmts: function to generate the IR for the body of the innermost loop
"""

def handle_loop(loop_params: list[tuple[Lvalue, Expression, list[Expression]]]) -> None:
def handle_loop(loop_params: list[tuple[Lvalue, Expression, list[Expression], bool]]) -> None:
"""Generate IR for a loop.

Given a list of (index, expression, [conditions]) tuples, generate IR
for the nested loops the list defines.
"""
index, expr, conds = loop_params[0]
index, expr, conds, is_async = loop_params[0]
for_loop_helper(
builder, index, expr, lambda: loop_contents(conds, loop_params[1:]), None, line
builder,
index,
expr,
lambda: loop_contents(conds, loop_params[1:]),
None,
is_async=is_async,
line=line,
)

def loop_contents(
conds: list[Expression],
remaining_loop_params: list[tuple[Lvalue, Expression, list[Expression]]],
remaining_loop_params: list[tuple[Lvalue, Expression, list[Expression], bool]],
) -> None:
"""Generate the body of the loop.

Expand Down Expand Up @@ -319,13 +343,23 @@ def make_for_loop_generator(
body_block: BasicBlock,
loop_exit: BasicBlock,
line: int,
is_async: bool = False,
nested: bool = False,
) -> ForGenerator:
"""Return helper object for generating a for loop over an iterable.

If "nested" is True, this is a nested iterator such as "e" in "enumerate(e)".
"""

# Do an async loop if needed. async is always generic
if is_async:
expr_reg = builder.accept(expr)
async_obj = ForAsyncIterable(builder, index, body_block, loop_exit, line, nested)
item_type = builder._analyze_iterable_item_type(expr)
item_rtype = builder.type_to_rtype(item_type)
async_obj.init(expr_reg, item_rtype)
return async_obj

rtyp = builder.node_type(expr)
if is_sequence_rprimitive(rtyp):
# Special case "for x in <list>".
Expand Down Expand Up @@ -500,7 +534,7 @@ def load_len(self, expr: Value | AssignmentTarget) -> Value:


class ForIterable(ForGenerator):
"""Generate IR for a for loop over an arbitrary iterable (the normal case)."""
"""Generate IR for a for loop over an arbitrary iterable (the general case)."""

def need_cleanup(self) -> bool:
# Create a new cleanup block for when the loop is finished.
Expand Down Expand Up @@ -548,6 +582,70 @@ def gen_cleanup(self) -> None:
self.builder.call_c(no_err_occurred_op, [], self.line)


class ForAsyncIterable(ForGenerator):
"""Generate IR for an async for loop."""

def init(self, expr_reg: Value, target_type: RType) -> None:
# Define targets to contain the expression, along with the
# iterator that will be used for the for-loop. We are inside
# of a generator function, so we will spill these into
# environment class.
builder = self.builder
iter_reg = builder.call_c(aiter_op, [expr_reg], self.line)
builder.maybe_spill(expr_reg)
self.iter_target = builder.maybe_spill(iter_reg)
self.target_type = target_type
self.stop_reg = Register(bool_rprimitive)

def gen_condition(self) -> None:
# This does the test and fetches the next value
# try:
# TARGET = await type(iter).__anext__(iter)
# stop = False
# except StopAsyncIteration:
# stop = True
#
# What a pain.
# There are optimizations available here if we punch through some abstractions.

from mypyc.irbuild.statement import emit_await, transform_try_except

builder = self.builder
line = self.line

def except_match() -> Value:
addr = builder.add(LoadAddress(pointer_rprimitive, stop_async_iteration_op.src, line))
return builder.add(LoadMem(stop_async_iteration_op.type, addr))

def try_body() -> None:
awaitable = builder.call_c(anext_op, [builder.read(self.iter_target)], line)
self.next_reg = emit_await(builder, awaitable, line)
builder.assign(self.stop_reg, builder.false(), -1)

def except_body() -> None:
builder.assign(self.stop_reg, builder.true(), line)

transform_try_except(
builder, try_body, [((except_match, line), None, except_body)], None, line
)

builder.add(Branch(self.stop_reg, self.loop_exit, self.body_block, Branch.BOOL))

def begin_body(self) -> None:
# Assign the value obtained from await __anext__ to the
# lvalue so that it can be referenced by code in the body of the loop.
builder = self.builder
line = self.line
# We unbox here so that iterating with tuple unpacking generates a tuple based
# unpack instead of an iterator based one.
next_reg = builder.coerce(self.next_reg, self.target_type, line)
builder.assign(builder.get_assignment_target(self.index), next_reg, line)

def gen_step(self) -> None:
# Nothing to do here, since we get the next item as part of gen_condition().
pass


def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> Value:
"""Emit a potentially unsafe index into a target."""
# This doesn't really fit nicely into any of our data-driven frameworks
Expand Down
8 changes: 5 additions & 3 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def any_all_helper(
) -> Value:
retval = Register(bool_rprimitive)
builder.assign(retval, initial_value(), -1)
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
true_block, false_block, exit_block = BasicBlock(), BasicBlock(), BasicBlock()

def gen_inner_stmts() -> None:
Expand Down Expand Up @@ -417,7 +417,9 @@ def gen_inner_stmts() -> None:
call_expr = builder.accept(gen_expr.left_expr)
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)

loop_params = list(zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists))
loop_params = list(
zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async)
)
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)

return retval
Expand Down Expand Up @@ -467,7 +469,7 @@ def gen_inner_stmts() -> None:
builder.assign(retval, builder.accept(gen.left_expr), gen.left_expr.line)
builder.goto(exit_block)

loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)

# Now we need the case for when nothing got hit. If there was
Expand Down
19 changes: 12 additions & 7 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
)

GenFunc = Callable[[], None]
ValueGenFunc = Callable[[], Value]


def transform_block(builder: IRBuilder, block: Block) -> None:
Expand Down Expand Up @@ -327,17 +328,16 @@ def transform_while_stmt(builder: IRBuilder, s: WhileStmt) -> None:


def transform_for_stmt(builder: IRBuilder, s: ForStmt) -> None:
if s.is_async:
builder.error("async for is unimplemented", s.line)

def body() -> None:
builder.accept(s.body)

def else_block() -> None:
assert s.else_body is not None
builder.accept(s.else_body)

for_loop_helper(builder, s.index, s.expr, body, else_block if s.else_body else None, s.line)
for_loop_helper(
builder, s.index, s.expr, body, else_block if s.else_body else None, s.is_async, s.line
)


def transform_break_stmt(builder: IRBuilder, node: BreakStmt) -> None:
Expand All @@ -362,7 +362,7 @@ def transform_raise_stmt(builder: IRBuilder, s: RaiseStmt) -> None:
def transform_try_except(
builder: IRBuilder,
body: GenFunc,
handlers: Sequence[tuple[Expression | None, Expression | None, GenFunc]],
handlers: Sequence[tuple[tuple[ValueGenFunc, int] | None, Expression | None, GenFunc]],
else_body: GenFunc | None,
line: int,
) -> None:
Expand Down Expand Up @@ -399,8 +399,9 @@ def transform_try_except(
for type, var, handler_body in handlers:
next_block = None
if type:
type_f, type_line = type
next_block, body_block = BasicBlock(), BasicBlock()
matches = builder.call_c(exc_matches_op, [builder.accept(type)], type.line)
matches = builder.call_c(exc_matches_op, [type_f()], type_line)
builder.add(Branch(matches, body_block, next_block, Branch.BOOL))
builder.activate_block(body_block)
if var:
Expand Down Expand Up @@ -451,8 +452,12 @@ def body() -> None:
def make_handler(body: Block) -> GenFunc:
return lambda: builder.accept(body)

def make_entry(type: Expression) -> tuple[ValueGenFunc, int]:
return (lambda: builder.accept(type), type.line)

handlers = [
(type, var, make_handler(body)) for type, var, body in zip(t.types, t.vars, t.handlers)
(make_entry(type) if type else None, var, make_handler(body))
for type, var, body in zip(t.types, t.vars, t.handlers)
]
else_body = (lambda: builder.accept(t.else_body)) if t.else_body else None
transform_try_except(builder, body, handlers, else_body, t.line)
Expand Down
4 changes: 4 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,

PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
PyObject *func);

PyObject *CPy_GetAIter(PyObject *obj);
PyObject *CPy_GetANext(PyObject *aiter);

#ifdef __cplusplus
}
#endif
Expand Down
Loading