Skip to content

Commit 427896d

Browse files
Construct body for by-move coroutine closure output
1 parent fc4fff4 commit 427896d

File tree

24 files changed

+233
-15
lines changed

24 files changed

+233
-15
lines changed

compiler/rustc_borrowck/src/type_check/input_output.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
8585
self.tcx(),
8686
ty::CoroutineArgsParts {
8787
parent_args: args.parent_args(),
88+
kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()),
8889
resume_ty: next_ty_var(),
8990
yield_ty: next_ty_var(),
9091
witness: next_ty_var(),

compiler/rustc_const_eval/src/interpret/terminator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
546546
| ty::InstanceDef::ReifyShim(..)
547547
| ty::InstanceDef::ClosureOnceShim { .. }
548548
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
549+
| ty::InstanceDef::CoroutineByMoveShim { .. }
549550
| ty::InstanceDef::FnPtrShim(..)
550551
| ty::InstanceDef::DropGlue(..)
551552
| ty::InstanceDef::CloneShim(..)

compiler/rustc_hir_typeck/src/callee.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
183183
coroutine_closure_sig.to_coroutine(
184184
self.tcx,
185185
closure_args.parent_args(),
186+
closure_args.kind_ty(),
186187
self.tcx.coroutine_for_closure(def_id),
187188
tupled_upvars_ty,
188189
),

compiler/rustc_hir_typeck/src/closure.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
175175
interior,
176176
));
177177

178+
let kind_ty = match kind {
179+
hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) => self
180+
.next_ty_var(TypeVariableOrigin {
181+
kind: TypeVariableOriginKind::ClosureSynthetic,
182+
span: expr_span,
183+
}),
184+
_ => tcx.types.unit,
185+
};
186+
178187
let coroutine_args = ty::CoroutineArgs::new(
179188
tcx,
180189
ty::CoroutineArgsParts {
181190
parent_args,
191+
kind_ty,
182192
resume_ty,
183193
yield_ty,
184194
return_ty: liberated_sig.output(),
@@ -256,6 +266,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
256266
sig.to_coroutine(
257267
tcx,
258268
parent_args,
269+
closure_kind_ty,
259270
tcx.coroutine_for_closure(expr_def_id),
260271
coroutine_upvars_ty,
261272
)

compiler/rustc_hir_typeck/src/upvar.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
393393
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
394394
coroutine_captures_by_ref_ty,
395395
);
396+
397+
let ty::Coroutine(_, args) = *self.typeck_results.borrow().expr_ty(body.value).kind()
398+
else {
399+
bug!();
400+
};
401+
self.demand_eqtype(
402+
span,
403+
args.as_coroutine().kind_ty(),
404+
Ty::from_closure_kind(self.tcx, closure_kind),
405+
);
396406
}
397407

398408
self.log_closure_min_capture_info(closure_def_id, span);

compiler/rustc_middle/src/mir/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ pub struct CoroutineInfo<'tcx> {
262262
/// Coroutine drop glue. This field is populated after the state transform pass.
263263
pub coroutine_drop: Option<Body<'tcx>>,
264264

265+
/// The body of the coroutine, modified to take its upvars by move.
266+
/// TODO:
267+
pub by_move_body: Option<Body<'tcx>>,
268+
265269
/// The layout of a coroutine. This field is populated after the state transform pass.
266270
pub coroutine_layout: Option<CoroutineLayout<'tcx>>,
267271

@@ -281,6 +285,7 @@ impl<'tcx> CoroutineInfo<'tcx> {
281285
coroutine_kind,
282286
yield_ty: Some(yield_ty),
283287
resume_ty: Some(resume_ty),
288+
by_move_body: None,
284289
coroutine_drop: None,
285290
coroutine_layout: None,
286291
}

compiler/rustc_middle/src/mir/mono.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> {
403403
| InstanceDef::Virtual(..)
404404
| InstanceDef::ClosureOnceShim { .. }
405405
| InstanceDef::ConstructCoroutineInClosureShim { .. }
406+
| InstanceDef::CoroutineByMoveShim { .. }
406407
| InstanceDef::DropGlue(..)
407408
| InstanceDef::CloneShim(..)
408409
| InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/mir/visit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ macro_rules! make_mir_visitor {
346346
ty::InstanceDef::ThreadLocalShim(_def_id) |
347347
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
348348
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
349+
ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } |
349350
ty::InstanceDef::DropGlue(_def_id, None) => {}
350351

351352
ty::InstanceDef::FnPtrShim(_def_id, ty) |

compiler/rustc_middle/src/ty/instance.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ pub enum InstanceDef<'tcx> {
101101
target_kind: ty::ClosureKind,
102102
},
103103

104+
/// TODO:
105+
CoroutineByMoveShim { coroutine_def_id: DefId },
106+
104107
/// Compiler-generated accessor for thread locals which returns a reference to the thread local
105108
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
106109
/// native support.
@@ -186,6 +189,7 @@ impl<'tcx> InstanceDef<'tcx> {
186189
coroutine_closure_def_id: def_id,
187190
target_kind: _,
188191
}
192+
| ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id }
189193
| InstanceDef::DropGlue(def_id, _)
190194
| InstanceDef::CloneShim(def_id, _)
191195
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -206,6 +210,7 @@ impl<'tcx> InstanceDef<'tcx> {
206210
| InstanceDef::Intrinsic(..)
207211
| InstanceDef::ClosureOnceShim { .. }
208212
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
213+
| ty::InstanceDef::CoroutineByMoveShim { .. }
209214
| InstanceDef::DropGlue(..)
210215
| InstanceDef::CloneShim(..)
211216
| InstanceDef::FnPtrAddrShim(..) => None,
@@ -302,6 +307,7 @@ impl<'tcx> InstanceDef<'tcx> {
302307
| InstanceDef::DropGlue(_, Some(_)) => false,
303308
InstanceDef::ClosureOnceShim { .. }
304309
| InstanceDef::ConstructCoroutineInClosureShim { .. }
310+
| InstanceDef::CoroutineByMoveShim { .. }
305311
| InstanceDef::DropGlue(..)
306312
| InstanceDef::Item(_)
307313
| InstanceDef::Intrinsic(..)
@@ -340,6 +346,7 @@ fn fmt_instance(
340346
InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
341347
InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
342348
InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
349+
InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"),
343350
InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
344351
InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
345352
InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),
@@ -631,7 +638,19 @@ impl<'tcx> Instance<'tcx> {
631638
};
632639

633640
if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) {
634-
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
641+
let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind()
642+
else {
643+
bug!()
644+
};
645+
646+
if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() {
647+
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
648+
} else {
649+
Some(Instance {
650+
def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id },
651+
args,
652+
})
653+
}
635654
} else {
636655
// All other methods should be defaulted methods of the built-in trait.
637656
// This is important for `Iterator`'s combinators, but also useful for

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,6 +1681,7 @@ impl<'tcx> TyCtxt<'tcx> {
16811681
| ty::InstanceDef::Virtual(..)
16821682
| ty::InstanceDef::ClosureOnceShim { .. }
16831683
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
1684+
| ty::InstanceDef::CoroutineByMoveShim { .. }
16841685
| ty::InstanceDef::DropGlue(..)
16851686
| ty::InstanceDef::CloneShim(..)
16861687
| ty::InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/ty/sty.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,15 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
399399
self,
400400
tcx: TyCtxt<'tcx>,
401401
parent_args: &'tcx [GenericArg<'tcx>],
402+
kind_ty: Ty<'tcx>,
402403
coroutine_def_id: DefId,
403404
tupled_upvars_ty: Ty<'tcx>,
404405
) -> Ty<'tcx> {
405406
let coroutine_args = ty::CoroutineArgs::new(
406407
tcx,
407408
ty::CoroutineArgsParts {
408409
parent_args,
410+
kind_ty,
409411
resume_ty: self.resume_ty,
410412
yield_ty: self.yield_ty,
411413
return_ty: self.return_ty,
@@ -436,7 +438,13 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
436438
env_region,
437439
);
438440

439-
self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty)
441+
self.to_coroutine(
442+
tcx,
443+
parent_args,
444+
Ty::from_closure_kind(tcx, closure_kind),
445+
coroutine_def_id,
446+
tupled_upvars_ty,
447+
)
440448
}
441449

442450
/// Given a closure kind, compute the tupled upvars that the given coroutine would return.
@@ -488,6 +496,8 @@ pub struct CoroutineArgs<'tcx> {
488496
pub struct CoroutineArgsParts<'tcx> {
489497
/// This is the args of the typeck root.
490498
pub parent_args: &'tcx [GenericArg<'tcx>],
499+
// TODO: why
500+
pub kind_ty: Ty<'tcx>,
491501
pub resume_ty: Ty<'tcx>,
492502
pub yield_ty: Ty<'tcx>,
493503
pub return_ty: Ty<'tcx>,
@@ -506,6 +516,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
506516
pub fn new(tcx: TyCtxt<'tcx>, parts: CoroutineArgsParts<'tcx>) -> CoroutineArgs<'tcx> {
507517
CoroutineArgs {
508518
args: tcx.mk_args_from_iter(parts.parent_args.iter().copied().chain([
519+
parts.kind_ty.into(),
509520
parts.resume_ty.into(),
510521
parts.yield_ty.into(),
511522
parts.return_ty.into(),
@@ -519,16 +530,23 @@ impl<'tcx> CoroutineArgs<'tcx> {
519530
/// The ordering assumed here must match that used by `CoroutineArgs::new` above.
520531
fn split(self) -> CoroutineArgsParts<'tcx> {
521532
match self.args[..] {
522-
[ref parent_args @ .., resume_ty, yield_ty, return_ty, witness, tupled_upvars_ty] => {
523-
CoroutineArgsParts {
524-
parent_args,
525-
resume_ty: resume_ty.expect_ty(),
526-
yield_ty: yield_ty.expect_ty(),
527-
return_ty: return_ty.expect_ty(),
528-
witness: witness.expect_ty(),
529-
tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
530-
}
531-
}
533+
[
534+
ref parent_args @ ..,
535+
kind_ty,
536+
resume_ty,
537+
yield_ty,
538+
return_ty,
539+
witness,
540+
tupled_upvars_ty,
541+
] => CoroutineArgsParts {
542+
parent_args,
543+
kind_ty: kind_ty.expect_ty(),
544+
resume_ty: resume_ty.expect_ty(),
545+
yield_ty: yield_ty.expect_ty(),
546+
return_ty: return_ty.expect_ty(),
547+
witness: witness.expect_ty(),
548+
tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
549+
},
532550
_ => bug!("coroutine args missing synthetics"),
533551
}
534552
}
@@ -538,6 +556,11 @@ impl<'tcx> CoroutineArgs<'tcx> {
538556
self.split().parent_args
539557
}
540558

559+
// TODO:
560+
pub fn kind_ty(self) -> Ty<'tcx> {
561+
self.split().kind_ty
562+
}
563+
541564
/// This describes the types that can be contained in a coroutine.
542565
/// It will be a type variable initially and unified in the last stages of typeck of a body.
543566
/// It contains a tuple of all the types that could end up on a coroutine frame.
@@ -1628,7 +1651,7 @@ impl<'tcx> Ty<'tcx> {
16281651
) -> Ty<'tcx> {
16291652
debug_assert_eq!(
16301653
coroutine_args.len(),
1631-
tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5,
1654+
tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 6,
16321655
"coroutine constructed with incorrect number of substitutions"
16331656
);
16341657
Ty::new(tcx, Coroutine(def_id, coroutine_args))

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
5151
//! Otherwise it drops all the values in scope at the last suspension point.
5252
53+
mod by_move_body;
54+
pub use by_move_body::ByMoveBody;
55+
5356
use crate::abort_unwinding_calls;
5457
use crate::deref_separator::deref_finder;
5558
use crate::errors;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use rustc_data_structures::fx::FxIndexSet;
2+
use rustc_hir as hir;
3+
use rustc_middle::mir::visit::MutVisitor;
4+
use rustc_middle::mir::{self, MirPass};
5+
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt};
6+
use rustc_target::abi::FieldIdx;
7+
8+
pub struct ByMoveBody;
9+
10+
impl<'tcx> MirPass<'tcx> for ByMoveBody {
11+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
12+
let Some(coroutine_def_id) = body.source.def_id().as_local() else {
13+
return;
14+
};
15+
let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
16+
tcx.coroutine_kind(coroutine_def_id)
17+
else {
18+
return;
19+
};
20+
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
21+
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() };
22+
if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce {
23+
return;
24+
}
25+
26+
let mut by_ref_fields = FxIndexSet::default();
27+
let by_move_upvars = Ty::new_tup_from_iter(
28+
tcx,
29+
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
30+
if capture.is_by_ref() {
31+
by_ref_fields.insert(FieldIdx::from_usize(idx));
32+
}
33+
capture.place.ty()
34+
}),
35+
);
36+
let by_move_coroutine_ty = Ty::new_coroutine(
37+
tcx,
38+
coroutine_def_id.to_def_id(),
39+
ty::CoroutineArgs::new(
40+
tcx,
41+
ty::CoroutineArgsParts {
42+
parent_args: args.as_coroutine().parent_args(),
43+
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
44+
resume_ty: args.as_coroutine().resume_ty(),
45+
yield_ty: args.as_coroutine().yield_ty(),
46+
return_ty: args.as_coroutine().return_ty(),
47+
witness: args.as_coroutine().witness(),
48+
tupled_upvars_ty: by_move_upvars,
49+
},
50+
)
51+
.args,
52+
);
53+
54+
let mut by_move_body = body.clone();
55+
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
56+
by_move_body.source = mir::MirSource {
57+
instance: InstanceDef::CoroutineByMoveShim {
58+
coroutine_def_id: coroutine_def_id.to_def_id(),
59+
},
60+
promoted: None,
61+
};
62+
63+
body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
64+
}
65+
}
66+
67+
struct MakeByMoveBody<'tcx> {
68+
tcx: TyCtxt<'tcx>,
69+
by_ref_fields: FxIndexSet<FieldIdx>,
70+
by_move_coroutine_ty: Ty<'tcx>,
71+
}
72+
73+
impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
74+
fn tcx(&self) -> TyCtxt<'tcx> {
75+
self.tcx
76+
}
77+
78+
fn visit_place(
79+
&mut self,
80+
place: &mut mir::Place<'tcx>,
81+
context: mir::visit::PlaceContext,
82+
location: mir::Location,
83+
) {
84+
if place.local == ty::CAPTURE_STRUCT_LOCAL
85+
&& !place.projection.is_empty()
86+
&& let mir::ProjectionElem::Field(idx, ty) = place.projection[0]
87+
&& self.by_ref_fields.contains(&idx)
88+
{
89+
let (begin, end) = place.projection[1..].split_first().unwrap();
90+
assert_eq!(*begin, mir::ProjectionElem::Deref);
91+
*place = mir::Place {
92+
local: place.local,
93+
projection: self.tcx.mk_place_elems_from_iter(
94+
[mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)]
95+
.into_iter()
96+
.chain(end.iter().copied()),
97+
),
98+
};
99+
}
100+
self.super_place(place, context, location);
101+
}
102+
103+
fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
104+
if local == ty::CAPTURE_STRUCT_LOCAL {
105+
local_decl.ty = self.by_move_coroutine_ty;
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)