|
1 | 1 | use clippy_utils::diagnostics::span_lint_and_then;
|
2 | 2 | use clippy_utils::source::snippet;
|
3 | 3 | use clippy_utils::{path_to_local, search_same, SpanlessEq, SpanlessHash};
|
4 |
| -use rustc_hir::{Arm, Expr, HirId, HirIdMap, HirIdSet, Pat, PatKind}; |
| 4 | +use rustc_ast::ast::LitKind; |
| 5 | +use rustc_hir::def_id::DefId; |
| 6 | +use rustc_hir::{Arm, Expr, ExprKind, HirId, HirIdMap, HirIdSet, Pat, PatKind, RangeEnd}; |
5 | 7 | use rustc_lint::LateContext;
|
| 8 | +use rustc_span::Symbol; |
6 | 9 | use std::collections::hash_map::Entry;
|
7 | 10 |
|
8 | 11 | use super::MATCH_SAME_ARMS;
|
9 | 12 |
|
10 |
| -pub(crate) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) { |
| 13 | +pub(super) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) { |
11 | 14 | let hash = |&(_, arm): &(usize, &Arm<'_>)| -> u64 {
|
12 | 15 | let mut h = SpanlessHash::new(cx);
|
13 | 16 | h.hash_expr(arm.body);
|
14 | 17 | h.finish()
|
15 | 18 | };
|
16 | 19 |
|
| 20 | + let resolved_pats: Vec<_> = arms.iter().map(|a| ResolvedPat::from_pat(cx, a.pat)).collect(); |
| 21 | + |
| 22 | + // The furthast forwards a pattern can move without semantic changes |
| 23 | + let forwards_blocking_idxs: Vec<_> = resolved_pats |
| 24 | + .iter() |
| 25 | + .enumerate() |
| 26 | + .map(|(i, pat)| { |
| 27 | + resolved_pats[i + 1..] |
| 28 | + .iter() |
| 29 | + .enumerate() |
| 30 | + .find_map(|(j, other)| pat.can_also_match(other).then(|| i + 1 + j)) |
| 31 | + .unwrap_or(resolved_pats.len()) |
| 32 | + }) |
| 33 | + .collect(); |
| 34 | + |
| 35 | + // The furthast backwards a pattern can move without semantic changes |
| 36 | + let backwards_blocking_idxs: Vec<_> = resolved_pats |
| 37 | + .iter() |
| 38 | + .enumerate() |
| 39 | + .map(|(i, pat)| { |
| 40 | + resolved_pats[..i] |
| 41 | + .iter() |
| 42 | + .enumerate() |
| 43 | + .rev() |
| 44 | + .zip(forwards_blocking_idxs[..i].iter().copied().rev()) |
| 45 | + .skip_while(|&(_, forward_block)| forward_block > i) |
| 46 | + .find_map(|((j, other), forward_block)| (forward_block == i || pat.can_also_match(other)).then(|| j)) |
| 47 | + .unwrap_or(0) |
| 48 | + }) |
| 49 | + .collect(); |
| 50 | + |
17 | 51 | let eq = |&(lindex, lhs): &(usize, &Arm<'_>), &(rindex, rhs): &(usize, &Arm<'_>)| -> bool {
|
18 | 52 | let min_index = usize::min(lindex, rindex);
|
19 | 53 | let max_index = usize::max(lindex, rindex);
|
@@ -42,14 +76,16 @@ pub(crate) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) {
|
42 | 76 | }
|
43 | 77 | };
|
44 | 78 | // Arms with a guard are ignored, those can’t always be merged together
|
45 |
| - // This is also the case for arms in-between each there is an arm with a guard |
46 |
| - (min_index..=max_index).all(|index| arms[index].guard.is_none()) |
47 |
| - && SpanlessEq::new(cx) |
48 |
| - .expr_fallback(eq_fallback) |
49 |
| - .eq_expr(lhs.body, rhs.body) |
50 |
| - // these checks could be removed to allow unused bindings |
51 |
| - && bindings_eq(lhs.pat, local_map.keys().copied().collect()) |
52 |
| - && bindings_eq(rhs.pat, local_map.values().copied().collect()) |
| 79 | + // If both arms overlap with an arm in between then these can't be merged either. |
| 80 | + !(backwards_blocking_idxs[max_index] > min_index && forwards_blocking_idxs[min_index] < max_index) |
| 81 | + && lhs.guard.is_none() |
| 82 | + && rhs.guard.is_none() |
| 83 | + && SpanlessEq::new(cx) |
| 84 | + .expr_fallback(eq_fallback) |
| 85 | + .eq_expr(lhs.body, rhs.body) |
| 86 | + // these checks could be removed to allow unused bindings |
| 87 | + && bindings_eq(lhs.pat, local_map.keys().copied().collect()) |
| 88 | + && bindings_eq(rhs.pat, local_map.values().copied().collect()) |
53 | 89 | };
|
54 | 90 |
|
55 | 91 | let indexed_arms: Vec<(usize, &Arm<'_>)> = arms.iter().enumerate().collect();
|
@@ -92,6 +128,203 @@ pub(crate) fn check<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>]) {
|
92 | 128 | }
|
93 | 129 | }
|
94 | 130 |
|
| 131 | +#[derive(Debug)] |
| 132 | +enum ResolvedPat<'hir> { |
| 133 | + Wild, |
| 134 | + Struct(Option<DefId>, Vec<(Symbol, ResolvedPat<'hir>)>), |
| 135 | + Sequence(Option<DefId>, Vec<ResolvedPat<'hir>>, Option<usize>), |
| 136 | + Or(Vec<ResolvedPat<'hir>>), |
| 137 | + Path(Option<DefId>), |
| 138 | + LitStr(Symbol), |
| 139 | + LitBytes(&'hir [u8]), |
| 140 | + LitInt(u128), |
| 141 | + LitBool(bool), |
| 142 | + Range(PatRange), |
| 143 | +} |
| 144 | + |
| 145 | +#[derive(Debug)] |
| 146 | +struct PatRange { |
| 147 | + start: u128, |
| 148 | + end: u128, |
| 149 | + bounds: RangeEnd, |
| 150 | +} |
| 151 | +impl PatRange { |
| 152 | + fn contains(&self, x: u128) -> bool { |
| 153 | + x >= self.start |
| 154 | + && match self.bounds { |
| 155 | + RangeEnd::Included => x <= self.end, |
| 156 | + RangeEnd::Excluded => x < self.end, |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | + fn overlaps(&self, other: &Self) -> bool { |
| 161 | + !(self.is_empty() || other.is_empty()) |
| 162 | + && match self.bounds { |
| 163 | + RangeEnd::Included => self.end >= other.start, |
| 164 | + RangeEnd::Excluded => self.end > other.start, |
| 165 | + } |
| 166 | + && match other.bounds { |
| 167 | + RangeEnd::Included => self.start <= other.end, |
| 168 | + RangeEnd::Excluded => self.start < other.end, |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + fn is_empty(&self) -> bool { |
| 173 | + match self.bounds { |
| 174 | + RangeEnd::Included => false, |
| 175 | + RangeEnd::Excluded => self.start == self.end, |
| 176 | + } |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +impl<'hir> ResolvedPat<'hir> { |
| 181 | + fn from_pat(cx: &LateContext<'_>, pat: &'hir Pat<'_>) -> Self { |
| 182 | + match pat.kind { |
| 183 | + PatKind::Wild | PatKind::Binding(.., None) => Self::Wild, |
| 184 | + PatKind::Binding(.., Some(pat)) | PatKind::Box(pat) | PatKind::Ref(pat, _) => Self::from_pat(cx, pat), |
| 185 | + PatKind::Struct(ref path, fields, _) => { |
| 186 | + let mut fields: Vec<_> = fields |
| 187 | + .iter() |
| 188 | + .map(|f| (f.ident.name, Self::from_pat(cx, f.pat))) |
| 189 | + .collect(); |
| 190 | + fields.sort_by_key(|&(name, _)| name); |
| 191 | + Self::Struct(cx.qpath_res(path, pat.hir_id).opt_def_id(), fields) |
| 192 | + }, |
| 193 | + PatKind::TupleStruct(ref path, pats, wild_idx) => Self::Sequence( |
| 194 | + cx.qpath_res(path, pat.hir_id).opt_def_id(), |
| 195 | + pats.iter().map(|pat| Self::from_pat(cx, pat)).collect(), |
| 196 | + wild_idx, |
| 197 | + ), |
| 198 | + PatKind::Or(pats) => Self::Or(pats.iter().map(|pat| Self::from_pat(cx, pat)).collect()), |
| 199 | + PatKind::Path(ref path) => Self::Path(cx.qpath_res(path, pat.hir_id).opt_def_id()), |
| 200 | + PatKind::Tuple(pats, wild_idx) => { |
| 201 | + Self::Sequence(None, pats.iter().map(|pat| Self::from_pat(cx, pat)).collect(), wild_idx) |
| 202 | + }, |
| 203 | + PatKind::Lit(e) => match &e.kind { |
| 204 | + ExprKind::Lit(lit) => match lit.node { |
| 205 | + LitKind::Str(sym, _) => Self::LitStr(sym), |
| 206 | + LitKind::ByteStr(ref bytes) => Self::LitBytes(&**bytes), |
| 207 | + LitKind::Byte(val) => Self::LitInt(val.into()), |
| 208 | + LitKind::Char(val) => Self::LitInt(val.into()), |
| 209 | + LitKind::Int(val, _) => Self::LitInt(val), |
| 210 | + LitKind::Bool(val) => Self::LitBool(val), |
| 211 | + LitKind::Float(..) | LitKind::Err(_) => Self::Wild, |
| 212 | + }, |
| 213 | + _ => Self::Wild, |
| 214 | + }, |
| 215 | + PatKind::Range(start, end, bounds) => { |
| 216 | + let start = match start { |
| 217 | + None => 0, |
| 218 | + Some(e) => match &e.kind { |
| 219 | + ExprKind::Lit(lit) => match lit.node { |
| 220 | + LitKind::Int(val, _) => val, |
| 221 | + LitKind::Char(val) => val.into(), |
| 222 | + LitKind::Byte(val) => val.into(), |
| 223 | + _ => return Self::Wild, |
| 224 | + }, |
| 225 | + _ => return Self::Wild, |
| 226 | + }, |
| 227 | + }; |
| 228 | + let (end, bounds) = match end { |
| 229 | + None => (u128::MAX, RangeEnd::Included), |
| 230 | + Some(e) => match &e.kind { |
| 231 | + ExprKind::Lit(lit) => match lit.node { |
| 232 | + LitKind::Int(val, _) => (val, bounds), |
| 233 | + LitKind::Char(val) => (val.into(), bounds), |
| 234 | + LitKind::Byte(val) => (val.into(), bounds), |
| 235 | + _ => return Self::Wild, |
| 236 | + }, |
| 237 | + _ => return Self::Wild, |
| 238 | + }, |
| 239 | + }; |
| 240 | + Self::Range(PatRange { start, end, bounds }) |
| 241 | + }, |
| 242 | + PatKind::Slice(pats, wild, pats2) => Self::Sequence( |
| 243 | + None, |
| 244 | + pats.iter() |
| 245 | + .chain(pats2.iter()) |
| 246 | + .map(|pat| Self::from_pat(cx, pat)) |
| 247 | + .collect(), |
| 248 | + wild.map(|_| pats.len()), |
| 249 | + ), |
| 250 | + } |
| 251 | + } |
| 252 | + |
| 253 | + /// Checks if two patterns overlap in the values they can match assuming they are for the same |
| 254 | + /// type. |
| 255 | + fn can_also_match(&self, other: &Self) -> bool { |
| 256 | + match (self, other) { |
| 257 | + (Self::Wild, _) | (_, Self::Wild) => true, |
| 258 | + (Self::Or(pats), other) | (other, Self::Or(pats)) => pats.iter().any(|pat| pat.can_also_match(other)), |
| 259 | + (Self::Struct(lpath, lfields), Self::Struct(rpath, rfields)) => { |
| 260 | + if lpath != rpath { |
| 261 | + return false; |
| 262 | + } |
| 263 | + let mut rfields = rfields.iter(); |
| 264 | + let mut rfield = match rfields.next() { |
| 265 | + Some(x) => x, |
| 266 | + None => return true, |
| 267 | + }; |
| 268 | + 'outer: for lfield in lfields { |
| 269 | + loop { |
| 270 | + if lfield.0 < rfield.0 { |
| 271 | + continue 'outer; |
| 272 | + } else if lfield.0 > rfield.0 { |
| 273 | + rfield = match rfields.next() { |
| 274 | + Some(x) => x, |
| 275 | + None => return true, |
| 276 | + }; |
| 277 | + } else if !lfield.1.can_also_match(&rfield.1) { |
| 278 | + return false; |
| 279 | + } else { |
| 280 | + rfield = match rfields.next() { |
| 281 | + Some(x) => x, |
| 282 | + None => return true, |
| 283 | + }; |
| 284 | + continue 'outer; |
| 285 | + } |
| 286 | + } |
| 287 | + } |
| 288 | + true |
| 289 | + }, |
| 290 | + (Self::Sequence(lpath, lpats, lwild_idx), Self::Sequence(rpath, rpats, rwild_idx)) => { |
| 291 | + if lpath != rpath { |
| 292 | + return false; |
| 293 | + } |
| 294 | + |
| 295 | + let (lpats_start, lpats_end) = lwild_idx |
| 296 | + .or(*rwild_idx) |
| 297 | + .map_or((&**lpats, [].as_slice()), |idx| lpats.split_at(idx)); |
| 298 | + let (rpats_start, rpats_end) = rwild_idx |
| 299 | + .or(*lwild_idx) |
| 300 | + .map_or((&**rpats, [].as_slice()), |idx| rpats.split_at(idx)); |
| 301 | + |
| 302 | + lpats_start |
| 303 | + .iter() |
| 304 | + .zip(rpats_start.iter()) |
| 305 | + .all(|(lpat, rpat)| lpat.can_also_match(rpat)) |
| 306 | + // `lpats_end` and `rpats_end` lengths may be disjointed, so start from the end and ignore any |
| 307 | + // extras. |
| 308 | + && lpats_end |
| 309 | + .iter() |
| 310 | + .rev() |
| 311 | + .zip(rpats_end.iter().rev()) |
| 312 | + .all(|(lpat, rpat)| lpat.can_also_match(rpat)) |
| 313 | + }, |
| 314 | + (Self::Path(x), Self::Path(y)) => x == y, |
| 315 | + (Self::LitStr(x), Self::LitStr(y)) => x == y, |
| 316 | + (Self::LitBytes(x), Self::LitBytes(y)) => x == y, |
| 317 | + (Self::LitInt(x), Self::LitInt(y)) => x == y, |
| 318 | + (Self::LitBool(x), Self::LitBool(y)) => x == y, |
| 319 | + (Self::Range(x), Self::Range(y)) => x.overlaps(y), |
| 320 | + (Self::Range(range), Self::LitInt(x)) | (Self::LitInt(x), Self::Range(range)) => range.contains(*x), |
| 321 | + |
| 322 | + // Todo: Lit* with Path, Range with Path, LitBytes with Sequence |
| 323 | + _ => true, |
| 324 | + } |
| 325 | + } |
| 326 | +} |
| 327 | + |
95 | 328 | fn pat_contains_local(pat: &Pat<'_>, id: HirId) -> bool {
|
96 | 329 | let mut result = false;
|
97 | 330 | pat.walk_short(|p| {
|
|
0 commit comments