Skip to content

Commit 0902904

Browse files
authored
[InstCombine] Fold max/min when incrementing/decrementing by 1 (#142466)
Add the following folds for integer min max folding in InstCombine: - (X > Y) ? X : (Y - 1) ==> MIN(X, Y - 1) - (X < Y) ? X : (Y + 1) ==> MAX(X, Y + 1) These are safe when overflow corresponding to the sign of the comparison is poison. (proof https://alive2.llvm.org/ce/z/oj5iiI). The most common of these patterns is likely the minimum case which occurs in some internal library code when clamping an integer index to a range (The maximum cases are included for completeness). Here is a simplified example: int clampToWidth(int idx, int width) { if (idx >= width) return width - 1; return idx; } https://cuda.godbolt.org/z/nhPzWrc3W
1 parent e74d834 commit 0902904

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,62 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
565565
return nullptr;
566566
}
567567

568+
/// Try to fold a select to a min/max intrinsic. Many cases are already handled
569+
/// by matchDecomposedSelectPattern but here we handle the cases where more
570+
/// extensive modification of the IR is required.
571+
static Value *foldSelectICmpMinMax(const ICmpInst *Cmp, Value *TVal,
572+
Value *FVal,
573+
InstCombiner::BuilderTy &Builder,
574+
const SimplifyQuery &SQ) {
575+
const Value *CmpLHS = Cmp->getOperand(0);
576+
const Value *CmpRHS = Cmp->getOperand(1);
577+
ICmpInst::Predicate Pred = Cmp->getPredicate();
578+
579+
// (X > Y) ? X : (Y - 1) ==> MIN(X, Y - 1)
580+
// (X < Y) ? X : (Y + 1) ==> MAX(X, Y + 1)
581+
// This transformation is valid when overflow corresponding to the sign of
582+
// the comparison is poison and we must drop the non-matching overflow flag.
583+
if (CmpRHS == TVal) {
584+
std::swap(CmpLHS, CmpRHS);
585+
Pred = CmpInst::getSwappedPredicate(Pred);
586+
}
587+
588+
// TODO: consider handling 'or disjoint' as well, though these would need to
589+
// be converted to 'add' instructions.
590+
if (!(CmpLHS == TVal && isa<Instruction>(FVal)))
591+
return nullptr;
592+
593+
if (Pred == CmpInst::ICMP_SGT &&
594+
match(FVal, m_NSWAdd(m_Specific(CmpRHS), m_One()))) {
595+
cast<Instruction>(FVal)->setHasNoUnsignedWrap(false);
596+
return Builder.CreateBinaryIntrinsic(Intrinsic::smax, TVal, FVal);
597+
}
598+
599+
if (Pred == CmpInst::ICMP_SLT &&
600+
match(FVal, m_NSWAdd(m_Specific(CmpRHS), m_AllOnes()))) {
601+
cast<Instruction>(FVal)->setHasNoUnsignedWrap(false);
602+
return Builder.CreateBinaryIntrinsic(Intrinsic::smin, TVal, FVal);
603+
}
604+
605+
if (Pred == CmpInst::ICMP_UGT &&
606+
match(FVal, m_NUWAdd(m_Specific(CmpRHS), m_One()))) {
607+
cast<Instruction>(FVal)->setHasNoSignedWrap(false);
608+
return Builder.CreateBinaryIntrinsic(Intrinsic::umax, TVal, FVal);
609+
}
610+
611+
// Note: We must use isKnownNonZero here because "sub nuw %x, 1" will be
612+
// canonicalized to "add %x, -1" discarding the nuw flag.
613+
if (Pred == CmpInst::ICMP_ULT &&
614+
match(FVal, m_Add(m_Specific(CmpRHS), m_AllOnes())) &&
615+
isKnownNonZero(CmpRHS, SQ)) {
616+
cast<Instruction>(FVal)->setHasNoSignedWrap(false);
617+
cast<Instruction>(FVal)->setHasNoUnsignedWrap(false);
618+
return Builder.CreateBinaryIntrinsic(Intrinsic::umin, TVal, FVal);
619+
}
620+
621+
return nullptr;
622+
}
623+
568624
/// We want to turn:
569625
/// (select (icmp eq (and X, Y), 0), (and (lshr X, Z), 1), 1)
570626
/// into:
@@ -1940,6 +1996,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
19401996
return &SI;
19411997
}
19421998

1999+
if (Value *V = foldSelectICmpMinMax(ICI, TrueVal, FalseVal, Builder, SQ))
2000+
return replaceInstUsesWith(SI, V);
2001+
19432002
if (Instruction *V =
19442003
foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder))
19452004
return V;

llvm/test/Transforms/InstCombine/minmax-fold.ll

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,3 +1598,247 @@ define <2 x i32> @test_umax_smax_vec_neg(<2 x i32> %x) {
15981598
%umax = call <2 x i32> @llvm.umax.v2i32(<2 x i32> %smax, <2 x i32> <i32 1, i32 10>)
15991599
ret <2 x i32> %umax
16001600
}
1601+
1602+
define i32 @test_smin_sub1_nsw(i32 %x, i32 %w) {
1603+
; CHECK-LABEL: @test_smin_sub1_nsw(
1604+
; CHECK-NEXT: [[SUB:%.*]] = add nsw i32 [[W:%.*]], -1
1605+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 [[SUB]])
1606+
; CHECK-NEXT: ret i32 [[R]]
1607+
;
1608+
%cmp = icmp slt i32 %x, %w
1609+
%sub = add nsw i32 %w, -1
1610+
%r = select i1 %cmp, i32 %x, i32 %sub
1611+
ret i32 %r
1612+
}
1613+
1614+
define i32 @test_smax_add1_nsw(i32 %x, i32 %w) {
1615+
; CHECK-LABEL: @test_smax_add1_nsw(
1616+
; CHECK-NEXT: [[X2:%.*]] = add nsw i32 [[W:%.*]], 1
1617+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 [[X2]])
1618+
; CHECK-NEXT: ret i32 [[R]]
1619+
;
1620+
%cmp = icmp sgt i32 %x, %w
1621+
%add = add nsw i32 %w, 1
1622+
%r = select i1 %cmp, i32 %x, i32 %add
1623+
ret i32 %r
1624+
}
1625+
1626+
define i32 @test_umax_add1_nuw(i32 %x, i32 %w) {
1627+
; CHECK-LABEL: @test_umax_add1_nuw(
1628+
; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[W:%.*]], 1
1629+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[ADD]])
1630+
; CHECK-NEXT: ret i32 [[R]]
1631+
;
1632+
%cmp = icmp ugt i32 %x, %w
1633+
%add = add nuw i32 %w, 1
1634+
%r = select i1 %cmp, i32 %x, i32 %add
1635+
ret i32 %r
1636+
}
1637+
1638+
define i32 @test_umin_sub1_nuw(i32 %x, i32 range(i32 1, 0) %w) {
1639+
; CHECK-LABEL: @test_umin_sub1_nuw(
1640+
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[W:%.*]], -1
1641+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.umin.i32(i32 [[X:%.*]], i32 [[SUB]])
1642+
; CHECK-NEXT: ret i32 [[R]]
1643+
;
1644+
%cmp = icmp ult i32 %x, %w
1645+
%sub = add i32 %w, -1
1646+
%r = select i1 %cmp, i32 %x, i32 %sub
1647+
ret i32 %r
1648+
}
1649+
1650+
define i32 @test_smin_sub1_nsw_swapped(i32 %x, i32 %w) {
1651+
; CHECK-LABEL: @test_smin_sub1_nsw_swapped(
1652+
; CHECK-NEXT: [[SUB:%.*]] = add nsw i32 [[W:%.*]], -1
1653+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 [[SUB]])
1654+
; CHECK-NEXT: ret i32 [[R]]
1655+
;
1656+
%cmp = icmp sgt i32 %w, %x
1657+
%sub = add nsw i32 %w, -1
1658+
%r = select i1 %cmp, i32 %x, i32 %sub
1659+
ret i32 %r
1660+
}
1661+
1662+
define i32 @test_smax_add1_nsw_swapped(i32 %x, i32 %w) {
1663+
; CHECK-LABEL: @test_smax_add1_nsw_swapped(
1664+
; CHECK-NEXT: [[X2:%.*]] = add nsw i32 [[W:%.*]], 1
1665+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 [[X2]])
1666+
; CHECK-NEXT: ret i32 [[R]]
1667+
;
1668+
%cmp = icmp slt i32 %w, %x
1669+
%add = add nsw i32 %w, 1
1670+
%r = select i1 %cmp, i32 %x, i32 %add
1671+
ret i32 %r
1672+
}
1673+
1674+
define i32 @test_umax_add1_nuw_swapped(i32 %x, i32 %w) {
1675+
; CHECK-LABEL: @test_umax_add1_nuw_swapped(
1676+
; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[W:%.*]], 1
1677+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[ADD]])
1678+
; CHECK-NEXT: ret i32 [[R]]
1679+
;
1680+
%cmp = icmp ult i32 %w, %x
1681+
%add = add nuw i32 %w, 1
1682+
%r = select i1 %cmp, i32 %x, i32 %add
1683+
ret i32 %r
1684+
}
1685+
1686+
define i32 @test_umin_sub1_nuw_swapped(i32 %x, i32 range(i32 1, 0) %w) {
1687+
; CHECK-LABEL: @test_umin_sub1_nuw_swapped(
1688+
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[W:%.*]], -1
1689+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.umin.i32(i32 [[X:%.*]], i32 [[SUB]])
1690+
; CHECK-NEXT: ret i32 [[R]]
1691+
;
1692+
%cmp = icmp ugt i32 %w, %x
1693+
%sub = add i32 %w, -1
1694+
%r = select i1 %cmp, i32 %x, i32 %sub
1695+
ret i32 %r
1696+
}
1697+
1698+
define <2 x i16> @test_smin_sub1_nsw_vec(<2 x i16> %x, <2 x i16> %w) {
1699+
; CHECK-LABEL: @test_smin_sub1_nsw_vec(
1700+
; CHECK-NEXT: [[SUB:%.*]] = add nsw <2 x i16> [[W:%.*]], splat (i16 -1)
1701+
; CHECK-NEXT: [[R:%.*]] = call <2 x i16> @llvm.smin.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[SUB]])
1702+
; CHECK-NEXT: ret <2 x i16> [[R]]
1703+
;
1704+
%cmp = icmp slt <2 x i16> %x, %w
1705+
%sub = add nsw <2 x i16> %w, splat (i16 -1)
1706+
%r = select <2 x i1> %cmp, <2 x i16> %x, <2 x i16> %sub
1707+
ret <2 x i16> %r
1708+
}
1709+
1710+
define <2 x i16> @test_smax_add1_nsw_vec(<2 x i16> %x, <2 x i16> %w) {
1711+
; CHECK-LABEL: @test_smax_add1_nsw_vec(
1712+
; CHECK-NEXT: [[ADD:%.*]] = add nsw <2 x i16> [[W:%.*]], splat (i16 1)
1713+
; CHECK-NEXT: [[R:%.*]] = call <2 x i16> @llvm.smax.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[ADD]])
1714+
; CHECK-NEXT: ret <2 x i16> [[R]]
1715+
;
1716+
%cmp = icmp sgt <2 x i16> %x, %w
1717+
%add = add nsw <2 x i16> %w, splat (i16 1)
1718+
%r = select <2 x i1> %cmp, <2 x i16> %x, <2 x i16> %add
1719+
ret <2 x i16> %r
1720+
}
1721+
1722+
define <2 x i16> @test_umax_add1_nuw_vec(<2 x i16> %x, <2 x i16> %w) {
1723+
; CHECK-LABEL: @test_umax_add1_nuw_vec(
1724+
; CHECK-NEXT: [[ADD:%.*]] = add nuw <2 x i16> [[W:%.*]], splat (i16 1)
1725+
; CHECK-NEXT: [[R:%.*]] = call <2 x i16> @llvm.umax.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[ADD]])
1726+
; CHECK-NEXT: ret <2 x i16> [[R]]
1727+
;
1728+
%cmp = icmp ugt <2 x i16> %x, %w
1729+
%add = add nuw <2 x i16> %w, splat (i16 1)
1730+
%r = select <2 x i1> %cmp, <2 x i16> %x, <2 x i16> %add
1731+
ret <2 x i16> %r
1732+
}
1733+
1734+
define <2 x i16> @test_umin_sub1_nuw_vec(<2 x i16> %x, <2 x i16> range(i16 1, 0) %w) {
1735+
; CHECK-LABEL: @test_umin_sub1_nuw_vec(
1736+
; CHECK-NEXT: [[SUB:%.*]] = add <2 x i16> [[W:%.*]], splat (i16 -1)
1737+
; CHECK-NEXT: [[R:%.*]] = call <2 x i16> @llvm.umin.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[SUB]])
1738+
; CHECK-NEXT: ret <2 x i16> [[R]]
1739+
;
1740+
%cmp = icmp ult <2 x i16> %x, %w
1741+
%sub = add <2 x i16> %w, splat (i16 -1)
1742+
%r = select <2 x i1> %cmp, <2 x i16> %x, <2 x i16> %sub
1743+
ret <2 x i16> %r
1744+
}
1745+
1746+
1747+
define i32 @test_smin_sub1_nsw_drop_flags(i32 %x, i32 %w) {
1748+
; CHECK-LABEL: @test_smin_sub1_nsw_drop_flags(
1749+
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[X:%.*]], [[W:%.*]]
1750+
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 -1
1751+
; CHECK-NEXT: ret i32 [[R]]
1752+
;
1753+
%cmp = icmp slt i32 %x, %w
1754+
%sub = add nsw nuw i32 %w, -1
1755+
%r = select i1 %cmp, i32 %x, i32 %sub
1756+
ret i32 %r
1757+
}
1758+
1759+
define i32 @test_smax_add1_nsw_drop_flags(i32 %x, i32 %w) {
1760+
; CHECK-LABEL: @test_smax_add1_nsw_drop_flags(
1761+
; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[W:%.*]], 1
1762+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 [[ADD]])
1763+
; CHECK-NEXT: ret i32 [[R]]
1764+
;
1765+
%cmp = icmp sgt i32 %x, %w
1766+
%add = add nsw nuw i32 %w, 1
1767+
%r = select i1 %cmp, i32 %x, i32 %add
1768+
ret i32 %r
1769+
}
1770+
1771+
define i32 @test_umax_add1_nuw_drop_flags(i32 %x, i32 %w) {
1772+
; CHECK-LABEL: @test_umax_add1_nuw_drop_flags(
1773+
; CHECK-NEXT: [[ADD:%.*]] = add nuw i32 [[W:%.*]], 1
1774+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[ADD]])
1775+
; CHECK-NEXT: ret i32 [[R]]
1776+
;
1777+
%cmp = icmp ugt i32 %x, %w
1778+
%add = add nuw nsw i32 %w, 1
1779+
%r = select i1 %cmp, i32 %x, i32 %add
1780+
ret i32 %r
1781+
}
1782+
1783+
define i32 @test_umin_sub1_nuw_drop_flags(i32 %x, i32 range(i32 1, 0) %w) {
1784+
; CHECK-LABEL: @test_umin_sub1_nuw_drop_flags(
1785+
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[W:%.*]], -1
1786+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.umin.i32(i32 [[X:%.*]], i32 [[SUB]])
1787+
; CHECK-NEXT: ret i32 [[R]]
1788+
;
1789+
%cmp = icmp ult i32 %x, %w
1790+
%sub = add nsw i32 %w, -1
1791+
%r = select i1 %cmp, i32 %x, i32 %sub
1792+
ret i32 %r
1793+
}
1794+
1795+
;; Confirm we don't crash on these cases.
1796+
define i32 @test_smin_or_neg1_nsw(i32 %x, i32 %w) {
1797+
; CHECK-LABEL: @test_smin_or_neg1_nsw(
1798+
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[X:%.*]], [[W:%.*]]
1799+
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 -1
1800+
; CHECK-NEXT: ret i32 [[R]]
1801+
;
1802+
%cmp = icmp slt i32 %x, %w
1803+
%sub = or disjoint i32 %w, -1
1804+
%r = select i1 %cmp, i32 %x, i32 %sub
1805+
ret i32 %r
1806+
}
1807+
1808+
define i32 @test_smax_or_1_nsw(i32 %x, i32 %w) {
1809+
; CHECK-LABEL: @test_smax_or_1_nsw(
1810+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], [[W:%.*]]
1811+
; CHECK-NEXT: [[ADD:%.*]] = or disjoint i32 [[W]], 1
1812+
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 [[ADD]]
1813+
; CHECK-NEXT: ret i32 [[R]]
1814+
;
1815+
%cmp = icmp sgt i32 %x, %w
1816+
%add = or disjoint i32 %w, 1
1817+
%r = select i1 %cmp, i32 %x, i32 %add
1818+
ret i32 %r
1819+
}
1820+
1821+
define i32 @test_umax_or_1_nuw(i32 %x, i32 %w) {
1822+
; CHECK-LABEL: @test_umax_or_1_nuw(
1823+
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], [[W:%.*]]
1824+
; CHECK-NEXT: [[ADD:%.*]] = or disjoint i32 [[W]], 1
1825+
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 [[ADD]]
1826+
; CHECK-NEXT: ret i32 [[R]]
1827+
;
1828+
%cmp = icmp ugt i32 %x, %w
1829+
%add = or disjoint i32 %w, 1
1830+
%r = select i1 %cmp, i32 %x, i32 %add
1831+
ret i32 %r
1832+
}
1833+
1834+
define i32 @test_umin_or_neg1_nuw(i32 %x, i32 range(i32 1, 0) %w) {
1835+
; CHECK-LABEL: @test_umin_or_neg1_nuw(
1836+
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], [[W:%.*]]
1837+
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 -1
1838+
; CHECK-NEXT: ret i32 [[R]]
1839+
;
1840+
%cmp = icmp ult i32 %x, %w
1841+
%sub = or disjoint i32 %w, -1
1842+
%r = select i1 %cmp, i32 %x, i32 %sub
1843+
ret i32 %r
1844+
}

0 commit comments

Comments
 (0)