1
1
use hir:: db:: AstDatabase ;
2
2
use ide_db:: { assists:: Assist , source_change:: SourceChange } ;
3
- use syntax:: ast:: { ExprStmt , LetStmt } ;
4
3
use syntax:: AstNode ;
5
4
use syntax:: { ast, SyntaxNode } ;
6
5
use text_edit:: TextEdit ;
@@ -23,7 +22,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
23
22
let root = ctx. sema . db . parse_or_expand ( d. expr . file_id ) ?;
24
23
let expr = d. expr . value . to_node ( & root) ;
25
24
26
- let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( ctx , & expr) ;
25
+ let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( & expr) ;
27
26
28
27
let replacement = format ! ( "unsafe {{ {} }}" , node_to_add_unsafe_block. text( ) ) ;
29
28
let edit = TextEdit :: replace ( node_to_add_unsafe_block. text_range ( ) , replacement) ;
@@ -32,39 +31,78 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
32
31
Some ( vec ! [ fix( "add_unsafe" , "Add unsafe block" , source_change, expr. syntax( ) . text_range( ) ) ] )
33
32
}
34
33
35
- // Find the let statement or expression statement closest to the `expr` in the
36
- // ancestor chain.
37
- //
38
- // Why don't we just add an unsafe block around the `expr`?
39
- //
40
- // Consider this example:
41
- // ```
42
- // STATIC_MUT += 1;
43
- // ```
44
- // We can't add an unsafe block to the left-hand side of an assignment.
45
- // ```
46
- // unsafe { STATIC_MUT } += 1;
47
- // ```
48
- //
49
- // Or this example:
50
- // ```
51
- // let z = STATIC_MUT.a;
52
- // ```
53
- // We can't add an unsafe block like this:
54
- // ```
55
- // let z = unsafe { STATIC_MUT } .a;
56
- // ```
57
- fn pick_best_node_to_add_unsafe_block (
58
- ctx : & DiagnosticsContext < ' _ > ,
59
- expr : & ast:: Expr ,
60
- ) -> SyntaxNode {
61
- let Some ( let_or_expr_stmt) = ctx. sema . ancestors_with_macros ( expr. syntax ( ) . clone ( ) ) . find ( |node| {
62
- LetStmt :: can_cast ( node. kind ( ) ) || ExprStmt :: can_cast ( node. kind ( ) )
63
- } ) else {
64
- // Is this reachable?
65
- return expr. syntax ( ) . clone ( ) ;
66
- } ;
67
- let_or_expr_stmt
34
+ // Pick the first ancestor expression of the unsafe `expr` that is not a
35
+ // receiver of a method call, a field access, the left-hand side of an
36
+ // assignment, or a reference. As all of those cases would incur a forced move
37
+ // if wrapped which might not be wanted. That is:
38
+ // - `unsafe_expr.foo` -> `unsafe { unsafe_expr.foo }`
39
+ // - `unsafe_expr.foo.bar` -> `unsafe { unsafe_expr.foo.bar }`
40
+ // - `unsafe_expr.foo()` -> `unsafe { unsafe_expr.foo() }`
41
+ // - `unsafe_expr.foo.bar()` -> `unsafe { unsafe_expr.foo.bar() }`
42
+ // - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }`
43
+ // - `&unsafe_expr` -> `unsafe { &unsafe_expr }`
44
+ // - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }`
45
+ fn pick_best_node_to_add_unsafe_block ( unsafe_expr : & ast:: Expr ) -> SyntaxNode {
46
+ // The `unsafe_expr` might be:
47
+ // - `ast::CallExpr`: call an unsafe function
48
+ // - `ast::MethodCallExpr`: call an unsafe method
49
+ // - `ast::PrefixExpr`: dereference a raw pointer
50
+ // - `ast::PathExpr`: access a static mut variable
51
+ for node in unsafe_expr. syntax ( ) . ancestors ( ) {
52
+ let Some ( parent) = node. parent ( ) else {
53
+ return node;
54
+ } ;
55
+ match parent. kind ( ) {
56
+ syntax:: SyntaxKind :: METHOD_CALL_EXPR => {
57
+ // Check if the `node` is the receiver of the method call
58
+ let method_call_expr = ast:: MethodCallExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
59
+ if method_call_expr
60
+ . receiver ( )
61
+ . map ( |receiver| {
62
+ receiver. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) )
63
+ } )
64
+ . unwrap_or ( false )
65
+ {
66
+ // Actually, I think it's not necessary to check whether the
67
+ // text range of the `node` (which is the ancestor of the
68
+ // `unsafe_expr`) is contained in the text range of the
69
+ // receiver. The `node` could potentially be the receiver, the
70
+ // method name, or the argument list. Since the `node` is the
71
+ // ancestor of the unsafe_expr, it cannot be the method name.
72
+ // Additionally, if the `node` is the argument list, the loop
73
+ // would break at least when `parent` reaches the argument list.
74
+ //
75
+ // Dispite this, I still check the text range because I think it
76
+ // makes the code easier to understand.
77
+ continue ;
78
+ }
79
+ return node;
80
+ }
81
+ syntax:: SyntaxKind :: FIELD_EXPR | syntax:: SyntaxKind :: REF_EXPR => continue ,
82
+ syntax:: SyntaxKind :: BIN_EXPR => {
83
+ // Check if the `node` is the left-hand side of an assignment
84
+ let is_left_hand_side_of_assignment = {
85
+ let bin_expr = ast:: BinExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
86
+ if let Some ( ast:: BinaryOp :: Assignment { .. } ) = bin_expr. op_kind ( ) {
87
+ let is_left_hand_side = bin_expr
88
+ . lhs ( )
89
+ . map ( |lhs| lhs. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) ) )
90
+ . unwrap_or ( false ) ;
91
+ is_left_hand_side
92
+ } else {
93
+ false
94
+ }
95
+ } ;
96
+ if !is_left_hand_side_of_assignment {
97
+ return node;
98
+ }
99
+ }
100
+ _ => {
101
+ return node;
102
+ }
103
+ }
104
+ }
105
+ unsafe_expr. syntax ( ) . clone ( )
68
106
}
69
107
70
108
#[ cfg( test) ]
@@ -168,7 +206,7 @@ fn main() {
168
206
r#"
169
207
fn main() {
170
208
let x = &5 as *const usize;
171
- unsafe { let z = *x; }
209
+ let z = unsafe { *x };
172
210
}
173
211
"# ,
174
212
) ;
@@ -192,7 +230,7 @@ unsafe fn func() {
192
230
let z = *x;
193
231
}
194
232
fn main() {
195
- unsafe { func(); }
233
+ unsafe { func() };
196
234
}
197
235
"# ,
198
236
)
@@ -224,7 +262,7 @@ impl S {
224
262
}
225
263
fn main() {
226
264
let s = S(5);
227
- unsafe { s.func(); }
265
+ unsafe { s.func() };
228
266
}
229
267
"# ,
230
268
)
@@ -252,7 +290,7 @@ struct Ty {
252
290
static mut STATIC_MUT: Ty = Ty { a: 0 };
253
291
254
292
fn main() {
255
- unsafe { let x = STATIC_MUT.a; }
293
+ let x = unsafe { STATIC_MUT.a };
256
294
}
257
295
"# ,
258
296
)
@@ -276,7 +314,155 @@ extern "rust-intrinsic" {
276
314
}
277
315
278
316
fn main() {
279
- unsafe { let _ = floorf32(12.0); }
317
+ let _ = unsafe { floorf32(12.0) };
318
+ }
319
+ "# ,
320
+ )
321
+ }
322
+
323
+ #[ test]
324
+ fn unsafe_expr_as_a_receiver_of_a_method_call ( ) {
325
+ check_fix (
326
+ r#"
327
+ unsafe fn foo() -> String {
328
+ "string".to_string()
329
+ }
330
+
331
+ fn main() {
332
+ foo$0().len();
333
+ }
334
+ "# ,
335
+ r#"
336
+ unsafe fn foo() -> String {
337
+ "string".to_string()
338
+ }
339
+
340
+ fn main() {
341
+ unsafe { foo().len() };
342
+ }
343
+ "# ,
344
+ )
345
+ }
346
+
347
+ #[ test]
348
+ fn unsafe_expr_as_an_argument_of_a_method_call ( ) {
349
+ check_fix (
350
+ r#"
351
+ static mut STATIC_MUT: u8 = 0;
352
+
353
+ fn main() {
354
+ let mut v = vec![];
355
+ v.push(STATIC_MUT$0);
356
+ }
357
+ "# ,
358
+ r#"
359
+ static mut STATIC_MUT: u8 = 0;
360
+
361
+ fn main() {
362
+ let mut v = vec![];
363
+ v.push(unsafe { STATIC_MUT });
364
+ }
365
+ "# ,
366
+ )
367
+ }
368
+
369
+ #[ test]
370
+ fn unsafe_expr_as_left_hand_side_of_assignment ( ) {
371
+ check_fix (
372
+ r#"
373
+ static mut STATIC_MUT: u8 = 0;
374
+
375
+ fn main() {
376
+ STATIC_MUT$0 = 1;
377
+ }
378
+ "# ,
379
+ r#"
380
+ static mut STATIC_MUT: u8 = 0;
381
+
382
+ fn main() {
383
+ unsafe { STATIC_MUT = 1 };
384
+ }
385
+ "# ,
386
+ )
387
+ }
388
+
389
+ #[ test]
390
+ fn unsafe_expr_as_right_hand_side_of_assignment ( ) {
391
+ check_fix (
392
+ r#"
393
+ static mut STATIC_MUT: u8 = 0;
394
+
395
+ fn main() {
396
+ let x;
397
+ x = STATIC_MUT$0;
398
+ }
399
+ "# ,
400
+ r#"
401
+ static mut STATIC_MUT: u8 = 0;
402
+
403
+ fn main() {
404
+ let x;
405
+ x = unsafe { STATIC_MUT };
406
+ }
407
+ "# ,
408
+ )
409
+ }
410
+
411
+ #[ test]
412
+ fn unsafe_expr_in_binary_plus ( ) {
413
+ check_fix (
414
+ r#"
415
+ static mut STATIC_MUT: u8 = 0;
416
+
417
+ fn main() {
418
+ let x = STATIC_MUT$0 + 1;
419
+ }
420
+ "# ,
421
+ r#"
422
+ static mut STATIC_MUT: u8 = 0;
423
+
424
+ fn main() {
425
+ let x = unsafe { STATIC_MUT } + 1;
426
+ }
427
+ "# ,
428
+ )
429
+ }
430
+
431
+ #[ test]
432
+ fn ref_to_unsafe_expr ( ) {
433
+ check_fix (
434
+ r#"
435
+ static mut STATIC_MUT: u8 = 0;
436
+
437
+ fn main() {
438
+ let x = &STATIC_MUT$0;
439
+ }
440
+ "# ,
441
+ r#"
442
+ static mut STATIC_MUT: u8 = 0;
443
+
444
+ fn main() {
445
+ let x = unsafe { &STATIC_MUT };
446
+ }
447
+ "# ,
448
+ )
449
+ }
450
+
451
+ #[ test]
452
+ fn ref_ref_to_unsafe_expr ( ) {
453
+ check_fix (
454
+ r#"
455
+ static mut STATIC_MUT: u8 = 0;
456
+
457
+ fn main() {
458
+ let x = &&STATIC_MUT$0;
459
+ }
460
+ "# ,
461
+ r#"
462
+ static mut STATIC_MUT: u8 = 0;
463
+
464
+ fn main() {
465
+ let x = unsafe { &&STATIC_MUT };
280
466
}
281
467
"# ,
282
468
)
0 commit comments