Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit bf0322c

Browse files
committed
pick the best ancestor expr of unsafe expr to add unsafe block. Thanks! @Veykril
1 parent 82780d8 commit bf0322c

File tree

1 file changed

+226
-40
lines changed

1 file changed

+226
-40
lines changed

crates/ide-diagnostics/src/handlers/missing_unsafe.rs

Lines changed: 226 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use hir::db::AstDatabase;
22
use ide_db::{assists::Assist, source_change::SourceChange};
3-
use syntax::ast::{ExprStmt, LetStmt};
43
use syntax::AstNode;
54
use syntax::{ast, SyntaxNode};
65
use text_edit::TextEdit;
@@ -23,7 +22,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
2322
let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?;
2423
let expr = d.expr.value.to_node(&root);
2524

26-
let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(ctx, &expr);
25+
let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr);
2726

2827
let replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text());
2928
let edit = TextEdit::replace(node_to_add_unsafe_block.text_range(), replacement);
@@ -32,39 +31,78 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
3231
Some(vec![fix("add_unsafe", "Add unsafe block", source_change, expr.syntax().text_range())])
3332
}
3433

35-
// Find the let statement or expression statement closest to the `expr` in the
36-
// ancestor chain.
37-
//
38-
// Why don't we just add an unsafe block around the `expr`?
39-
//
40-
// Consider this example:
41-
// ```
42-
// STATIC_MUT += 1;
43-
// ```
44-
// We can't add an unsafe block to the left-hand side of an assignment.
45-
// ```
46-
// unsafe { STATIC_MUT } += 1;
47-
// ```
48-
//
49-
// Or this example:
50-
// ```
51-
// let z = STATIC_MUT.a;
52-
// ```
53-
// We can't add an unsafe block like this:
54-
// ```
55-
// let z = unsafe { STATIC_MUT } .a;
56-
// ```
57-
fn pick_best_node_to_add_unsafe_block(
58-
ctx: &DiagnosticsContext<'_>,
59-
expr: &ast::Expr,
60-
) -> SyntaxNode {
61-
let Some(let_or_expr_stmt) = ctx.sema.ancestors_with_macros(expr.syntax().clone()).find(|node| {
62-
LetStmt::can_cast(node.kind()) || ExprStmt::can_cast(node.kind())
63-
}) else {
64-
// Is this reachable?
65-
return expr.syntax().clone();
66-
};
67-
let_or_expr_stmt
34+
// Pick the first ancestor expression of the unsafe `expr` that is not a
35+
// receiver of a method call, a field access, the left-hand side of an
36+
// assignment, or a reference. As all of those cases would incur a forced move
37+
// if wrapped which might not be wanted. That is:
38+
// - `unsafe_expr.foo` -> `unsafe { unsafe_expr.foo }`
39+
// - `unsafe_expr.foo.bar` -> `unsafe { unsafe_expr.foo.bar }`
40+
// - `unsafe_expr.foo()` -> `unsafe { unsafe_expr.foo() }`
41+
// - `unsafe_expr.foo.bar()` -> `unsafe { unsafe_expr.foo.bar() }`
42+
// - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }`
43+
// - `&unsafe_expr` -> `unsafe { &unsafe_expr }`
44+
// - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }`
45+
fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> SyntaxNode {
46+
// The `unsafe_expr` might be:
47+
// - `ast::CallExpr`: call an unsafe function
48+
// - `ast::MethodCallExpr`: call an unsafe method
49+
// - `ast::PrefixExpr`: dereference a raw pointer
50+
// - `ast::PathExpr`: access a static mut variable
51+
for node in unsafe_expr.syntax().ancestors() {
52+
let Some(parent) = node.parent() else {
53+
return node;
54+
};
55+
match parent.kind() {
56+
syntax::SyntaxKind::METHOD_CALL_EXPR => {
57+
// Check if the `node` is the receiver of the method call
58+
let method_call_expr = ast::MethodCallExpr::cast(parent.clone()).unwrap();
59+
if method_call_expr
60+
.receiver()
61+
.map(|receiver| {
62+
receiver.syntax().text_range().contains_range(node.text_range())
63+
})
64+
.unwrap_or(false)
65+
{
66+
// Actually, I think it's not necessary to check whether the
67+
// text range of the `node` (which is the ancestor of the
68+
// `unsafe_expr`) is contained in the text range of the
69+
// receiver. The `node` could potentially be the receiver, the
70+
// method name, or the argument list. Since the `node` is the
71+
// ancestor of the unsafe_expr, it cannot be the method name.
72+
// Additionally, if the `node` is the argument list, the loop
73+
// would break at least when `parent` reaches the argument list.
74+
//
75+
// Dispite this, I still check the text range because I think it
76+
// makes the code easier to understand.
77+
continue;
78+
}
79+
return node;
80+
}
81+
syntax::SyntaxKind::FIELD_EXPR | syntax::SyntaxKind::REF_EXPR => continue,
82+
syntax::SyntaxKind::BIN_EXPR => {
83+
// Check if the `node` is the left-hand side of an assignment
84+
let is_left_hand_side_of_assignment = {
85+
let bin_expr = ast::BinExpr::cast(parent.clone()).unwrap();
86+
if let Some(ast::BinaryOp::Assignment { .. }) = bin_expr.op_kind() {
87+
let is_left_hand_side = bin_expr
88+
.lhs()
89+
.map(|lhs| lhs.syntax().text_range().contains_range(node.text_range()))
90+
.unwrap_or(false);
91+
is_left_hand_side
92+
} else {
93+
false
94+
}
95+
};
96+
if !is_left_hand_side_of_assignment {
97+
return node;
98+
}
99+
}
100+
_ => {
101+
return node;
102+
}
103+
}
104+
}
105+
unsafe_expr.syntax().clone()
68106
}
69107

70108
#[cfg(test)]
@@ -168,7 +206,7 @@ fn main() {
168206
r#"
169207
fn main() {
170208
let x = &5 as *const usize;
171-
unsafe { let z = *x; }
209+
let z = unsafe { *x };
172210
}
173211
"#,
174212
);
@@ -192,7 +230,7 @@ unsafe fn func() {
192230
let z = *x;
193231
}
194232
fn main() {
195-
unsafe { func(); }
233+
unsafe { func() };
196234
}
197235
"#,
198236
)
@@ -224,7 +262,7 @@ impl S {
224262
}
225263
fn main() {
226264
let s = S(5);
227-
unsafe { s.func(); }
265+
unsafe { s.func() };
228266
}
229267
"#,
230268
)
@@ -252,7 +290,7 @@ struct Ty {
252290
static mut STATIC_MUT: Ty = Ty { a: 0 };
253291
254292
fn main() {
255-
unsafe { let x = STATIC_MUT.a; }
293+
let x = unsafe { STATIC_MUT.a };
256294
}
257295
"#,
258296
)
@@ -276,7 +314,155 @@ extern "rust-intrinsic" {
276314
}
277315
278316
fn main() {
279-
unsafe { let _ = floorf32(12.0); }
317+
let _ = unsafe { floorf32(12.0) };
318+
}
319+
"#,
320+
)
321+
}
322+
323+
#[test]
324+
fn unsafe_expr_as_a_receiver_of_a_method_call() {
325+
check_fix(
326+
r#"
327+
unsafe fn foo() -> String {
328+
"string".to_string()
329+
}
330+
331+
fn main() {
332+
foo$0().len();
333+
}
334+
"#,
335+
r#"
336+
unsafe fn foo() -> String {
337+
"string".to_string()
338+
}
339+
340+
fn main() {
341+
unsafe { foo().len() };
342+
}
343+
"#,
344+
)
345+
}
346+
347+
#[test]
348+
fn unsafe_expr_as_an_argument_of_a_method_call() {
349+
check_fix(
350+
r#"
351+
static mut STATIC_MUT: u8 = 0;
352+
353+
fn main() {
354+
let mut v = vec![];
355+
v.push(STATIC_MUT$0);
356+
}
357+
"#,
358+
r#"
359+
static mut STATIC_MUT: u8 = 0;
360+
361+
fn main() {
362+
let mut v = vec![];
363+
v.push(unsafe { STATIC_MUT });
364+
}
365+
"#,
366+
)
367+
}
368+
369+
#[test]
370+
fn unsafe_expr_as_left_hand_side_of_assignment() {
371+
check_fix(
372+
r#"
373+
static mut STATIC_MUT: u8 = 0;
374+
375+
fn main() {
376+
STATIC_MUT$0 = 1;
377+
}
378+
"#,
379+
r#"
380+
static mut STATIC_MUT: u8 = 0;
381+
382+
fn main() {
383+
unsafe { STATIC_MUT = 1 };
384+
}
385+
"#,
386+
)
387+
}
388+
389+
#[test]
390+
fn unsafe_expr_as_right_hand_side_of_assignment() {
391+
check_fix(
392+
r#"
393+
static mut STATIC_MUT: u8 = 0;
394+
395+
fn main() {
396+
let x;
397+
x = STATIC_MUT$0;
398+
}
399+
"#,
400+
r#"
401+
static mut STATIC_MUT: u8 = 0;
402+
403+
fn main() {
404+
let x;
405+
x = unsafe { STATIC_MUT };
406+
}
407+
"#,
408+
)
409+
}
410+
411+
#[test]
412+
fn unsafe_expr_in_binary_plus() {
413+
check_fix(
414+
r#"
415+
static mut STATIC_MUT: u8 = 0;
416+
417+
fn main() {
418+
let x = STATIC_MUT$0 + 1;
419+
}
420+
"#,
421+
r#"
422+
static mut STATIC_MUT: u8 = 0;
423+
424+
fn main() {
425+
let x = unsafe { STATIC_MUT } + 1;
426+
}
427+
"#,
428+
)
429+
}
430+
431+
#[test]
432+
fn ref_to_unsafe_expr() {
433+
check_fix(
434+
r#"
435+
static mut STATIC_MUT: u8 = 0;
436+
437+
fn main() {
438+
let x = &STATIC_MUT$0;
439+
}
440+
"#,
441+
r#"
442+
static mut STATIC_MUT: u8 = 0;
443+
444+
fn main() {
445+
let x = unsafe { &STATIC_MUT };
446+
}
447+
"#,
448+
)
449+
}
450+
451+
#[test]
452+
fn ref_ref_to_unsafe_expr() {
453+
check_fix(
454+
r#"
455+
static mut STATIC_MUT: u8 = 0;
456+
457+
fn main() {
458+
let x = &&STATIC_MUT$0;
459+
}
460+
"#,
461+
r#"
462+
static mut STATIC_MUT: u8 = 0;
463+
464+
fn main() {
465+
let x = unsafe { &&STATIC_MUT };
280466
}
281467
"#,
282468
)

0 commit comments

Comments
 (0)