Skip to content

Commit 3e0eb86

Browse files
Create an assist to convert closure to freestanding fn
The assist converts all captures to parameters.
1 parent 5c07dba commit 3e0eb86

File tree

25 files changed

+1584
-37
lines changed

25 files changed

+1584
-37
lines changed

crates/hir-def/src/lang_item.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ impl LangItemTarget {
7474
_ => None,
7575
}
7676
}
77+
78+
pub fn as_type_alias(self) -> Option<TypeAliasId> {
79+
match self {
80+
LangItemTarget::TypeAlias(id) => Some(id),
81+
_ => None,
82+
}
83+
}
7784
}
7885

7986
#[derive(Default, Debug, Clone, PartialEq, Eq)]
@@ -117,11 +124,19 @@ impl LangItems {
117124
match def {
118125
ModuleDefId::TraitId(trait_) => {
119126
lang_items.collect_lang_item(db, trait_, LangItemTarget::Trait);
120-
db.trait_data(trait_).items.iter().for_each(|&(_, assoc_id)| {
121-
if let AssocItemId::FunctionId(f) = assoc_id {
122-
lang_items.collect_lang_item(db, f, LangItemTarget::Function);
123-
}
124-
});
127+
db.trait_data(trait_).items.iter().for_each(
128+
|&(_, assoc_id)| match assoc_id {
129+
AssocItemId::FunctionId(f) => {
130+
lang_items.collect_lang_item(db, f, LangItemTarget::Function);
131+
}
132+
AssocItemId::TypeAliasId(alias) => lang_items.collect_lang_item(
133+
db,
134+
alias,
135+
LangItemTarget::TypeAlias,
136+
),
137+
AssocItemId::ConstId(_) => {}
138+
},
139+
);
125140
}
126141
ModuleDefId::AdtId(AdtId::EnumId(e)) => {
127142
lang_items.collect_lang_item(db, e, LangItemTarget::EnumId);
@@ -453,6 +468,7 @@ language_item_table! {
453468

454469
Context, sym::Context, context, Target::Struct, GenericRequirement::None;
455470
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
471+
FutureOutput, sym::future_output, future_output, Target::TypeAlias, GenericRequirement::None;
456472

457473
Option, sym::Option, option_type, Target::Enum, GenericRequirement::None;
458474
OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None;
@@ -467,6 +483,7 @@ language_item_table! {
467483
IntoFutureIntoFuture, sym::into_future, into_future_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
468484
IntoIterIntoIter, sym::into_iter, into_iter_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
469485
IteratorNext, sym::next, next_fn, Target::Method(MethodKind::Trait { body: false}), GenericRequirement::None;
486+
Iterator, sym::iterator, iterator, Target::Trait, GenericRequirement::None;
470487

471488
PinNewUnchecked, sym::new_unchecked, new_unchecked_fn, Target::Method(MethodKind::Inherent), GenericRequirement::None;
472489

crates/hir-ty/src/infer/closure.rs

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use intern::sym;
2020
use rustc_hash::FxHashMap;
2121
use smallvec::{smallvec, SmallVec};
2222
use stdx::never;
23+
use syntax::utils::is_raw_identifier;
2324

2425
use crate::{
2526
db::{HirDatabase, InternedClosure},
@@ -242,6 +243,11 @@ impl CapturedItem {
242243
self.place.local
243244
}
244245

246+
/// Returns whether this place has any field (aka. non-deref) projections.
247+
pub fn has_field_projections(&self) -> bool {
248+
self.place.projections.iter().any(|it| !matches!(it, ProjectionElem::Deref))
249+
}
250+
245251
pub fn ty(&self, subst: &Substitution) -> Ty {
246252
self.ty.clone().substitute(Interner, utils::ClosureSubst(subst).parent_subst())
247253
}
@@ -254,6 +260,103 @@ impl CapturedItem {
254260
self.span_stacks.iter().map(|stack| *stack.last().expect("empty span stack")).collect()
255261
}
256262

263+
/// Converts the place to a name that can be inserted into source code.
264+
pub fn place_to_name(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
265+
use std::fmt::Write;
266+
267+
let body = db.body(owner);
268+
let mut result = body[self.place.local].name.unescaped().display(db.upcast()).to_string();
269+
for proj in &self.place.projections {
270+
match proj {
271+
ProjectionElem::Deref => {}
272+
ProjectionElem::Field(Either::Left(f)) => {
273+
match &*f.parent.variant_data(db.upcast()) {
274+
VariantData::Record(fields) => {
275+
result.push_str("_");
276+
result.push_str(fields[f.local_id].name.as_str())
277+
}
278+
VariantData::Tuple(fields) => {
279+
let index = fields.iter().position(|it| it.0 == f.local_id);
280+
if let Some(index) = index {
281+
write!(result, "_{index}").unwrap();
282+
}
283+
}
284+
VariantData::Unit => {}
285+
}
286+
}
287+
ProjectionElem::Field(Either::Right(f)) => write!(result, "_{}", f.index).unwrap(),
288+
&ProjectionElem::ClosureField(field) => write!(result, "_{field}").unwrap(),
289+
ProjectionElem::Index(_)
290+
| ProjectionElem::ConstantIndex { .. }
291+
| ProjectionElem::Subslice { .. }
292+
| ProjectionElem::OpaqueCast(_) => {
293+
never!("Not happen in closure capture");
294+
continue;
295+
}
296+
}
297+
}
298+
if is_raw_identifier(&result, db.crate_graph()[owner.module(db.upcast()).krate()].edition) {
299+
result.insert_str(0, "r#");
300+
}
301+
result
302+
}
303+
304+
pub fn display_place_source_code(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
305+
use std::fmt::Write;
306+
307+
let body = db.body(owner);
308+
let krate = owner.krate(db.upcast());
309+
let edition = db.crate_graph()[krate].edition;
310+
let mut result = body[self.place.local].name.display(db.upcast(), edition).to_string();
311+
for proj in &self.place.projections {
312+
match proj {
313+
// In source code autoderef kicks in.
314+
ProjectionElem::Deref => {}
315+
ProjectionElem::Field(Either::Left(f)) => {
316+
let variant_data = f.parent.variant_data(db.upcast());
317+
match &*variant_data {
318+
VariantData::Record(fields) => write!(
319+
result,
320+
".{}",
321+
fields[f.local_id].name.display(db.upcast(), edition)
322+
)
323+
.unwrap(),
324+
VariantData::Tuple(fields) => write!(
325+
result,
326+
".{}",
327+
fields.iter().position(|it| it.0 == f.local_id).unwrap_or_default()
328+
)
329+
.unwrap(),
330+
VariantData::Unit => {}
331+
}
332+
}
333+
ProjectionElem::Field(Either::Right(f)) => {
334+
let field = f.index;
335+
write!(result, ".{field}").unwrap();
336+
}
337+
&ProjectionElem::ClosureField(field) => {
338+
write!(result, ".{field}").unwrap();
339+
}
340+
ProjectionElem::Index(_)
341+
| ProjectionElem::ConstantIndex { .. }
342+
| ProjectionElem::Subslice { .. }
343+
| ProjectionElem::OpaqueCast(_) => {
344+
never!("Not happen in closure capture");
345+
continue;
346+
}
347+
}
348+
}
349+
let final_derefs_count = self
350+
.place
351+
.projections
352+
.iter()
353+
.rev()
354+
.take_while(|proj| matches!(proj, ProjectionElem::Deref))
355+
.count();
356+
result.insert_str(0, &"*".repeat(final_derefs_count));
357+
result
358+
}
359+
257360
pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
258361
let body = db.body(owner);
259362
let krate = owner.krate(db.upcast());
@@ -442,14 +545,6 @@ impl InferenceContext<'_> {
442545
});
443546
}
444547

445-
fn is_ref_span(&self, span: MirSpan) -> bool {
446-
match span {
447-
MirSpan::ExprId(expr) => matches!(self.body[expr], Expr::Ref { .. }),
448-
MirSpan::BindingId(_) => true,
449-
MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false,
450-
}
451-
}
452-
453548
fn truncate_capture_spans(&self, capture: &mut CapturedItemWithoutTy, mut truncate_to: usize) {
454549
// The first span is the identifier, and it must always remain.
455550
truncate_to += 1;
@@ -458,15 +553,15 @@ impl InferenceContext<'_> {
458553
let mut actual_truncate_to = 0;
459554
for &span in &*span_stack {
460555
actual_truncate_to += 1;
461-
if !self.is_ref_span(span) {
556+
if !span.is_ref_span(self.body) {
462557
remained -= 1;
463558
if remained == 0 {
464559
break;
465560
}
466561
}
467562
}
468563
if actual_truncate_to < span_stack.len()
469-
&& self.is_ref_span(span_stack[actual_truncate_to])
564+
&& span_stack[actual_truncate_to].is_ref_span(self.body)
470565
{
471566
// Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect.
472567
actual_truncate_to += 1;
@@ -1140,7 +1235,7 @@ impl InferenceContext<'_> {
11401235
for capture in &mut captures {
11411236
if matches!(capture.kind, CaptureKind::ByValue) {
11421237
for span_stack in &mut capture.span_stacks {
1143-
if self.is_ref_span(span_stack[span_stack.len() - 1]) {
1238+
if span_stack[span_stack.len() - 1].is_ref_span(self.body) {
11441239
span_stack.truncate(span_stack.len() - 1);
11451240
}
11461241
}

crates/hir-ty/src/mir.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ use base_db::CrateId;
1616
use chalk_ir::Mutability;
1717
use either::Either;
1818
use hir_def::{
19-
hir::{BindingId, Expr, ExprId, Ordering, PatId},
19+
body::Body,
20+
hir::{BindingAnnotation, BindingId, Expr, ExprId, Ordering, PatId},
2021
DefWithBodyId, FieldId, StaticId, TupleFieldId, UnionId, VariantId,
2122
};
2223
use la_arena::{Arena, ArenaMap, Idx, RawIdx};
@@ -1174,6 +1175,20 @@ pub enum MirSpan {
11741175
Unknown,
11751176
}
11761177

1178+
impl MirSpan {
1179+
pub fn is_ref_span(&self, body: &Body) -> bool {
1180+
match *self {
1181+
MirSpan::ExprId(expr) => matches!(body[expr], Expr::Ref { .. }),
1182+
// FIXME: Figure out if this is correct wrt. match ergonomics.
1183+
MirSpan::BindingId(binding) => matches!(
1184+
body.bindings[binding].mode,
1185+
BindingAnnotation::Ref | BindingAnnotation::RefMut
1186+
),
1187+
MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false,
1188+
}
1189+
}
1190+
}
1191+
11771192
impl_from!(ExprId, PatId for MirSpan);
11781193

11791194
impl From<&ExprId> for MirSpan {

crates/hir/src/lib.rs

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ use hir_ty::{
7878
use itertools::Itertools;
7979
use nameres::diagnostics::DefDiagnosticKind;
8080
use rustc_hash::FxHashSet;
81-
use span::{Edition, EditionedFileId, FileId, MacroCallId};
81+
use smallvec::SmallVec;
82+
use span::{Edition, EditionedFileId, FileId, MacroCallId, SyntaxContextId};
8283
use stdx::{impl_from, never};
8384
use syntax::{
8485
ast::{self, HasAttrs as _, HasGenericParams, HasName},
@@ -4114,6 +4115,15 @@ impl ClosureCapture {
41144115
Local { parent: self.owner, binding_id: self.capture.local() }
41154116
}
41164117

4118+
/// Returns whether this place has any field (aka. non-deref) projections.
4119+
pub fn has_field_projections(&self) -> bool {
4120+
self.capture.has_field_projections()
4121+
}
4122+
4123+
pub fn usages(&self) -> CaptureUsages {
4124+
CaptureUsages { parent: self.owner, spans: self.capture.spans() }
4125+
}
4126+
41174127
pub fn kind(&self) -> CaptureKind {
41184128
match self.capture.kind() {
41194129
hir_ty::CaptureKind::ByRef(
@@ -4129,6 +4139,15 @@ impl ClosureCapture {
41294139
}
41304140
}
41314141

4142+
/// Converts the place to a name that can be inserted into source code.
4143+
pub fn place_to_name(&self, db: &dyn HirDatabase) -> String {
4144+
self.capture.place_to_name(self.owner, db)
4145+
}
4146+
4147+
pub fn display_place_source_code(&self, db: &dyn HirDatabase) -> String {
4148+
self.capture.display_place_source_code(self.owner, db)
4149+
}
4150+
41324151
pub fn display_place(&self, db: &dyn HirDatabase) -> String {
41334152
self.capture.display_place(self.owner, db)
41344153
}
@@ -4142,6 +4161,74 @@ pub enum CaptureKind {
41424161
Move,
41434162
}
41444163

4164+
#[derive(Debug, Clone)]
4165+
pub struct CaptureUsages {
4166+
parent: DefWithBodyId,
4167+
spans: SmallVec<[mir::MirSpan; 3]>,
4168+
}
4169+
4170+
impl CaptureUsages {
4171+
pub fn sources(&self, db: &dyn HirDatabase) -> Vec<CaptureUsageSource> {
4172+
let (body, source_map) = db.body_with_source_map(self.parent);
4173+
let mut result = Vec::with_capacity(self.spans.len());
4174+
for &span in self.spans.iter() {
4175+
let is_ref = span.is_ref_span(&body);
4176+
match span {
4177+
mir::MirSpan::ExprId(expr) => {
4178+
if let Ok(expr) = source_map.expr_syntax(expr) {
4179+
result.push(CaptureUsageSource {
4180+
is_ref,
4181+
source: expr.map(AstPtr::wrap_left),
4182+
})
4183+
}
4184+
}
4185+
mir::MirSpan::PatId(pat) => {
4186+
if let Ok(pat) = source_map.pat_syntax(pat) {
4187+
result.push(CaptureUsageSource {
4188+
is_ref,
4189+
source: pat.map(AstPtr::wrap_right),
4190+
});
4191+
}
4192+
}
4193+
mir::MirSpan::BindingId(binding) => result.extend(
4194+
source_map
4195+
.patterns_for_binding(binding)
4196+
.iter()
4197+
.filter_map(|&pat| source_map.pat_syntax(pat).ok())
4198+
.map(|pat| CaptureUsageSource {
4199+
is_ref,
4200+
source: pat.map(AstPtr::wrap_right),
4201+
}),
4202+
),
4203+
mir::MirSpan::SelfParam | mir::MirSpan::Unknown => {
4204+
unreachable!("invalid capture usage span")
4205+
}
4206+
}
4207+
}
4208+
result
4209+
}
4210+
}
4211+
4212+
#[derive(Debug)]
4213+
pub struct CaptureUsageSource {
4214+
is_ref: bool,
4215+
source: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>,
4216+
}
4217+
4218+
impl CaptureUsageSource {
4219+
pub fn source(&self) -> AstPtr<Either<ast::Expr, ast::Pat>> {
4220+
self.source.value.clone()
4221+
}
4222+
4223+
pub fn file_id(&self) -> HirFileId {
4224+
self.source.file_id
4225+
}
4226+
4227+
pub fn is_ref(&self) -> bool {
4228+
self.is_ref
4229+
}
4230+
}
4231+
41454232
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
41464233
pub struct Type {
41474234
env: Arc<TraitEnvironment>,
@@ -4380,6 +4467,22 @@ impl Type {
43804467
method_resolution::implements_trait(&canonical_ty, db, &self.env, trait_)
43814468
}
43824469

4470+
/// This does **not** resolve `IntoFuture`, only `Future`.
4471+
pub fn future_output(self, db: &dyn HirDatabase) -> Option<Type> {
4472+
let future_output =
4473+
db.lang_item(self.env.krate, LangItem::FutureOutput)?.as_type_alias()?;
4474+
self.normalize_trait_assoc_type(db, &[], future_output.into())
4475+
}
4476+
4477+
/// This does **not** resolve `IntoIterator`, only `Iterator`.
4478+
pub fn iterator_item(self, db: &dyn HirDatabase) -> Option<Type> {
4479+
let iterator_trait = db.lang_item(self.env.krate, LangItem::Iterator)?.as_trait()?;
4480+
let iterator_item = db
4481+
.trait_data(iterator_trait)
4482+
.associated_type_by_name(&Name::new_symbol(sym::Item.clone(), SyntaxContextId::ROOT))?;
4483+
self.normalize_trait_assoc_type(db, &[], iterator_item.into())
4484+
}
4485+
43834486
/// Checks that particular type `ty` implements `std::ops::FnOnce`.
43844487
///
43854488
/// This function can be used to check if a particular type is callable, since FnOnce is a

0 commit comments

Comments
 (0)