Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 77c40f8

Browse files
committed
Implement type inference for generator and yield expressions
1 parent aeeb9e0 commit 77c40f8

File tree

5 files changed

+78
-18
lines changed

5 files changed

+78
-18
lines changed

crates/hir-ty/src/db.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
116116
fn intern_impl_trait_id(&self, id: ImplTraitId) -> InternedOpaqueTyId;
117117
#[salsa::interned]
118118
fn intern_closure(&self, id: (DefWithBodyId, ExprId)) -> InternedClosureId;
119+
#[salsa::interned]
120+
fn intern_generator(&self, id: (DefWithBodyId, ExprId)) -> InternedGeneratorId;
119121

120122
#[salsa::invoke(chalk_db::associated_ty_data_query)]
121123
fn associated_ty_data(&self, id: chalk_db::AssocTypeId) -> Arc<chalk_db::AssociatedTyDatum>;
@@ -218,6 +220,10 @@ impl_intern_key!(InternedOpaqueTyId);
218220
pub struct InternedClosureId(salsa::InternId);
219221
impl_intern_key!(InternedClosureId);
220222

223+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
224+
pub struct InternedGeneratorId(salsa::InternId);
225+
impl_intern_key!(InternedGeneratorId);
226+
221227
/// This exists just for Chalk, because Chalk just has a single `FnDefId` where
222228
/// we have different IDs for struct and enum variant constructors.
223229
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]

crates/hir-ty/src/infer.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ pub struct InferenceResult {
332332
/// unresolved or missing subpatterns or subpatterns of mismatched types.
333333
pub type_of_pat: ArenaMap<PatId, Ty>,
334334
type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>,
335-
/// Interned Unknown to return references to.
335+
/// Interned common types to return references to.
336336
standard_types: InternedStandardTypes,
337337
/// Stores the types which were implicitly dereferenced in pattern binding modes.
338338
pub pat_adjustments: FxHashMap<PatId, Vec<Ty>>,
@@ -412,6 +412,8 @@ pub(crate) struct InferenceContext<'a> {
412412
/// closures, but currently this is the only field that will change there,
413413
/// so it doesn't make sense.
414414
return_ty: Ty,
415+
/// The resume type and the yield type, respectively, of the generator being inferred.
416+
resume_yield_tys: Option<(Ty, Ty)>,
415417
diverges: Diverges,
416418
breakables: Vec<BreakableContext>,
417419
}
@@ -476,6 +478,7 @@ impl<'a> InferenceContext<'a> {
476478
table: unify::InferenceTable::new(db, trait_env.clone()),
477479
trait_env,
478480
return_ty: TyKind::Error.intern(Interner), // set in collect_fn_signature
481+
resume_yield_tys: None,
479482
db,
480483
owner,
481484
body,

crates/hir-ty/src/infer/closure.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::{
1212
use super::{Expectation, InferenceContext};
1313

1414
impl InferenceContext<'_> {
15+
// This function handles both closures and generators.
1516
pub(super) fn deduce_closure_type_from_expectations(
1617
&mut self,
1718
closure_expr: ExprId,
@@ -27,6 +28,11 @@ impl InferenceContext<'_> {
2728
// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
2829
let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
2930

31+
// Generators are not Fn* so return early.
32+
if matches!(closure_ty.kind(Interner), TyKind::Generator(..)) {
33+
return;
34+
}
35+
3036
// Deduction based on the expected `dyn Fn` is done separately.
3137
if let TyKind::Dyn(dyn_ty) = expected_ty.kind(Interner) {
3238
if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) {

crates/hir-ty/src/infer/expr.rs

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ use chalk_ir::{
1010
cast::Cast, fold::Shift, DebruijnIndex, GenericArgData, Mutability, TyVariableKind,
1111
};
1212
use hir_def::{
13-
expr::{ArithOp, Array, BinaryOp, CmpOp, Expr, ExprId, LabelId, Literal, Statement, UnaryOp},
13+
expr::{
14+
ArithOp, Array, BinaryOp, ClosureKind, CmpOp, Expr, ExprId, LabelId, Literal, Statement,
15+
UnaryOp,
16+
},
1417
generics::TypeOrConstParamData,
1518
path::{GenericArg, GenericArgs},
1619
resolver::resolver_for_expr,
@@ -216,7 +219,7 @@ impl<'a> InferenceContext<'a> {
216219
self.diverges = Diverges::Maybe;
217220
TyBuilder::unit()
218221
}
219-
Expr::Closure { body, args, ret_type, arg_types, closure_kind: _ } => {
222+
Expr::Closure { body, args, ret_type, arg_types, closure_kind } => {
220223
assert_eq!(args.len(), arg_types.len());
221224

222225
let mut sig_tys = Vec::new();
@@ -244,20 +247,40 @@ impl<'a> InferenceContext<'a> {
244247
),
245248
})
246249
.intern(Interner);
247-
let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
248-
let closure_ty =
249-
TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone()))
250-
.intern(Interner);
250+
251+
let (ty, resume_yield_tys) = if matches!(closure_kind, ClosureKind::Generator(_)) {
252+
// FIXME: report error when there are more than 1 parameter.
253+
let resume_ty = match sig_tys.first() {
254+
// When `sig_tys.len() == 1` the first type is the return type, not the
255+
// first parameter type.
256+
Some(ty) if sig_tys.len() > 1 => ty.clone(),
257+
_ => self.result.standard_types.unit.clone(),
258+
};
259+
let yield_ty = self.table.new_type_var();
260+
261+
let subst = TyBuilder::subst_for_generator(self.db, self.owner)
262+
.push(resume_ty.clone())
263+
.push(yield_ty.clone())
264+
.push(ret_ty.clone())
265+
.build();
266+
267+
let generator_id = self.db.intern_generator((self.owner, tgt_expr)).into();
268+
let generator_ty = TyKind::Generator(generator_id, subst).intern(Interner);
269+
270+
(generator_ty, Some((resume_ty, yield_ty)))
271+
} else {
272+
let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
273+
let closure_ty =
274+
TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone()))
275+
.intern(Interner);
276+
277+
(closure_ty, None)
278+
};
251279

252280
// Eagerly try to relate the closure type with the expected
253281
// type, otherwise we often won't have enough information to
254282
// infer the body.
255-
self.deduce_closure_type_from_expectations(
256-
tgt_expr,
257-
&closure_ty,
258-
&sig_ty,
259-
expected,
260-
);
283+
self.deduce_closure_type_from_expectations(tgt_expr, &ty, &sig_ty, expected);
261284

262285
// Now go through the argument patterns
263286
for (arg_pat, arg_ty) in args.iter().zip(sig_tys) {
@@ -266,15 +289,18 @@ impl<'a> InferenceContext<'a> {
266289

267290
let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
268291
let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone());
292+
let prev_resume_yield_tys =
293+
mem::replace(&mut self.resume_yield_tys, resume_yield_tys);
269294

270295
self.with_breakable_ctx(BreakableKind::Border, self.err_ty(), None, |this| {
271296
this.infer_expr_coerce(*body, &Expectation::has_type(ret_ty));
272297
});
273298

274299
self.diverges = prev_diverges;
275300
self.return_ty = prev_ret_ty;
301+
self.resume_yield_tys = prev_resume_yield_tys;
276302

277-
closure_ty
303+
ty
278304
}
279305
Expr::Call { callee, args, .. } => {
280306
let callee_ty = self.infer_expr(*callee, &Expectation::none());
@@ -423,11 +449,18 @@ impl<'a> InferenceContext<'a> {
423449
TyKind::Never.intern(Interner)
424450
}
425451
Expr::Yield { expr } => {
426-
// FIXME: track yield type for coercion
427-
if let Some(expr) = expr {
428-
self.infer_expr(*expr, &Expectation::none());
452+
if let Some((resume_ty, yield_ty)) = self.resume_yield_tys.clone() {
453+
if let Some(expr) = expr {
454+
self.infer_expr_coerce(*expr, &Expectation::has_type(yield_ty));
455+
} else {
456+
let unit = self.result.standard_types.unit.clone();
457+
let _ = self.coerce(Some(tgt_expr), &unit, &yield_ty);
458+
}
459+
resume_ty
460+
} else {
461+
// FIXME: report error (yield expr in non-generator)
462+
TyKind::Error.intern(Interner)
429463
}
430-
TyKind::Never.intern(Interner)
431464
}
432465
Expr::RecordLit { path, fields, spread, .. } => {
433466
let (ty, def_id) = self.resolve_variant(path.as_deref(), false);

crates/hir-ty/src/mapping.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ impl From<crate::db::InternedClosureId> for chalk_ir::ClosureId<Interner> {
103103
}
104104
}
105105

106+
impl From<chalk_ir::GeneratorId<Interner>> for crate::db::InternedGeneratorId {
107+
fn from(id: chalk_ir::GeneratorId<Interner>) -> Self {
108+
Self::from_intern_id(id.0)
109+
}
110+
}
111+
112+
impl From<crate::db::InternedGeneratorId> for chalk_ir::GeneratorId<Interner> {
113+
fn from(id: crate::db::InternedGeneratorId) -> Self {
114+
chalk_ir::GeneratorId(id.as_intern_id())
115+
}
116+
}
117+
106118
pub fn to_foreign_def_id(id: TypeAliasId) -> ForeignDefId {
107119
chalk_ir::ForeignDefId(salsa::InternKey::as_intern_id(&id))
108120
}

0 commit comments

Comments
 (0)