Skip to content

Commit 9fa12de

Browse files
committed
Improve while_let_on_iterator suggestion inside an FnOnce closure
1 parent 54feac1 commit 9fa12de

File tree

6 files changed

+240
-57
lines changed

6 files changed

+240
-57
lines changed

clippy_lints/src/loops/while_let_on_iterator.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ use clippy_utils::diagnostics::span_lint_and_sugg;
33
use clippy_utils::higher;
44
use clippy_utils::source::snippet_with_applicability;
55
use clippy_utils::{
6-
get_enclosing_loop_or_closure, is_refutable, is_trait_method, match_def_path, paths, visitors::is_res_used,
6+
get_enclosing_loop_or_multi_call_closure, is_refutable, is_trait_method, match_def_path, paths,
7+
visitors::is_res_used,
78
};
89
use if_chain::if_chain;
910
use rustc_errors::Applicability;
1011
use rustc_hir::intravisit::{walk_expr, Visitor};
1112
use rustc_hir::{def::Res, Expr, ExprKind, HirId, Local, Mutability, PatKind, QPath, UnOp};
1213
use rustc_lint::LateContext;
14+
use rustc_middle::hir::nested_filter::OnlyBodies;
1315
use rustc_middle::ty::adjustment::Adjust;
1416
use rustc_span::{symbol::sym, Symbol};
1517

@@ -249,6 +251,11 @@ fn needs_mutable_borrow(cx: &LateContext<'_>, iter_expr: &IterExpr, loop_expr: &
249251
used_iter: bool,
250252
}
251253
impl<'tcx> Visitor<'tcx> for AfterLoopVisitor<'_, '_, 'tcx> {
254+
type NestedFilter = OnlyBodies;
255+
fn nested_visit_map(&mut self) -> Self::Map {
256+
self.cx.tcx.hir()
257+
}
258+
252259
fn visit_expr(&mut self, e: &'tcx Expr<'_>) {
253260
if self.used_iter {
254261
return;
@@ -283,6 +290,11 @@ fn needs_mutable_borrow(cx: &LateContext<'_>, iter_expr: &IterExpr, loop_expr: &
283290
used_after: bool,
284291
}
285292
impl<'a, 'b, 'tcx> Visitor<'tcx> for NestedLoopVisitor<'a, 'b, 'tcx> {
293+
type NestedFilter = OnlyBodies;
294+
fn nested_visit_map(&mut self) -> Self::Map {
295+
self.cx.tcx.hir()
296+
}
297+
286298
fn visit_local(&mut self, l: &'tcx Local<'_>) {
287299
if !self.after_loop {
288300
l.pat.each_binding_or_first(&mut |_, id, _, _| {
@@ -320,10 +332,7 @@ fn needs_mutable_borrow(cx: &LateContext<'_>, iter_expr: &IterExpr, loop_expr: &
320332
}
321333
}
322334

323-
if let Some(e) = get_enclosing_loop_or_closure(cx.tcx, loop_expr) {
324-
// The iterator expression will be used on the next iteration (for loops), or on the next call (for
325-
// closures) unless it is declared within the enclosing expression. TODO: Check for closures
326-
// used where an `FnOnce` type is expected.
335+
if let Some(e) = get_enclosing_loop_or_multi_call_closure(cx, loop_expr) {
327336
let local_id = match iter_expr.path {
328337
Res::Local(id) => id,
329338
_ => return true,

clippy_utils/src/lib.rs

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ use rustc_middle::ty::fast_reject::SimplifiedTypeGen::{
9393
ArraySimplifiedType, BoolSimplifiedType, CharSimplifiedType, FloatSimplifiedType, IntSimplifiedType,
9494
PtrSimplifiedType, SliceSimplifiedType, StrSimplifiedType, UintSimplifiedType,
9595
};
96-
use rustc_middle::ty::{layout::IntegerExt, BorrowKind, DefIdTree, Ty, TyCtxt, TypeAndMut, TypeFoldable, UpvarCapture};
96+
use rustc_middle::ty::{
97+
layout::IntegerExt, BorrowKind, ClosureKind, DefIdTree, Ty, TyCtxt, TypeAndMut, TypeFoldable, UpvarCapture,
98+
};
9799
use rustc_middle::ty::{FloatTy, IntTy, UintTy};
98100
use rustc_semver::RustcVersion;
99101
use rustc_session::Session;
@@ -105,7 +107,7 @@ use rustc_span::{Span, DUMMY_SP};
105107
use rustc_target::abi::Integer;
106108

107109
use crate::consts::{constant, Constant};
108-
use crate::ty::{can_partially_move_ty, is_copy, is_recursively_primitive_type};
110+
use crate::ty::{can_partially_move_ty, expr_sig, is_copy, is_recursively_primitive_type, ty_is_fn_once_param};
109111
use crate::visitors::expr_visitor_no_bodies;
110112

111113
pub fn parse_msrv(msrv: &str, sess: Option<&Session>, span: Option<Span>) -> Option<RustcVersion> {
@@ -1197,16 +1199,54 @@ pub fn get_enclosing_block<'tcx>(cx: &LateContext<'tcx>, hir_id: HirId) -> Optio
11971199
}
11981200

11991201
/// Gets the loop or closure enclosing the given expression, if any.
1200-
pub fn get_enclosing_loop_or_closure<'tcx>(tcx: TyCtxt<'tcx>, expr: &Expr<'_>) -> Option<&'tcx Expr<'tcx>> {
1201-
for (_, node) in tcx.hir().parent_iter(expr.hir_id) {
1202+
pub fn get_enclosing_loop_or_multi_call_closure<'tcx>(
1203+
cx: &LateContext<'tcx>,
1204+
expr: &Expr<'_>,
1205+
) -> Option<&'tcx Expr<'tcx>> {
1206+
for (_, node) in cx.tcx.hir().parent_iter(expr.hir_id) {
12021207
match node {
1203-
Node::Expr(
1204-
e @ Expr {
1205-
kind: ExprKind::Loop(..) | ExprKind::Closure { .. },
1206-
..
1208+
Node::Expr(e) => match e.kind {
1209+
ExprKind::Closure { .. } => {
1210+
if let rustc_ty::Closure(_, subs) = cx.typeck_results().expr_ty(e).kind()
1211+
&& subs.as_closure().kind() == ClosureKind::FnOnce
1212+
{
1213+
continue;
1214+
}
1215+
let is_once = walk_to_expr_usage(cx, e, |node, id| {
1216+
let Node::Expr(e) = node else {
1217+
return None;
1218+
};
1219+
match e.kind {
1220+
ExprKind::Call(f, _) if f.hir_id == id => Some(()),
1221+
ExprKind::Call(f, args) => {
1222+
let i = args.iter().position(|arg| arg.hir_id == id)?;
1223+
let sig = expr_sig(cx, f)?;
1224+
let predicates = sig
1225+
.predicates_id()
1226+
.map_or(cx.param_env, |id| cx.tcx.param_env(id))
1227+
.caller_bounds();
1228+
sig.input(i).and_then(|ty| {
1229+
ty_is_fn_once_param(cx.tcx, ty.skip_binder(), predicates).then_some(())
1230+
})
1231+
},
1232+
ExprKind::MethodCall(_, args, _) => {
1233+
let i = args.iter().position(|arg| arg.hir_id == id)?;
1234+
let id = cx.typeck_results().type_dependent_def_id(e.hir_id)?;
1235+
let ty = cx.tcx.fn_sig(id).skip_binder().inputs()[i];
1236+
ty_is_fn_once_param(cx.tcx, ty, cx.tcx.param_env(id).caller_bounds()).then_some(())
1237+
},
1238+
_ => None,
1239+
}
1240+
})
1241+
.is_some();
1242+
if !is_once {
1243+
return Some(e);
1244+
}
12071245
},
1208-
) => return Some(e),
1209-
Node::Expr(_) | Node::Stmt(_) | Node::Block(_) | Node::Local(_) | Node::Arm(_) => (),
1246+
ExprKind::Loop(..) => return Some(e),
1247+
_ => (),
1248+
},
1249+
Node::Stmt(_) | Node::Block(_) | Node::Local(_) | Node::Arm(_) => (),
12101250
_ => break,
12111251
}
12121252
}

clippy_utils/src/ty.rs

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ pub fn all_predicates_of(tcx: TyCtxt<'_>, id: DefId) -> impl Iterator<Item = &(P
501501
/// A signature for a function like type.
502502
#[derive(Clone, Copy)]
503503
pub enum ExprFnSig<'tcx> {
504-
Sig(Binder<'tcx, FnSig<'tcx>>),
504+
Sig(Binder<'tcx, FnSig<'tcx>>, Option<DefId>),
505505
Closure(Option<&'tcx FnDecl<'tcx>>, Binder<'tcx, FnSig<'tcx>>),
506506
Trait(Binder<'tcx, Ty<'tcx>>, Option<Binder<'tcx, Ty<'tcx>>>),
507507
}
@@ -510,7 +510,7 @@ impl<'tcx> ExprFnSig<'tcx> {
510510
/// bounds only for variadic functions, otherwise this will panic.
511511
pub fn input(self, i: usize) -> Option<Binder<'tcx, Ty<'tcx>>> {
512512
match self {
513-
Self::Sig(sig) => {
513+
Self::Sig(sig, _) => {
514514
if sig.c_variadic() {
515515
sig.inputs().map_bound(|inputs| inputs.get(i).copied()).transpose()
516516
} else {
@@ -527,7 +527,7 @@ impl<'tcx> ExprFnSig<'tcx> {
527527
/// functions, otherwise this will panic.
528528
pub fn input_with_hir(self, i: usize) -> Option<(Option<&'tcx hir::Ty<'tcx>>, Binder<'tcx, Ty<'tcx>>)> {
529529
match self {
530-
Self::Sig(sig) => {
530+
Self::Sig(sig, _) => {
531531
if sig.c_variadic() {
532532
sig.inputs()
533533
.map_bound(|inputs| inputs.get(i).copied())
@@ -549,16 +549,20 @@ impl<'tcx> ExprFnSig<'tcx> {
549549
/// specified.
550550
pub fn output(self) -> Option<Binder<'tcx, Ty<'tcx>>> {
551551
match self {
552-
Self::Sig(sig) | Self::Closure(_, sig) => Some(sig.output()),
552+
Self::Sig(sig, _) | Self::Closure(_, sig) => Some(sig.output()),
553553
Self::Trait(_, output) => output,
554554
}
555555
}
556+
557+
pub fn predicates_id(&self) -> Option<DefId> {
558+
if let ExprFnSig::Sig(_, id) = *self { id } else { None }
559+
}
556560
}
557561

558562
/// If the expression is function like, get the signature for it.
559563
pub fn expr_sig<'tcx>(cx: &LateContext<'tcx>, expr: &Expr<'_>) -> Option<ExprFnSig<'tcx>> {
560564
if let Res::Def(DefKind::Fn | DefKind::Ctor(_, CtorKind::Fn) | DefKind::AssocFn, id) = path_res(cx, expr) {
561-
Some(ExprFnSig::Sig(cx.tcx.fn_sig(id)))
565+
Some(ExprFnSig::Sig(cx.tcx.fn_sig(id), Some(id)))
562566
} else {
563567
ty_sig(cx, cx.typeck_results().expr_ty_adjusted(expr).peel_refs())
564568
}
@@ -575,9 +579,9 @@ fn ty_sig<'tcx>(cx: &LateContext<'tcx>, ty: Ty<'tcx>) -> Option<ExprFnSig<'tcx>>
575579
.and_then(|id| cx.tcx.hir().fn_decl_by_hir_id(cx.tcx.hir().local_def_id_to_hir_id(id)));
576580
Some(ExprFnSig::Closure(decl, subs.as_closure().sig()))
577581
},
578-
ty::FnDef(id, subs) => Some(ExprFnSig::Sig(cx.tcx.bound_fn_sig(id).subst(cx.tcx, subs))),
582+
ty::FnDef(id, subs) => Some(ExprFnSig::Sig(cx.tcx.bound_fn_sig(id).subst(cx.tcx, subs), Some(id))),
579583
ty::Opaque(id, _) => ty_sig(cx, cx.tcx.type_of(id)),
580-
ty::FnPtr(sig) => Some(ExprFnSig::Sig(sig)),
584+
ty::FnPtr(sig) => Some(ExprFnSig::Sig(sig, None)),
581585
ty::Dynamic(bounds, _) => {
582586
let lang_items = cx.tcx.lang_items();
583587
match bounds.principal() {
@@ -793,3 +797,33 @@ pub fn variant_of_res<'tcx>(cx: &LateContext<'tcx>, res: Res) -> Option<&'tcx Va
793797
_ => None,
794798
}
795799
}
800+
801+
/// Checks if the type is a type parameter implementing `FnOnce`, but not `FnMut`.
802+
pub fn ty_is_fn_once_param<'tcx>(tcx: TyCtxt<'_>, ty: Ty<'tcx>, predicates: &'tcx [Predicate<'_>]) -> bool {
803+
let ty::Param(ty) = *ty.kind() else {
804+
return false;
805+
};
806+
let lang = tcx.lang_items();
807+
let (Some(fn_once_id), Some(fn_mut_id), Some(fn_id))
808+
= (lang.fn_once_trait(), lang.fn_mut_trait(), lang.fn_trait())
809+
else {
810+
return false;
811+
};
812+
predicates
813+
.iter()
814+
.try_fold(false, |found, p| {
815+
if let PredicateKind::Trait(p) = p.kind().skip_binder()
816+
&& let ty::Param(self_ty) = p.trait_ref.self_ty().kind()
817+
&& ty.index == self_ty.index
818+
{
819+
// This should use `super_traits_of`, but that's a private function.
820+
if p.trait_ref.def_id == fn_once_id {
821+
return Some(true);
822+
} else if p.trait_ref.def_id == fn_mut_id || p.trait_ref.def_id == fn_id {
823+
return None;
824+
}
825+
}
826+
Some(found)
827+
})
828+
.unwrap_or(false)
829+
}

tests/ui/while_let_on_iterator.fixed

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
unused_mut,
88
dead_code,
99
clippy::equatable_if_let,
10-
clippy::manual_find
10+
clippy::manual_find,
11+
clippy::redundant_closure_call
1112
)]
1213

1314
fn base() {
@@ -259,7 +260,7 @@ fn issue1924() {
259260
fn f(&mut self) -> Option<u32> {
260261
// Used as a field.
261262
for i in self.0.by_ref() {
262-
if !(3..=7).contains(&i) {
263+
if !(3..8).contains(&i) {
263264
return Some(i);
264265
}
265266
}
@@ -403,6 +404,47 @@ fn issue_8113() {
403404
}
404405
}
405406

407+
fn fn_once_closure() {
408+
let mut it = 0..10;
409+
(|| {
410+
for x in it {
411+
if x % 2 == 0 {
412+
break;
413+
}
414+
}
415+
})();
416+
417+
fn f(_: impl FnOnce()) {}
418+
let mut it = 0..10;
419+
f(|| {
420+
for x in it {
421+
if x % 2 == 0 {
422+
break;
423+
}
424+
}
425+
});
426+
427+
fn f2(_: impl FnMut()) {}
428+
let mut it = 0..10;
429+
f2(|| {
430+
for x in it.by_ref() {
431+
if x % 2 == 0 {
432+
break;
433+
}
434+
}
435+
});
436+
437+
fn f3(_: fn()) {}
438+
f3(|| {
439+
let mut it = 0..10;
440+
for x in it {
441+
if x % 2 == 0 {
442+
break;
443+
}
444+
}
445+
})
446+
}
447+
406448
fn main() {
407449
let mut it = 0..20;
408450
for _ in it {

tests/ui/while_let_on_iterator.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
unused_mut,
88
dead_code,
99
clippy::equatable_if_let,
10-
clippy::manual_find
10+
clippy::manual_find,
11+
clippy::redundant_closure_call
1112
)]
1213

1314
fn base() {
@@ -259,7 +260,7 @@ fn issue1924() {
259260
fn f(&mut self) -> Option<u32> {
260261
// Used as a field.
261262
while let Some(i) = self.0.next() {
262-
if i < 3 || i > 7 {
263+
if !(3..8).contains(&i) {
263264
return Some(i);
264265
}
265266
}
@@ -403,6 +404,47 @@ fn issue_8113() {
403404
}
404405
}
405406

407+
fn fn_once_closure() {
408+
let mut it = 0..10;
409+
(|| {
410+
while let Some(x) = it.next() {
411+
if x % 2 == 0 {
412+
break;
413+
}
414+
}
415+
})();
416+
417+
fn f(_: impl FnOnce()) {}
418+
let mut it = 0..10;
419+
f(|| {
420+
while let Some(x) = it.next() {
421+
if x % 2 == 0 {
422+
break;
423+
}
424+
}
425+
});
426+
427+
fn f2(_: impl FnMut()) {}
428+
let mut it = 0..10;
429+
f2(|| {
430+
while let Some(x) = it.next() {
431+
if x % 2 == 0 {
432+
break;
433+
}
434+
}
435+
});
436+
437+
fn f3(_: fn()) {}
438+
f3(|| {
439+
let mut it = 0..10;
440+
while let Some(x) = it.next() {
441+
if x % 2 == 0 {
442+
break;
443+
}
444+
}
445+
})
446+
}
447+
406448
fn main() {
407449
let mut it = 0..20;
408450
while let Some(..) = it.next() {

0 commit comments

Comments
 (0)