Skip to content

Fix short-if formatting #2778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use matches::rewrite_match;
use overflow;
use pairs::{rewrite_all_pairs, rewrite_pair, PairParts};
use patterns::{can_be_overflowed_pat, is_short_pattern, TuplePatField};
use rewrite::{Rewrite, RewriteContext};
use rewrite::{Rewrite, RewriteContext, RewriteStmt};
use shape::{Indent, Shape};
use spanned::Spanned;
use string::{rewrite_string, StringFormat};
Expand Down Expand Up @@ -455,7 +455,7 @@ fn rewrite_single_line_block(
) -> Option<String> {
if is_simple_block(block, attrs, context.codemap) {
let expr_shape = shape.offset_left(last_line_width(prefix))?;
let expr_str = block.stmts[0].rewrite(context, expr_shape)?;
let expr_str = block.stmts[0].rewrite(context, expr_shape, false)?;
let label_str = rewrite_label(label);
let result = format!("{}{}{{ {} }}", prefix, label_str, expr_str);
if result.len() <= shape.width && !result.contains('\n') {
Expand Down Expand Up @@ -531,8 +531,13 @@ fn rewrite_block(
result
}

impl Rewrite for ast::Stmt {
fn rewrite(&self, context: &RewriteContext, shape: Shape) -> Option<String> {
impl RewriteStmt for ast::Stmt {
fn rewrite(
&self,
context: &RewriteContext,
shape: Shape,
last_stmt_is_if: bool,
) -> Option<String> {
skip_out_of_file_lines_range!(context, self.span());

let result = match self.node {
Expand All @@ -545,7 +550,12 @@ impl Rewrite for ast::Stmt {
};

let shape = shape.sub_width(suffix.len())?;
format_expr(ex, ExprType::Statement, context, shape).map(|s| s + suffix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this is incorrect. We should only put the if expression in a single line if it is the last statement of the block.

let expr_type = if last_stmt_is_if {
ExprType::SubExpression
} else {
ExprType::Statement
};
format_expr(ex, expr_type, context, shape).map(|s| s + suffix)
}
ast::StmtKind::Mac(..) | ast::StmtKind::Item(..) => None,
};
Expand Down Expand Up @@ -634,11 +644,7 @@ fn to_control_flow(expr: &ast::Expr, expr_type: ExprType) -> Option<ControlFlow>
}

fn choose_matcher(pats: &[&ast::Pat]) -> &'static str {
if pats.is_empty() {
""
} else {
"let"
}
if pats.is_empty() { "" } else { "let" }
}

impl<'a> ControlFlow<'a> {
Expand Down Expand Up @@ -748,11 +754,12 @@ impl<'a> ControlFlow<'a> {

let new_width = width.checked_sub(pat_expr_str.len() + fixed_cost)?;
let expr = &self.block.stmts[0];
let if_str = expr.rewrite(context, Shape::legacy(new_width, Indent::empty()))?;
let if_str = expr.rewrite(context, Shape::legacy(new_width, Indent::empty()), false)?;

let new_width = new_width.checked_sub(if_str.len())?;
let else_expr = &else_node.stmts[0];
let else_str = else_expr.rewrite(context, Shape::legacy(new_width, Indent::empty()))?;
let else_str =
else_expr.rewrite(context, Shape::legacy(new_width, Indent::empty()), false)?;

if if_str.contains('\n') || else_str.contains('\n') {
return None;
Expand Down Expand Up @@ -1134,6 +1141,16 @@ pub fn stmt_is_expr(stmt: &ast::Stmt) -> bool {
}
}

pub(crate) fn stmt_is_if(stmt: &ast::Stmt) -> bool {
match stmt.node {
ast::StmtKind::Semi(ref e) | ast::StmtKind::Expr(ref e) => match e.node {
ast::ExprKind::If(..) => true,
_ => false,
},
_ => false,
}
}

pub fn is_unsafe_block(block: &ast::Block) -> bool {
if let ast::BlockCheckMode::Unsafe(..) = block.rules {
true
Expand Down
4 changes: 2 additions & 2 deletions src/items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use expr::{
use lists::{definitive_tactic, itemize_list, write_list, ListFormatting, ListItem, Separator};
use macros::{rewrite_macro, MacroPosition};
use overflow;
use rewrite::{Rewrite, RewriteContext};
use rewrite::{Rewrite, RewriteContext, RewriteStmt};
use shape::{Indent, Shape};
use spanned::Spanned;
use utils::*;
Expand Down Expand Up @@ -399,7 +399,7 @@ impl<'a> FmtVisitor<'a> {
.map(|s| s + suffix)
.or_else(|| Some(self.snippet(e.span).to_owned()))
}
None => stmt.rewrite(&self.get_context(), self.shape()),
None => stmt.rewrite(&self.get_context(), self.shape(), false),
}
} else {
None
Expand Down
9 changes: 9 additions & 0 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ pub trait Rewrite {
fn rewrite(&self, context: &RewriteContext, shape: Shape) -> Option<String>;
}

pub trait RewriteStmt {
fn rewrite(
&self,
context: &RewriteContext,
shape: Shape,
last_stmt_is_if: bool,
) -> Option<String>;
}

#[derive(Clone)]
pub struct RewriteContext<'a> {
pub parse_session: &'a ParseSess,
Expand Down
10 changes: 6 additions & 4 deletions src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ use attr::*;
use codemap::{LineRangeUtils, SpanUtils};
use comment::{CodeCharKind, CommentCodeSlices, FindUncommented};
use config::{BraceStyle, Config};
use expr::stmt_is_if;
use items::{
format_impl, format_trait, format_trait_alias, is_mod_decl, is_use_item,
rewrite_associated_impl_type, rewrite_associated_type, rewrite_extern_crate,
rewrite_type_alias, FnSig, StaticParts, StructParts,
};
use macros::{rewrite_macro, rewrite_macro_def, MacroPosition};
use rewrite::{Rewrite, RewriteContext};
use rewrite::{Rewrite, RewriteContext, RewriteStmt};
use shape::{Indent, Shape};
use spanned::Spanned;
use utils::{
Expand Down Expand Up @@ -79,7 +80,7 @@ impl<'b, 'a: 'b> FmtVisitor<'a> {
Shape::indented(self.block_indent, self.config)
}

fn visit_stmt(&mut self, stmt: &ast::Stmt) {
fn visit_stmt(&mut self, stmt: &ast::Stmt, last_stmt_is_if: bool) {
debug!(
"visit_stmt: {:?} {:?}",
self.codemap.lookup_char_pos(stmt.span.lo()),
Expand All @@ -94,7 +95,7 @@ impl<'b, 'a: 'b> FmtVisitor<'a> {
if contains_skip(get_attrs_from_stmt(stmt)) {
self.push_skipped_with_span(stmt.span());
} else {
let rewrite = stmt.rewrite(&self.get_context(), self.shape());
let rewrite = stmt.rewrite(&self.get_context(), self.shape(), last_stmt_is_if);
self.push_rewrite(stmt.span(), rewrite)
}
}
Expand Down Expand Up @@ -662,7 +663,8 @@ impl<'b, 'a: 'b> FmtVisitor<'a> {
.collect();

if items.is_empty() {
self.visit_stmt(&stmts[0]);
let last_stmt_is_if = stmts[0] == stmts[stmts.len() - 1] && stmt_is_if(&stmts[0]);
self.visit_stmt(&stmts[0], last_stmt_is_if);
self.walk_stmts(&stmts[1..]);
} else {
self.visit_items_with_reordering(&items);
Expand Down
40 changes: 40 additions & 0 deletions tests/source/one_line_if.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
fn plain_if(x: bool) -> u8 {
if x {
0
} else {
1
}
}

fn paren_if(x: bool) -> u8 {
(if x { 0 } else { 1 })
}

fn let_if(x: bool) -> u8 {
let x = if x {
foo()
} else {
bar()
};
x
}

fn return_if(x: bool) -> u8 {
return if x {
0
} else {
1
};
}

fn multi_if() {
use std::io;
if x { foo() } else { bar() }
if x { foo() } else { bar() }
}

fn middle_if() {
use std::io;
if x { foo() } else { bar() }
let x = 1;
}
36 changes: 36 additions & 0 deletions tests/target/one_line_if.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
fn plain_if(x: bool) -> u8 {
if x { 0 } else { 1 }
}

fn paren_if(x: bool) -> u8 {
(if x { 0 } else { 1 })
}

fn let_if(x: bool) -> u8 {
let x = if x { foo() } else { bar() };
x
}

fn return_if(x: bool) -> u8 {
return if x { 0 } else { 1 };
}

fn multi_if() {
use std::io;
if x {
foo()
} else {
bar()
}
if x { foo() } else { bar() }
}

fn middle_if() {
use std::io;
if x {
foo()
} else {
bar()
}
let x = 1;
}