Skip to content

Desugar yield in async gen correctly, ensure gen always returns unit #119061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,12 +917,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
let poll_expr = {
let awaitee = self.expr_ident(span, awaitee_ident, awaitee_pat_hid);
let ref_mut_awaitee = self.expr_mut_addr_of(span, awaitee);
let task_context = if let Some(task_context_hid) = self.task_context {
self.expr_ident_mut(span, task_context_ident, task_context_hid)
} else {
// Use of `await` outside of an async context, we cannot use `task_context` here.
self.expr_err(span, self.tcx.sess.span_delayed_bug(span, "no task_context hir id"))

let Some(task_context_hid) = self.task_context else {
unreachable!("use of `await` outside of an async context.");
};

let task_context = self.expr_ident_mut(span, task_context_ident, task_context_hid);

let new_unchecked = self.expr_call_lang_item_fn_mut(
span,
hir::LangItem::PinNewUnchecked,
Expand Down Expand Up @@ -991,16 +992,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
);
let yield_expr = self.arena.alloc(yield_expr);

if let Some(task_context_hid) = self.task_context {
let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
let assign =
self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span)));
self.stmt_expr(span, assign)
} else {
// Use of `await` outside of an async context. Return `yield_expr` so that we can
// proceed with type checking.
self.stmt(span, hir::StmtKind::Semi(yield_expr))
}
let Some(task_context_hid) = self.task_context else {
unreachable!("use of `await` outside of an async context.");
};

let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
let assign =
self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span)));
self.stmt_expr(span, assign)
};

let loop_block = self.block_all(span, arena_vec![self; inner_match_stmt, yield_stmt], None);
Expand Down Expand Up @@ -1635,19 +1634,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
};

let mut yielded =
let yielded =
opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span));

if is_async_gen {
// yield async_gen_ready($expr);
yielded = self.expr_call_lang_item_fn(
// `yield $expr` is transformed into `task_context = yield async_gen_ready($expr)`.
// This ensures that we store our resumed `ResumeContext` correctly, and also that
// the apparent value of the `yield` expression is `()`.
let wrapped_yielded = self.expr_call_lang_item_fn(
span,
hir::LangItem::AsyncGenReady,
std::slice::from_ref(yielded),
);
}
let yield_expr = self.arena.alloc(
self.expr(span, hir::ExprKind::Yield(wrapped_yielded, hir::YieldSource::Yield)),
);

hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
let Some(task_context_hid) = self.task_context else {
unreachable!("use of `await` outside of an async context.");
};
let task_context_ident = Ident::with_dummy_span(sym::_task_context);
let lhs = self.expr_ident(span, task_context_ident, task_context_hid);

hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span))
} else {
hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
}
}

/// Desugar `ExprForLoop` from: `[opt_ident]: for <pat> in <head> <body>` into:
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
},
)
}
// For a `gen {}` block created as a `gen fn` body, we need the return type to be
// ().
Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => self.tcx.types.unit,
// All `gen {}` and `async gen {}` must return unit.
Some(hir::CoroutineKind::Gen(_) | hir::CoroutineKind::AsyncGen(_)) => {
self.tcx.types.unit
}

_ => astconv.ty_infer(None, decl.output.span()),
},
Expand Down
17 changes: 17 additions & 0 deletions tests/ui/coroutine/async-gen-yield-ty-is-unit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// compile-flags: --edition 2024 -Zunstable-options
// check-pass

#![feature(async_iterator, gen_blocks, noop_waker)]

use std::{async_iter::AsyncIterator, pin::pin, task::{Context, Waker}};

async gen fn gen_fn() -> &'static str {
yield "hello"
}

pub fn main() {
let async_iterator = pin!(gen_fn());
let waker = Waker::noop();
let ctx = &mut Context::from_waker(&waker);
async_iterator.poll_next(ctx);
}
20 changes: 20 additions & 0 deletions tests/ui/coroutine/return-types-diverge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// compile-flags: --edition 2024 -Zunstable-options
// check-pass

#![feature(gen_blocks)]

fn diverge() -> ! { loop {} }

async gen fn async_gen_fn() -> i32 { diverge() }

gen fn gen_fn() -> i32 { diverge() }

fn async_gen_block() {
async gen { yield (); diverge() };
}

fn gen_block() {
gen { yield (); diverge() };
}

fn main() {}
21 changes: 21 additions & 0 deletions tests/ui/coroutine/return-types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// compile-flags: --edition 2024 -Zunstable-options

#![feature(gen_blocks)]

async gen fn async_gen_fn() -> i32 { 0 }
//~^ ERROR mismatched types

gen fn gen_fn() -> i32 { 0 }
//~^ ERROR mismatched types

fn async_gen_block() {
async gen { yield (); 1 };
//~^ ERROR mismatched types
}

fn gen_block() {
gen { yield (); 1 };
//~^ ERROR mismatched types
}

fn main() {}
31 changes: 31 additions & 0 deletions tests/ui/coroutine/return-types.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
error[E0308]: mismatched types
--> $DIR/return-types.rs:5:38
|
LL | async gen fn async_gen_fn() -> i32 { 0 }
| --- ^ expected `()`, found integer
| |
| expected `()` because of return type

error[E0308]: mismatched types
--> $DIR/return-types.rs:8:26
|
LL | gen fn gen_fn() -> i32 { 0 }
| --- ^ expected `()`, found integer
| |
| expected `()` because of return type

error[E0308]: mismatched types
--> $DIR/return-types.rs:12:27
|
LL | async gen { yield (); 1 };
| ^ expected `()`, found integer

error[E0308]: mismatched types
--> $DIR/return-types.rs:17:21
|
LL | gen { yield (); 1 };
| ^ expected `()`, found integer

error: aborting due to 4 previous errors

For more information about this error, try `rustc --explain E0308`.