Skip to content

Commit c04480a

Browse files
Fix few bugs in closure capture computation, and add tests
Also create a test infrastructure for capture computation.
1 parent fa00326 commit c04480a

File tree

7 files changed

+302
-16
lines changed

7 files changed

+302
-16
lines changed

crates/hir-ty/src/infer/closure.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,12 @@ impl HirPlace {
208208
ty
209209
}
210210

211-
fn capture_kind_of_truncated_place(
212-
&self,
213-
mut current_capture: CaptureKind,
214-
len: usize,
215-
) -> CaptureKind {
211+
fn mut_to_closure_capture(&self, mut current_capture: CaptureKind) -> CaptureKind {
216212
if let CaptureKind::ByRef(BorrowKind::Mut {
217213
kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow,
218214
}) = current_capture
219215
{
220-
if self.projections[len..].iter().any(|it| *it == ProjectionElem::Deref) {
216+
if self.projections.iter().any(|it| *it == ProjectionElem::Deref) {
221217
current_capture =
222218
CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture });
223219
}
@@ -866,11 +862,8 @@ impl InferenceContext<'_> {
866862
};
867863
match prev_index {
868864
Some(p) => {
869-
let len = self.current_captures[p].place.projections.len();
870-
let kind_after_truncate =
871-
item.place.capture_kind_of_truncated_place(item.kind, len);
872865
self.current_captures[p].kind =
873-
cmp::max(kind_after_truncate, self.current_captures[p].kind);
866+
cmp::max(item.kind, self.current_captures[p].kind);
874867
}
875868
None => {
876869
hash_map.insert(item.place.clone(), self.current_captures.len());
@@ -1017,7 +1010,7 @@ impl InferenceContext<'_> {
10171010
unreachable!("Closure expression id is always closure");
10181011
};
10191012
self.consume_expr(*body);
1020-
for item in &self.current_captures {
1013+
for item in &mut self.current_captures {
10211014
if matches!(
10221015
item.kind,
10231016
CaptureKind::ByRef(BorrowKind::Mut {
@@ -1029,6 +1022,11 @@ impl InferenceContext<'_> {
10291022
// MIR. I didn't do that due duplicate diagnostics.
10301023
self.result.mutated_bindings_in_closure.insert(item.place.local);
10311024
}
1025+
1026+
// rustc does that only for some captures, not all, because it only uses unique captures
1027+
// in editions prior to 2021.
1028+
// We use them in all editions, so we need this here.
1029+
item.kind = item.place.mut_to_closure_capture(item.kind);
10321030
}
10331031
self.restrict_precision_for_unsafe();
10341032
// `closure_kind` should be done before adjust_for_move_closure

crates/hir-ty/src/mir.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ pub enum TerminatorKind {
636636
},
637637
}
638638

639+
// Order of variants in this enum matter: they are used to compare borrow kinds.
639640
#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
640641
pub enum BorrowKind {
641642
/// Data must be immutable and is aliasable.
@@ -666,15 +667,16 @@ pub enum BorrowKind {
666667
Mut { kind: MutBorrowKind },
667668
}
668669

670+
// Order of variants in this enum matter: they are used to compare borrow kinds.
669671
#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
670672
pub enum MutBorrowKind {
673+
/// Data must be immutable but not aliasable. This kind of borrow cannot currently
674+
/// be expressed by the user and is used only in implicit closure bindings.
675+
ClosureCapture,
671676
Default,
672677
/// This borrow arose from method-call auto-ref
673678
/// (i.e., adjustment::Adjust::Borrow).
674679
TwoPhasedBorrow,
675-
/// Data must be immutable but not aliasable. This kind of borrow cannot currently
676-
/// be expressed by the user and is used only in implicit closure bindings.
677-
ClosureCapture,
678680
}
679681

680682
impl BorrowKind {

crates/hir-ty/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod closure_captures;
12
mod coercion;
23
mod diagnostics;
34
mod display_source_code;
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
use base_db::salsa::InternKey;
2+
use expect_test::{expect, Expect};
3+
use hir_def::db::DefDatabase;
4+
use itertools::Itertools;
5+
use test_fixture::WithFixture;
6+
7+
use crate::db::{HirDatabase, InternedClosureId};
8+
use crate::display::HirDisplay;
9+
use crate::test_db::TestDB;
10+
11+
use super::visit_module;
12+
13+
fn check_closure_captures(ra_fixture: &str, expect: Expect) {
14+
let (db, file_id) = TestDB::with_single_file(ra_fixture);
15+
let module = db.module_for_file(file_id);
16+
let def_map = module.def_map(&db);
17+
18+
let mut defs = Vec::new();
19+
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
20+
21+
let mut captures_info = Vec::new();
22+
for def in defs {
23+
let infer = db.infer(def);
24+
let db = &db;
25+
captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| {
26+
let closure = db.lookup_intern_closure(InternedClosureId::from_intern_id(closure_id.0));
27+
let (_, source_map) = db.body_with_source_map(closure.0);
28+
let closure_text_range = source_map
29+
.expr_syntax(closure.1)
30+
.expect("failed to map closure to SyntaxNode")
31+
.value
32+
.text_range();
33+
captures.iter().flat_map(move |capture| {
34+
// FIXME: Deduplicate this with hir::Local::sources().
35+
let (body, source_map) = db.body_with_source_map(closure.0);
36+
let local_text_ranges = match body.self_param.zip(source_map.self_param_syntax()) {
37+
Some((param, source)) if param == capture.local() => {
38+
vec![source.file_syntax(db).text_range()]
39+
}
40+
_ => source_map
41+
.patterns_for_binding(capture.local())
42+
.iter()
43+
.map(|&definition| {
44+
let src = source_map.pat_syntax(definition).unwrap();
45+
src.file_syntax(db).text_range()
46+
})
47+
.collect(),
48+
};
49+
let place = capture.display_place(closure.0, db);
50+
let capture_ty = capture.ty.skip_binders().display_test(db).to_string();
51+
local_text_ranges.into_iter().map(move |local_text_range| {
52+
(
53+
closure_text_range,
54+
local_text_range,
55+
place.clone(),
56+
capture_ty.clone(),
57+
capture.kind(),
58+
)
59+
})
60+
})
61+
}));
62+
}
63+
captures_info.sort_unstable_by_key(|(closure_text_range, local_text_range, ..)| {
64+
(closure_text_range.start(), local_text_range.start())
65+
});
66+
67+
let rendered = captures_info
68+
.iter()
69+
.map(|(closure_text_range, local_text_range, place, capture_ty, capture_kind)| {
70+
format!(
71+
"{closure_text_range:?};{local_text_range:?} {capture_kind:?} {place} {capture_ty}"
72+
)
73+
})
74+
.join("\n");
75+
76+
expect.assert_eq(&rendered);
77+
}
78+
79+
#[test]
80+
fn deref_in_let() {
81+
check_closure_captures(
82+
r#"
83+
//- minicore:copy
84+
fn main() {
85+
let a = &mut true;
86+
let closure = || { let b = *a; };
87+
}
88+
"#,
89+
expect!["53..71;0..75 ByRef(Shared) *a &'? bool"],
90+
);
91+
}
92+
93+
#[test]
94+
fn deref_then_ref_pattern() {
95+
check_closure_captures(
96+
r#"
97+
//- minicore:copy
98+
fn main() {
99+
let a = &mut true;
100+
let closure = || { let &mut ref b = a; };
101+
}
102+
"#,
103+
expect!["53..79;0..83 ByRef(Shared) *a &'? bool"],
104+
);
105+
check_closure_captures(
106+
r#"
107+
//- minicore:copy
108+
fn main() {
109+
let a = &mut true;
110+
let closure = || { let &mut ref mut b = a; };
111+
}
112+
"#,
113+
expect!["53..83;0..87 ByRef(Mut { kind: ClosureCapture }) *a &'? mut bool"],
114+
);
115+
}
116+
117+
#[test]
118+
fn unique_borrow() {
119+
check_closure_captures(
120+
r#"
121+
//- minicore:copy
122+
fn main() {
123+
let a = &mut true;
124+
let closure = || { *a = false; };
125+
}
126+
"#,
127+
expect!["53..71;0..75 ByRef(Mut { kind: ClosureCapture }) *a &'? mut bool"],
128+
);
129+
}
130+
131+
#[test]
132+
fn deref_ref_mut() {
133+
check_closure_captures(
134+
r#"
135+
//- minicore:copy
136+
fn main() {
137+
let a = &mut true;
138+
let closure = || { let ref mut b = *a; };
139+
}
140+
"#,
141+
expect!["53..79;0..83 ByRef(Mut { kind: ClosureCapture }) *a &'? mut bool"],
142+
);
143+
}
144+
145+
#[test]
146+
fn let_else_not_consuming() {
147+
check_closure_captures(
148+
r#"
149+
//- minicore:copy
150+
fn main() {
151+
let a = &mut true;
152+
let closure = || { let _ = *a else { return; }; };
153+
}
154+
"#,
155+
expect!["53..88;0..92 ByRef(Shared) *a &'? bool"],
156+
);
157+
}
158+
159+
#[test]
160+
fn consume() {
161+
check_closure_captures(
162+
r#"
163+
//- minicore:copy
164+
struct NonCopy;
165+
fn main() {
166+
let a = NonCopy;
167+
let closure = || { let b = a; };
168+
}
169+
"#,
170+
expect!["67..84;0..88 ByValue a NonCopy"],
171+
);
172+
}
173+
174+
#[test]
175+
fn ref_to_upvar() {
176+
check_closure_captures(
177+
r#"
178+
//- minicore:copy
179+
struct NonCopy;
180+
fn main() {
181+
let mut a = NonCopy;
182+
let closure = || { let b = &a; };
183+
let closure = || { let c = &mut a; };
184+
}
185+
"#,
186+
expect![[r#"
187+
71..89;0..135 ByRef(Shared) a &'? NonCopy
188+
109..131;0..135 ByRef(Mut { kind: Default }) a &'? mut NonCopy"#]],
189+
);
190+
}
191+
192+
#[test]
193+
fn field() {
194+
check_closure_captures(
195+
r#"
196+
//- minicore:copy
197+
struct Foo { a: i32, b: i32 }
198+
fn main() {
199+
let a = Foo { a: 0, b: 0 };
200+
let closure = || { let b = a.a; };
201+
}
202+
"#,
203+
expect!["92..111;0..115 ByRef(Shared) a.a &'? i32"],
204+
);
205+
}
206+
207+
#[test]
208+
fn fields_different_mode() {
209+
check_closure_captures(
210+
r#"
211+
//- minicore:copy
212+
struct NonCopy;
213+
struct Foo { a: i32, b: i32, c: NonCopy, d: bool }
214+
fn main() {
215+
let mut a = Foo { a: 0, b: 0 };
216+
let closure = || {
217+
let b = &a.a;
218+
let c = &mut a.b;
219+
let d = a.c;
220+
};
221+
}
222+
"#,
223+
expect![[r#"
224+
133..212;0..216 ByRef(Shared) a.a &'? i32
225+
133..212;0..216 ByRef(Mut { kind: Default }) a.b &'? mut i32
226+
133..212;0..216 ByValue a.c NonCopy"#]],
227+
);
228+
}
229+
230+
#[test]
231+
fn autoref() {
232+
check_closure_captures(
233+
r#"
234+
//- minicore:copy
235+
struct Foo;
236+
impl Foo {
237+
fn imm(&self) {}
238+
fn mut_(&mut self) {}
239+
}
240+
fn main() {
241+
let mut a = Foo;
242+
let closure = || a.imm();
243+
let closure = || a.mut_();
244+
}
245+
"#,
246+
expect![[r#"
247+
123..133;0..168 ByRef(Shared) a &'? Foo
248+
153..164;0..168 ByRef(Mut { kind: Default }) a &'? mut Foo"#]],
249+
);
250+
}
251+
252+
#[test]
253+
fn captures_priority() {
254+
check_closure_captures(
255+
r#"
256+
//- minicore:copy
257+
struct NonCopy;
258+
fn main() {
259+
let mut a = &mut true;
260+
// Max ByRef(Mut { kind: Default })
261+
let closure = || {
262+
*a = false;
263+
let b = &mut a;
264+
};
265+
// Max ByRef(Mut { kind: ClosureCapture })
266+
let closure = || {
267+
let b = *a;
268+
let c = &mut *a;
269+
};
270+
// Max ByValue
271+
let mut a = NonCopy;
272+
let closure = || {
273+
let b = a;
274+
let c = &mut a;
275+
let d = &a;
276+
};
277+
}
278+
"#,
279+
expect![[r#"
280+
113..167;0..434 ByRef(Mut { kind: Default }) a &'? mut &'? mut bool
281+
234..293;0..434 ByRef(Mut { kind: ClosureCapture }) *a &'? mut bool
282+
357..430;0..434 ByValue a NonCopy"#]],
283+
);
284+
}

crates/hir/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4134,6 +4134,7 @@ impl ClosureCapture {
41344134
}
41354135
}
41364136

4137+
#[derive(Clone, Copy, PartialEq, Eq)]
41374138
pub enum CaptureKind {
41384139
SharedRef,
41394140
UniqueSharedRef,

crates/ide/src/hover/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ fn main() {
383383
384384
## Captures
385385
* `x.f1` by move
386-
* `(*x.f2.0.0).f` by mutable borrow
386+
* `(*x.f2.0.0).f` by unique immutable borrow ([read more](https://doc.rust-lang.org/stable/reference/types/closure.html#unique-immutable-borrows-in-captures))
387387
"#]],
388388
);
389389
check(

crates/ide/src/inlay_hints/closure_captures.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ fn main() {
181181
// ^ (
182182
// ^ &mut baz
183183
// ^ , $
184-
// ^ &mut *qux
184+
// ^ &unique *qux
185185
// ^ )
186186
baz = NonCopy;
187187
*qux = NonCopy;

0 commit comments

Comments
 (0)