Skip to content

Commit 2e13aed

Browse files
committed
feat: support cross module imports
1 parent 136a9db commit 2e13aed

File tree

1 file changed

+214
-12
lines changed

1 file changed

+214
-12
lines changed

crates/ide-assists/src/handlers/bool_to_enum.rs

Lines changed: 214 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
use hir::ModuleDef;
12
use ide_db::{
23
assists::{AssistId, AssistKind},
34
defs::Definition,
4-
search::{FileReference, SearchScope, UsageSearchResult},
5+
helpers::mod_path_to_ast,
6+
imports::insert_use::{insert_use, ImportScope},
7+
search::{FileReference, UsageSearchResult},
58
source_change::SourceChangeBuilder,
69
};
10+
use itertools::Itertools;
711
use syntax::{
812
ast::{
913
self,
@@ -48,6 +52,7 @@ use crate::assist_context::{AssistContext, Assists};
4852
pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
4953
let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
5054
find_bool_node(ctx)?;
55+
let target_module = ctx.sema.scope(&target_node)?.module();
5156

5257
let target = name.syntax().text_range();
5358
acc.add(
@@ -64,13 +69,10 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
6469
replace_bool_expr(edit, initializer);
6570
}
6671

67-
let usages = definition
68-
.usages(&ctx.sema)
69-
.in_scope(&SearchScope::single_file(ctx.file_id()))
70-
.all();
71-
replace_usages(edit, &usages);
72+
let usages = definition.usages(&ctx.sema).all();
7273

73-
add_enum_def(edit, ctx, &usages, target_node);
74+
add_enum_def(edit, ctx, &usages, target_node, &target_module);
75+
replace_usages(edit, ctx, &usages, &target_module);
7476
},
7577
)
7678
}
@@ -186,8 +188,45 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
186188
}
187189

188190
/// Replaces all usages of the target identifier, both when read and written to.
189-
fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) {
190-
for (_, references) in usages.iter() {
191+
fn replace_usages(
192+
edit: &mut SourceChangeBuilder,
193+
ctx: &AssistContext<'_>,
194+
usages: &UsageSearchResult,
195+
target_module: &hir::Module,
196+
) {
197+
for (file_id, references) in usages.iter() {
198+
edit.edit_file(*file_id);
199+
200+
// add imports across modules where needed
201+
references
202+
.iter()
203+
.filter_map(|FileReference { name, .. }| {
204+
ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module()))
205+
})
206+
.unique_by(|name_and_module| name_and_module.1)
207+
.filter(|(_, module)| module != target_module)
208+
.filter_map(|(name, module)| {
209+
let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
210+
let mod_path = module.find_use_path_prefixed(
211+
ctx.sema.db,
212+
ModuleDef::Module(*target_module),
213+
ctx.config.insert_use.prefix_kind,
214+
ctx.config.prefer_no_std,
215+
);
216+
import_scope.zip(mod_path)
217+
})
218+
.for_each(|(import_scope, mod_path)| {
219+
let import_scope = match import_scope {
220+
ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
221+
ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
222+
ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
223+
};
224+
let path =
225+
make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool"));
226+
insert_use(&import_scope, path, &ctx.config.insert_use);
227+
});
228+
229+
// replace the usages in expressions
191230
references
192231
.into_iter()
193232
.filter_map(|FileReference { range, name, .. }| match name {
@@ -213,7 +252,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) {
213252
let record_field = edit.make_mut(record_field);
214253
let enum_expr = bool_expr_to_enum_expr(initializer);
215254
record_field.replace_expr(enum_expr);
216-
} else if name_ref.syntax().ancestors().find_map(ast::Expr::cast).is_some() {
255+
} else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
217256
// for any other usage in an expression, replace it with a check that it is the true variant
218257
edit.replace(range, format!("{} == Bool::True", name_ref.text()));
219258
}
@@ -255,8 +294,15 @@ fn add_enum_def(
255294
ctx: &AssistContext<'_>,
256295
usages: &UsageSearchResult,
257296
target_node: SyntaxNode,
297+
target_module: &hir::Module,
258298
) {
259-
let make_enum_pub = usages.iter().any(|(file_id, _)| file_id != &ctx.file_id());
299+
let make_enum_pub = usages
300+
.iter()
301+
.flat_map(|(_, refs)| refs)
302+
.filter_map(|FileReference { name, .. }| {
303+
ctx.sema.scope(name.syntax()).map(|scope| scope.module())
304+
})
305+
.any(|module| &module != target_module);
260306
let enum_def = make_bool_enum(make_enum_pub);
261307

262308
let indent = IndentLevel::from_node(&target_node);
@@ -649,7 +695,7 @@ fn main() {
649695
"#,
650696
r#"
651697
#[derive(PartialEq, Eq)]
652-
enum $0Bool { True, False }
698+
enum Bool { True, False }
653699
654700
struct Foo {
655701
bar: Bool,
@@ -713,6 +759,162 @@ fn main() {
713759
)
714760
}
715761

762+
#[test]
763+
fn const_in_module() {
764+
check_assist(
765+
bool_to_enum,
766+
r#"
767+
fn main() {
768+
if foo::FOO {
769+
println!("foo");
770+
}
771+
}
772+
773+
mod foo {
774+
pub const $0FOO: bool = true;
775+
}
776+
"#,
777+
r#"
778+
use foo::Bool;
779+
780+
fn main() {
781+
if foo::FOO == Bool::True {
782+
println!("foo");
783+
}
784+
}
785+
786+
mod foo {
787+
#[derive(PartialEq, Eq)]
788+
pub enum Bool { True, False }
789+
790+
pub const FOO: Bool = Bool::True;
791+
}
792+
"#,
793+
)
794+
}
795+
796+
#[test]
797+
fn const_in_module_with_import() {
798+
check_assist(
799+
bool_to_enum,
800+
r#"
801+
fn main() {
802+
use foo::FOO;
803+
804+
if FOO {
805+
println!("foo");
806+
}
807+
}
808+
809+
mod foo {
810+
pub const $0FOO: bool = true;
811+
}
812+
"#,
813+
r#"
814+
use crate::foo::Bool;
815+
816+
fn main() {
817+
use foo::FOO;
818+
819+
if FOO == Bool::True {
820+
println!("foo");
821+
}
822+
}
823+
824+
mod foo {
825+
#[derive(PartialEq, Eq)]
826+
pub enum Bool { True, False }
827+
828+
pub const FOO: Bool = Bool::True;
829+
}
830+
"#,
831+
)
832+
}
833+
834+
#[test]
835+
fn const_cross_file() {
836+
check_assist(
837+
bool_to_enum,
838+
r#"
839+
//- /main.rs
840+
mod foo;
841+
842+
fn main() {
843+
if foo::FOO {
844+
println!("foo");
845+
}
846+
}
847+
848+
//- /foo.rs
849+
pub const $0FOO: bool = true;
850+
"#,
851+
r#"
852+
//- /main.rs
853+
use foo::Bool;
854+
855+
mod foo;
856+
857+
fn main() {
858+
if foo::FOO == Bool::True {
859+
println!("foo");
860+
}
861+
}
862+
863+
//- /foo.rs
864+
#[derive(PartialEq, Eq)]
865+
pub enum Bool { True, False }
866+
867+
pub const FOO: Bool = Bool::True;
868+
"#,
869+
)
870+
}
871+
872+
#[test]
873+
fn const_cross_file_and_module() {
874+
check_assist(
875+
bool_to_enum,
876+
r#"
877+
//- /main.rs
878+
mod foo;
879+
880+
fn main() {
881+
use foo::bar;
882+
883+
if bar::BAR {
884+
println!("foo");
885+
}
886+
}
887+
888+
//- /foo.rs
889+
pub mod bar {
890+
pub const $0BAR: bool = false;
891+
}
892+
"#,
893+
r#"
894+
//- /main.rs
895+
use crate::foo::bar::Bool;
896+
897+
mod foo;
898+
899+
fn main() {
900+
use foo::bar;
901+
902+
if bar::BAR == Bool::True {
903+
println!("foo");
904+
}
905+
}
906+
907+
//- /foo.rs
908+
pub mod bar {
909+
#[derive(PartialEq, Eq)]
910+
pub enum Bool { True, False }
911+
912+
pub const BAR: Bool = Bool::False;
913+
}
914+
"#,
915+
)
916+
}
917+
716918
#[test]
717919
fn const_non_bool() {
718920
cov_mark::check!(not_applicable_non_bool_const);

0 commit comments

Comments
 (0)