1
+ use hir:: ModuleDef ;
1
2
use ide_db:: {
2
3
assists:: { AssistId , AssistKind } ,
3
4
defs:: Definition ,
4
- search:: { FileReference , SearchScope , UsageSearchResult } ,
5
+ helpers:: mod_path_to_ast,
6
+ imports:: insert_use:: { insert_use, ImportScope } ,
7
+ search:: { FileReference , UsageSearchResult } ,
5
8
source_change:: SourceChangeBuilder ,
6
9
} ;
10
+ use itertools:: Itertools ;
7
11
use syntax:: {
8
12
ast:: {
9
13
self ,
@@ -48,6 +52,7 @@ use crate::assist_context::{AssistContext, Assists};
48
52
pub ( crate ) fn bool_to_enum ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
49
53
let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
50
54
find_bool_node ( ctx) ?;
55
+ let target_module = ctx. sema . scope ( & target_node) ?. module ( ) ;
51
56
52
57
let target = name. syntax ( ) . text_range ( ) ;
53
58
acc. add (
@@ -64,13 +69,10 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
64
69
replace_bool_expr ( edit, initializer) ;
65
70
}
66
71
67
- let usages = definition
68
- . usages ( & ctx. sema )
69
- . in_scope ( & SearchScope :: single_file ( ctx. file_id ( ) ) )
70
- . all ( ) ;
71
- replace_usages ( edit, & usages) ;
72
+ let usages = definition. usages ( & ctx. sema ) . all ( ) ;
72
73
73
- add_enum_def ( edit, ctx, & usages, target_node) ;
74
+ add_enum_def ( edit, ctx, & usages, target_node, & target_module) ;
75
+ replace_usages ( edit, ctx, & usages, & target_module) ;
74
76
} ,
75
77
)
76
78
}
@@ -186,8 +188,45 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
186
188
}
187
189
188
190
/// Replaces all usages of the target identifier, both when read and written to.
189
- fn replace_usages ( edit : & mut SourceChangeBuilder , usages : & UsageSearchResult ) {
190
- for ( _, references) in usages. iter ( ) {
191
+ fn replace_usages (
192
+ edit : & mut SourceChangeBuilder ,
193
+ ctx : & AssistContext < ' _ > ,
194
+ usages : & UsageSearchResult ,
195
+ target_module : & hir:: Module ,
196
+ ) {
197
+ for ( file_id, references) in usages. iter ( ) {
198
+ edit. edit_file ( * file_id) ;
199
+
200
+ // add imports across modules where needed
201
+ references
202
+ . iter ( )
203
+ . filter_map ( |FileReference { name, .. } | {
204
+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( name, scope. module ( ) ) )
205
+ } )
206
+ . unique_by ( |name_and_module| name_and_module. 1 )
207
+ . filter ( |( _, module) | module != target_module)
208
+ . filter_map ( |( name, module) | {
209
+ let import_scope = ImportScope :: find_insert_use_container ( name. syntax ( ) , & ctx. sema ) ;
210
+ let mod_path = module. find_use_path_prefixed (
211
+ ctx. sema . db ,
212
+ ModuleDef :: Module ( * target_module) ,
213
+ ctx. config . insert_use . prefix_kind ,
214
+ ctx. config . prefer_no_std ,
215
+ ) ;
216
+ import_scope. zip ( mod_path)
217
+ } )
218
+ . for_each ( |( import_scope, mod_path) | {
219
+ let import_scope = match import_scope {
220
+ ImportScope :: File ( it) => ImportScope :: File ( edit. make_mut ( it) ) ,
221
+ ImportScope :: Module ( it) => ImportScope :: Module ( edit. make_mut ( it) ) ,
222
+ ImportScope :: Block ( it) => ImportScope :: Block ( edit. make_mut ( it) ) ,
223
+ } ;
224
+ let path =
225
+ make:: path_concat ( mod_path_to_ast ( & mod_path) , make:: path_from_text ( "Bool" ) ) ;
226
+ insert_use ( & import_scope, path, & ctx. config . insert_use ) ;
227
+ } ) ;
228
+
229
+ // replace the usages in expressions
191
230
references
192
231
. into_iter ( )
193
232
. filter_map ( |FileReference { range, name, .. } | match name {
@@ -213,7 +252,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) {
213
252
let record_field = edit. make_mut ( record_field) ;
214
253
let enum_expr = bool_expr_to_enum_expr ( initializer) ;
215
254
record_field. replace_expr ( enum_expr) ;
216
- } else if name_ref. syntax ( ) . ancestors ( ) . find_map ( ast:: Expr :: cast) . is_some ( ) {
255
+ } else if name_ref. syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
217
256
// for any other usage in an expression, replace it with a check that it is the true variant
218
257
edit. replace ( range, format ! ( "{} == Bool::True" , name_ref. text( ) ) ) ;
219
258
}
@@ -255,8 +294,15 @@ fn add_enum_def(
255
294
ctx : & AssistContext < ' _ > ,
256
295
usages : & UsageSearchResult ,
257
296
target_node : SyntaxNode ,
297
+ target_module : & hir:: Module ,
258
298
) {
259
- let make_enum_pub = usages. iter ( ) . any ( |( file_id, _) | file_id != & ctx. file_id ( ) ) ;
299
+ let make_enum_pub = usages
300
+ . iter ( )
301
+ . flat_map ( |( _, refs) | refs)
302
+ . filter_map ( |FileReference { name, .. } | {
303
+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| scope. module ( ) )
304
+ } )
305
+ . any ( |module| & module != target_module) ;
260
306
let enum_def = make_bool_enum ( make_enum_pub) ;
261
307
262
308
let indent = IndentLevel :: from_node ( & target_node) ;
@@ -649,7 +695,7 @@ fn main() {
649
695
"# ,
650
696
r#"
651
697
#[derive(PartialEq, Eq)]
652
- enum $0Bool { True, False }
698
+ enum Bool { True, False }
653
699
654
700
struct Foo {
655
701
bar: Bool,
@@ -713,6 +759,162 @@ fn main() {
713
759
)
714
760
}
715
761
762
+ #[ test]
763
+ fn const_in_module ( ) {
764
+ check_assist (
765
+ bool_to_enum,
766
+ r#"
767
+ fn main() {
768
+ if foo::FOO {
769
+ println!("foo");
770
+ }
771
+ }
772
+
773
+ mod foo {
774
+ pub const $0FOO: bool = true;
775
+ }
776
+ "# ,
777
+ r#"
778
+ use foo::Bool;
779
+
780
+ fn main() {
781
+ if foo::FOO == Bool::True {
782
+ println!("foo");
783
+ }
784
+ }
785
+
786
+ mod foo {
787
+ #[derive(PartialEq, Eq)]
788
+ pub enum Bool { True, False }
789
+
790
+ pub const FOO: Bool = Bool::True;
791
+ }
792
+ "# ,
793
+ )
794
+ }
795
+
796
+ #[ test]
797
+ fn const_in_module_with_import ( ) {
798
+ check_assist (
799
+ bool_to_enum,
800
+ r#"
801
+ fn main() {
802
+ use foo::FOO;
803
+
804
+ if FOO {
805
+ println!("foo");
806
+ }
807
+ }
808
+
809
+ mod foo {
810
+ pub const $0FOO: bool = true;
811
+ }
812
+ "# ,
813
+ r#"
814
+ use crate::foo::Bool;
815
+
816
+ fn main() {
817
+ use foo::FOO;
818
+
819
+ if FOO == Bool::True {
820
+ println!("foo");
821
+ }
822
+ }
823
+
824
+ mod foo {
825
+ #[derive(PartialEq, Eq)]
826
+ pub enum Bool { True, False }
827
+
828
+ pub const FOO: Bool = Bool::True;
829
+ }
830
+ "# ,
831
+ )
832
+ }
833
+
834
+ #[ test]
835
+ fn const_cross_file ( ) {
836
+ check_assist (
837
+ bool_to_enum,
838
+ r#"
839
+ //- /main.rs
840
+ mod foo;
841
+
842
+ fn main() {
843
+ if foo::FOO {
844
+ println!("foo");
845
+ }
846
+ }
847
+
848
+ //- /foo.rs
849
+ pub const $0FOO: bool = true;
850
+ "# ,
851
+ r#"
852
+ //- /main.rs
853
+ use foo::Bool;
854
+
855
+ mod foo;
856
+
857
+ fn main() {
858
+ if foo::FOO == Bool::True {
859
+ println!("foo");
860
+ }
861
+ }
862
+
863
+ //- /foo.rs
864
+ #[derive(PartialEq, Eq)]
865
+ pub enum Bool { True, False }
866
+
867
+ pub const FOO: Bool = Bool::True;
868
+ "# ,
869
+ )
870
+ }
871
+
872
+ #[ test]
873
+ fn const_cross_file_and_module ( ) {
874
+ check_assist (
875
+ bool_to_enum,
876
+ r#"
877
+ //- /main.rs
878
+ mod foo;
879
+
880
+ fn main() {
881
+ use foo::bar;
882
+
883
+ if bar::BAR {
884
+ println!("foo");
885
+ }
886
+ }
887
+
888
+ //- /foo.rs
889
+ pub mod bar {
890
+ pub const $0BAR: bool = false;
891
+ }
892
+ "# ,
893
+ r#"
894
+ //- /main.rs
895
+ use crate::foo::bar::Bool;
896
+
897
+ mod foo;
898
+
899
+ fn main() {
900
+ use foo::bar;
901
+
902
+ if bar::BAR == Bool::True {
903
+ println!("foo");
904
+ }
905
+ }
906
+
907
+ //- /foo.rs
908
+ pub mod bar {
909
+ #[derive(PartialEq, Eq)]
910
+ pub enum Bool { True, False }
911
+
912
+ pub const BAR: Bool = Bool::False;
913
+ }
914
+ "# ,
915
+ )
916
+ }
917
+
716
918
#[ test]
717
919
fn const_non_bool ( ) {
718
920
cov_mark:: check!( not_applicable_non_bool_const) ;
0 commit comments