Skip to content

Commit 9e53db2

Browse files
author
Côme ALLART
committed
refactor: use hir to test if a value is returned
1 parent 80a6868 commit 9e53db2

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

crates/ide_assists/src/handlers/generate_documentation_template.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ fn safety_builder(ast_func: &ast::Fn) -> Option<Vec<String>> {
159159
fn gen_ex_template(ast_func: &ast::Fn, ctx: &AssistContext) -> Option<Vec<String>> {
160160
let (mut lines, ex_helper) = gen_ex_start_helper(ast_func, ctx)?;
161161
// Call the function, check result
162-
if returns_a_value(ast_func) {
162+
if returns_a_value(ast_func, ctx) {
163163
if count_parameters(&ex_helper.param_list) < 3 {
164164
lines.push(format!("assert_eq!({}, );", ex_helper.function_call));
165165
} else {
@@ -183,7 +183,7 @@ fn gen_ex_template(ast_func: &ast::Fn, ctx: &AssistContext) -> Option<Vec<String
183183
/// `None` if the function has a `self` parameter but is not in an `impl`.
184184
fn gen_panic_ex_template(ast_func: &ast::Fn, ctx: &AssistContext) -> Option<Vec<String>> {
185185
let (mut lines, ex_helper) = gen_ex_start_helper(ast_func, ctx)?;
186-
match returns_a_value(ast_func) {
186+
match returns_a_value(ast_func, ctx) {
187187
true => lines.push(format!("let _ = {}; // panics", ex_helper.function_call)),
188188
false => lines.push(format!("{}; // panics", ex_helper.function_call)),
189189
}
@@ -424,11 +424,12 @@ fn return_type(ast_func: &ast::Fn) -> Option<ast::Type> {
424424
}
425425

426426
/// Helper function to determine if the function returns some data
427-
fn returns_a_value(ast_func: &ast::Fn) -> bool {
428-
match return_type(ast_func) {
429-
Some(ret_type) => !["()", "!"].contains(&ret_type.to_string().as_str()),
430-
None => false,
431-
}
427+
fn returns_a_value(ast_func: &ast::Fn, ctx: &AssistContext) -> bool {
428+
ctx.sema
429+
.to_def(ast_func)
430+
.map(|hir_func| hir_func.ret_type(ctx.db()))
431+
.map(|ret_ty| !ret_ty.is_unit() && !ret_ty.is_never())
432+
.unwrap_or(false)
432433
}
433434

434435
#[cfg(test)]

0 commit comments

Comments
 (0)