Skip to content

Commit b37317b

Browse files
committed
Check if there are any overlapping patterns between equal arm bodies in match_same_arm
1 parent d23ddab commit b37317b

File tree

2 files changed

+256
-10
lines changed

2 files changed

+256
-10
lines changed

clippy_lints/src/matches/match_same_arms.rs

Lines changed: 243 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,53 @@
11
use clippy_utils::diagnostics::span_lint_and_then;
22
use clippy_utils::source::snippet;
33
use clippy_utils::{path_to_local, search_same, SpanlessEq, SpanlessHash};
4-
use rustc_hir::{Arm, Expr, HirId, HirIdMap, HirIdSet, Pat, PatKind};
4+
use rustc_ast::ast::LitKind;
5+
use rustc_hir::def_id::DefId;
6+
use rustc_hir::{Arm, Expr, ExprKind, HirId, HirIdMap, HirIdSet, Pat, PatKind, RangeEnd};
57
use rustc_lint::LateContext;
8+
use rustc_span::Symbol;
69
use std::collections::hash_map::Entry;
710

811
use super::MATCH_SAME_ARMS;
912

10-
pub(crate) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) {
13+
pub(super) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) {
1114
let hash = |&(_, arm): &(usize, &Arm<'_>)| -> u64 {
1215
let mut h = SpanlessHash::new(cx);
1316
h.hash_expr(arm.body);
1417
h.finish()
1518
};
1619

20+
let resolved_pats: Vec<_> = arms.iter().map(|a| ResolvedPat::from_pat(cx, a.pat)).collect();
21+
22+
// The furthast forwards a pattern can move without semantic changes
23+
let forwards_blocking_idxs: Vec<_> = resolved_pats
24+
.iter()
25+
.enumerate()
26+
.map(|(i, pat)| {
27+
resolved_pats[i + 1..]
28+
.iter()
29+
.enumerate()
30+
.find_map(|(j, other)| pat.can_also_match(other).then(|| i + 1 + j))
31+
.unwrap_or(resolved_pats.len())
32+
})
33+
.collect();
34+
35+
// The furthast backwards a pattern can move without semantic changes
36+
let backwards_blocking_idxs: Vec<_> = resolved_pats
37+
.iter()
38+
.enumerate()
39+
.map(|(i, pat)| {
40+
resolved_pats[..i]
41+
.iter()
42+
.enumerate()
43+
.rev()
44+
.zip(forwards_blocking_idxs[..i].iter().copied().rev())
45+
.skip_while(|&(_, forward_block)| forward_block > i)
46+
.find_map(|((j, other), forward_block)| (forward_block == i || pat.can_also_match(other)).then(|| j))
47+
.unwrap_or(0)
48+
})
49+
.collect();
50+
1751
let eq = |&(lindex, lhs): &(usize, &Arm<'_>), &(rindex, rhs): &(usize, &Arm<'_>)| -> bool {
1852
let min_index = usize::min(lindex, rindex);
1953
let max_index = usize::max(lindex, rindex);
@@ -42,14 +76,16 @@ pub(crate) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) {
4276
}
4377
};
4478
// Arms with a guard are ignored, those can’t always be merged together
45-
// This is also the case for arms in-between each there is an arm with a guard
46-
(min_index..=max_index).all(|index| arms[index].guard.is_none())
47-
&& SpanlessEq::new(cx)
48-
.expr_fallback(eq_fallback)
49-
.eq_expr(lhs.body, rhs.body)
50-
// these checks could be removed to allow unused bindings
51-
&& bindings_eq(lhs.pat, local_map.keys().copied().collect())
52-
&& bindings_eq(rhs.pat, local_map.values().copied().collect())
79+
// If both arms overlap with an arm in between then these can't be merged either.
80+
!(backwards_blocking_idxs[max_index] > min_index && forwards_blocking_idxs[min_index] < max_index)
81+
&& lhs.guard.is_none()
82+
&& rhs.guard.is_none()
83+
&& SpanlessEq::new(cx)
84+
.expr_fallback(eq_fallback)
85+
.eq_expr(lhs.body, rhs.body)
86+
// these checks could be removed to allow unused bindings
87+
&& bindings_eq(lhs.pat, local_map.keys().copied().collect())
88+
&& bindings_eq(rhs.pat, local_map.values().copied().collect())
5389
};
5490

5591
let indexed_arms: Vec<(usize, &Arm<'_>)> = arms.iter().enumerate().collect();
@@ -92,6 +128,203 @@ pub(crate) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) {
92128
}
93129
}
94130

131+
#[derive(Debug)]
132+
enum ResolvedPat<'hir> {
133+
Wild,
134+
Struct(Option<DefId>, Vec<(Symbol, ResolvedPat<'hir>)>),
135+
Sequence(Option<DefId>, Vec<ResolvedPat<'hir>>, Option<usize>),
136+
Or(Vec<ResolvedPat<'hir>>),
137+
Path(Option<DefId>),
138+
LitStr(Symbol),
139+
LitBytes(&'hir [u8]),
140+
LitInt(u128),
141+
LitBool(bool),
142+
Range(PatRange),
143+
}
144+
145+
#[derive(Debug)]
146+
struct PatRange {
147+
start: u128,
148+
end: u128,
149+
bounds: RangeEnd,
150+
}
151+
impl PatRange {
152+
fn contains(&self, x: u128) -> bool {
153+
x >= self.start
154+
&& match self.bounds {
155+
RangeEnd::Included => x <= self.end,
156+
RangeEnd::Excluded => x < self.end,
157+
}
158+
}
159+
160+
fn overlaps(&self, other: &Self) -> bool {
161+
!(self.is_empty() || other.is_empty())
162+
&& match self.bounds {
163+
RangeEnd::Included => self.end >= other.start,
164+
RangeEnd::Excluded => self.end > other.start,
165+
}
166+
&& match other.bounds {
167+
RangeEnd::Included => self.start <= other.end,
168+
RangeEnd::Excluded => self.start < other.end,
169+
}
170+
}
171+
172+
fn is_empty(&self) -> bool {
173+
match self.bounds {
174+
RangeEnd::Included => false,
175+
RangeEnd::Excluded => self.start == self.end,
176+
}
177+
}
178+
}
179+
180+
impl<'hir> ResolvedPat<'hir> {
181+
fn from_pat(cx: &LateContext<'_>, pat: &'hir Pat<'_>) -> Self {
182+
match pat.kind {
183+
PatKind::Wild | PatKind::Binding(.., None) => Self::Wild,
184+
PatKind::Binding(.., Some(pat)) | PatKind::Box(pat) | PatKind::Ref(pat, _) => Self::from_pat(cx, pat),
185+
PatKind::Struct(ref path, fields, _) => {
186+
let mut fields: Vec<_> = fields
187+
.iter()
188+
.map(|f| (f.ident.name, Self::from_pat(cx, f.pat)))
189+
.collect();
190+
fields.sort_by_key(|&(name, _)| name);
191+
Self::Struct(cx.qpath_res(path, pat.hir_id).opt_def_id(), fields)
192+
},
193+
PatKind::TupleStruct(ref path, pats, wild_idx) => Self::Sequence(
194+
cx.qpath_res(path, pat.hir_id).opt_def_id(),
195+
pats.iter().map(|pat| Self::from_pat(cx, pat)).collect(),
196+
wild_idx,
197+
),
198+
PatKind::Or(pats) => Self::Or(pats.iter().map(|pat| Self::from_pat(cx, pat)).collect()),
199+
PatKind::Path(ref path) => Self::Path(cx.qpath_res(path, pat.hir_id).opt_def_id()),
200+
PatKind::Tuple(pats, wild_idx) => {
201+
Self::Sequence(None, pats.iter().map(|pat| Self::from_pat(cx, pat)).collect(), wild_idx)
202+
},
203+
PatKind::Lit(e) => match &e.kind {
204+
ExprKind::Lit(lit) => match lit.node {
205+
LitKind::Str(sym, _) => Self::LitStr(sym),
206+
LitKind::ByteStr(ref bytes) => Self::LitBytes(&**bytes),
207+
LitKind::Byte(val) => Self::LitInt(val.into()),
208+
LitKind::Char(val) => Self::LitInt(val.into()),
209+
LitKind::Int(val, _) => Self::LitInt(val),
210+
LitKind::Bool(val) => Self::LitBool(val),
211+
LitKind::Float(..) | LitKind::Err(_) => Self::Wild,
212+
},
213+
_ => Self::Wild,
214+
},
215+
PatKind::Range(start, end, bounds) => {
216+
let start = match start {
217+
None => 0,
218+
Some(e) => match &e.kind {
219+
ExprKind::Lit(lit) => match lit.node {
220+
LitKind::Int(val, _) => val,
221+
LitKind::Char(val) => val.into(),
222+
LitKind::Byte(val) => val.into(),
223+
_ => return Self::Wild,
224+
},
225+
_ => return Self::Wild,
226+
},
227+
};
228+
let (end, bounds) = match end {
229+
None => (u128::MAX, RangeEnd::Included),
230+
Some(e) => match &e.kind {
231+
ExprKind::Lit(lit) => match lit.node {
232+
LitKind::Int(val, _) => (val, bounds),
233+
LitKind::Char(val) => (val.into(), bounds),
234+
LitKind::Byte(val) => (val.into(), bounds),
235+
_ => return Self::Wild,
236+
},
237+
_ => return Self::Wild,
238+
},
239+
};
240+
Self::Range(PatRange { start, end, bounds })
241+
},
242+
PatKind::Slice(pats, wild, pats2) => Self::Sequence(
243+
None,
244+
pats.iter()
245+
.chain(pats2.iter())
246+
.map(|pat| Self::from_pat(cx, pat))
247+
.collect(),
248+
wild.map(|_| pats.len()),
249+
),
250+
}
251+
}
252+
253+
/// Checks if two patterns overlap in the values they can match assuming they are for the same
254+
/// type.
255+
fn can_also_match(&self, other: &Self) -> bool {
256+
match (self, other) {
257+
(Self::Wild, _) | (_, Self::Wild) => true,
258+
(Self::Or(pats), other) | (other, Self::Or(pats)) => pats.iter().any(|pat| pat.can_also_match(other)),
259+
(Self::Struct(lpath, lfields), Self::Struct(rpath, rfields)) => {
260+
if lpath != rpath {
261+
return false;
262+
}
263+
let mut rfields = rfields.iter();
264+
let mut rfield = match rfields.next() {
265+
Some(x) => x,
266+
None => return true,
267+
};
268+
'outer: for lfield in lfields {
269+
loop {
270+
if lfield.0 < rfield.0 {
271+
continue 'outer;
272+
} else if lfield.0 > rfield.0 {
273+
rfield = match rfields.next() {
274+
Some(x) => x,
275+
None => return true,
276+
};
277+
} else if !lfield.1.can_also_match(&rfield.1) {
278+
return false;
279+
} else {
280+
rfield = match rfields.next() {
281+
Some(x) => x,
282+
None => return true,
283+
};
284+
continue 'outer;
285+
}
286+
}
287+
}
288+
true
289+
},
290+
(Self::Sequence(lpath, lpats, lwild_idx), Self::Sequence(rpath, rpats, rwild_idx)) => {
291+
if lpath != rpath {
292+
return false;
293+
}
294+
295+
let (lpats_start, lpats_end) = lwild_idx
296+
.or(*rwild_idx)
297+
.map_or((&**lpats, [].as_slice()), |idx| lpats.split_at(idx));
298+
let (rpats_start, rpats_end) = rwild_idx
299+
.or(*lwild_idx)
300+
.map_or((&**rpats, [].as_slice()), |idx| rpats.split_at(idx));
301+
302+
lpats_start
303+
.iter()
304+
.zip(rpats_start.iter())
305+
.all(|(lpat, rpat)| lpat.can_also_match(rpat))
306+
// `lpats_end` and `rpats_end` lengths may be disjointed, so start from the end and ignore any
307+
// extras.
308+
&& lpats_end
309+
.iter()
310+
.rev()
311+
.zip(rpats_end.iter().rev())
312+
.all(|(lpat, rpat)| lpat.can_also_match(rpat))
313+
},
314+
(Self::Path(x), Self::Path(y)) => x == y,
315+
(Self::LitStr(x), Self::LitStr(y)) => x == y,
316+
(Self::LitBytes(x), Self::LitBytes(y)) => x == y,
317+
(Self::LitInt(x), Self::LitInt(y)) => x == y,
318+
(Self::LitBool(x), Self::LitBool(y)) => x == y,
319+
(Self::Range(x), Self::Range(y)) => x.overlaps(y),
320+
(Self::Range(range), Self::LitInt(x)) | (Self::LitInt(x), Self::Range(range)) => range.contains(*x),
321+
322+
// Todo: Lit* with Path, Range with Path, LitBytes with Sequence
323+
_ => true,
324+
}
325+
}
326+
}
327+
95328
fn pat_contains_local(pat: &Pat<'_>, id: HirId) -> bool {
96329
let mut result = false;
97330
pat.walk_short(|p| {

tests/ui/match_same_arms2.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,17 @@ fn main() {
174174
Some(2) => 2,
175175
_ => 1,
176176
};
177+
178+
enum Foo {
179+
X(u32),
180+
Y(u32),
181+
Z(u32),
182+
}
183+
184+
let _ = match Foo::X(0) {
185+
Foo::X(0) => 1,
186+
Foo::X(_) | Foo::Y(_) | Foo::Z(0) => 2,
187+
Foo::Z(_) => 1,
188+
_ => 0,
189+
};
177190
}

0 commit comments

Comments
 (0)