1
- use hir:: db:: AstDatabase ;
1
+ use hir:: db:: ExpandDatabase ;
2
2
use ide_db:: { assists:: Assist , source_change:: SourceChange } ;
3
- use syntax:: AstNode ;
4
3
use syntax:: { ast, SyntaxNode } ;
4
+ use syntax:: { match_ast, AstNode } ;
5
5
use text_edit:: TextEdit ;
6
6
7
7
use crate :: { fix, Diagnostic , DiagnosticsContext } ;
@@ -19,10 +19,15 @@ pub(crate) fn missing_unsafe(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsaf
19
19
}
20
20
21
21
fn fixes ( ctx : & DiagnosticsContext < ' _ > , d : & hir:: MissingUnsafe ) -> Option < Vec < Assist > > {
22
+ // The fixit will not work correctly for macro expansions, so we don't offer it in that case.
23
+ if d. expr . file_id . is_macro ( ) {
24
+ return None ;
25
+ }
26
+
22
27
let root = ctx. sema . db . parse_or_expand ( d. expr . file_id ) ?;
23
28
let expr = d. expr . value . to_node ( & root) ;
24
29
25
- let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( & expr) ;
30
+ let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( & expr) ? ;
26
31
27
32
let replacement = format ! ( "unsafe {{ {} }}" , node_to_add_unsafe_block. text( ) ) ;
28
33
let edit = TextEdit :: replace ( node_to_add_unsafe_block. text_range ( ) , replacement) ;
@@ -42,72 +47,51 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
42
47
// - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }`
43
48
// - `&unsafe_expr` -> `unsafe { &unsafe_expr }`
44
49
// - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }`
45
- fn pick_best_node_to_add_unsafe_block ( unsafe_expr : & ast:: Expr ) -> SyntaxNode {
50
+ fn pick_best_node_to_add_unsafe_block ( unsafe_expr : & ast:: Expr ) -> Option < SyntaxNode > {
46
51
// The `unsafe_expr` might be:
47
52
// - `ast::CallExpr`: call an unsafe function
48
53
// - `ast::MethodCallExpr`: call an unsafe method
49
54
// - `ast::PrefixExpr`: dereference a raw pointer
50
55
// - `ast::PathExpr`: access a static mut variable
51
- for node in unsafe_expr. syntax ( ) . ancestors ( ) {
52
- let Some ( parent) = node. parent ( ) else {
53
- return node;
54
- } ;
55
- match parent. kind ( ) {
56
- syntax:: SyntaxKind :: METHOD_CALL_EXPR => {
57
- // Check if the `node` is the receiver of the method call
58
- let method_call_expr = ast:: MethodCallExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
59
- if method_call_expr
60
- . receiver ( )
61
- . map ( |receiver| {
62
- receiver. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) )
63
- } )
64
- . unwrap_or ( false )
65
- {
66
- // Actually, I think it's not necessary to check whether the
67
- // text range of the `node` (which is the ancestor of the
68
- // `unsafe_expr`) is contained in the text range of the
69
- // receiver. The `node` could potentially be the receiver, the
70
- // method name, or the argument list. Since the `node` is the
71
- // ancestor of the unsafe_expr, it cannot be the method name.
72
- // Additionally, if the `node` is the argument list, the loop
73
- // would break at least when `parent` reaches the argument list.
74
- //
75
- // Dispite this, I still check the text range because I think it
76
- // makes the code easier to understand.
77
- continue ;
78
- }
79
- return node;
80
- }
81
- syntax:: SyntaxKind :: FIELD_EXPR | syntax:: SyntaxKind :: REF_EXPR => continue ,
82
- syntax:: SyntaxKind :: BIN_EXPR => {
83
- // Check if the `node` is the left-hand side of an assignment
84
- let is_left_hand_side_of_assignment = {
85
- let bin_expr = ast:: BinExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
86
- if let Some ( ast:: BinaryOp :: Assignment { .. } ) = bin_expr. op_kind ( ) {
87
- let is_left_hand_side = bin_expr
88
- . lhs ( )
89
- . map ( |lhs| lhs. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) ) )
90
- . unwrap_or ( false ) ;
91
- is_left_hand_side
92
- } else {
93
- false
56
+ for ( node, parent) in
57
+ unsafe_expr. syntax ( ) . ancestors ( ) . zip ( unsafe_expr. syntax ( ) . ancestors ( ) . skip ( 1 ) )
58
+ {
59
+ match_ast ! {
60
+ match parent {
61
+ // If the `parent` is a `MethodCallExpr`, that means the `node`
62
+ // is the receiver of the method call, because only the receiver
63
+ // can be a direct child of a method call. The method name
64
+ // itself is not an expression but a `NameRef`, and an argument
65
+ // is a direct child of an `ArgList`.
66
+ ast:: MethodCallExpr ( _) => continue ,
67
+ ast:: FieldExpr ( _) => continue ,
68
+ ast:: RefExpr ( _) => continue ,
69
+ ast:: BinExpr ( it) => {
70
+ // Check if the `node` is the left-hand side of an
71
+ // assignment, if so, we don't want to wrap it in an unsafe
72
+ // block, e.g. `unsafe_expr += 1`
73
+ let is_left_hand_side_of_assignment = {
74
+ if let Some ( ast:: BinaryOp :: Assignment { .. } ) = it. op_kind( ) {
75
+ it. lhs( ) . map( |lhs| lhs. syntax( ) . text_range( ) . contains_range( node. text_range( ) ) ) . unwrap_or( false )
76
+ } else {
77
+ false
78
+ }
79
+ } ;
80
+ if !is_left_hand_side_of_assignment {
81
+ return Some ( node) ;
94
82
}
95
- } ;
96
- if !is_left_hand_side_of_assignment {
97
- return node;
98
- }
99
- }
100
- _ => {
101
- return node;
83
+ } ,
84
+ _ => { return Some ( node) ; }
85
+
102
86
}
103
87
}
104
88
}
105
- unsafe_expr . syntax ( ) . clone ( )
89
+ None
106
90
}
107
91
108
92
#[ cfg( test) ]
109
93
mod tests {
110
- use crate :: tests:: { check_diagnostics, check_fix} ;
94
+ use crate :: tests:: { check_diagnostics, check_fix, check_no_fix } ;
111
95
112
96
#[ test]
113
97
fn missing_unsafe_diagnostic_with_raw_ptr ( ) {
@@ -467,4 +451,19 @@ fn main() {
467
451
"# ,
468
452
)
469
453
}
454
+
455
+ #[ test]
456
+ fn unsafe_expr_in_macro_call ( ) {
457
+ check_no_fix (
458
+ r#"
459
+ unsafe fn foo() -> u8 {
460
+ 0
461
+ }
462
+
463
+ fn main() {
464
+ let x = format!("foo: {}", foo$0());
465
+ }
466
+ "# ,
467
+ )
468
+ }
470
469
}
0 commit comments