Skip to content

InstSimplify: support floating-point equivalences #115152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 15, 2024

Conversation

artagnon
Copy link
Contributor

@artagnon artagnon commented Nov 6, 2024

Since cd16b07 (IR: introduce CmpInst::isEquivalence), there is now an isEquivalence routine in CmpInst that we can use to determine equivalence in simplifySelectWithICmpEq. Implement this, extending the code from integer-equalities to integer and floating-point equivalences.

@artagnon artagnon requested a review from nikic as a code owner November 6, 2024 12:10
@llvmbot llvmbot added llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Nov 6, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Ramkumar Ramachandra (artagnon)

Changes

Since cd16b07 (IR: introduce CmpInst::isEquivalence), there is now an isEquivalence routine in CmpInst that we can use to determine equivalence in simplifySelectWithICmpEq. Implement this, extending the code from integer-equalities to integer and floating-point equivalences.


Full diff: https://github.com/llvm/llvm-project/pull/115152.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+28-32)
  • (added) llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll (+156)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2cb2612bf611e3..198707c5667c8c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4617,10 +4617,10 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
 
 /// Try to simplify a select instruction when its condition operand is an
 /// integer equality comparison.
-static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
-                                       Value *TrueVal, Value *FalseVal,
-                                       const SimplifyQuery &Q,
-                                       unsigned MaxRecurse) {
+static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
+                                            Value *TrueVal, Value *FalseVal,
+                                            const SimplifyQuery &Q,
+                                            unsigned MaxRecurse) {
   if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
                              /* AllowRefinement */ false,
                              /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
@@ -4635,23 +4635,21 @@ static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
 
 /// Try to simplify a select instruction when its condition operand is an
 /// integer comparison.
-static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
-                                         Value *FalseVal,
-                                         const SimplifyQuery &Q,
-                                         unsigned MaxRecurse) {
+static Value *simplifySelectWithCmpCond(Value *CondVal, Value *TrueVal,
+                                        Value *FalseVal, const SimplifyQuery &Q,
+                                        unsigned MaxRecurse) {
   ICmpInst::Predicate Pred;
   Value *CmpLHS, *CmpRHS;
-  if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
+  if (!match(CondVal, m_Cmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
     return nullptr;
+  auto *CI = cast<CmpInst>(CondVal);
 
   if (Value *V = simplifyCmpSelOfMaxMin(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal))
     return V;
 
-  // Canonicalize ne to eq predicate.
-  if (Pred == ICmpInst::ICMP_NE) {
-    Pred = ICmpInst::ICMP_EQ;
+  // Canonicalize the equivalence, of which equality is a subset.
+  if (CI->isEquivalence(/*Invert=*/true))
     std::swap(TrueVal, FalseVal);
-  }
 
   // Check for integer min/max with a limit constant:
   // X > MIN_INT ? X : MIN_INT --> X
@@ -4659,9 +4657,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   if (TrueVal->getType()->isIntOrIntVectorTy()) {
     Value *X, *Y;
     SelectPatternFlavor SPF =
-        matchDecomposedSelectPattern(cast<ICmpInst>(CondVal), TrueVal, FalseVal,
-                                     X, Y)
-            .Flavor;
+        matchDecomposedSelectPattern(CI, TrueVal, FalseVal, X, Y).Flavor;
     if (SelectPatternResult::isMinOrMax(SPF) && Pred == getMinMaxPred(SPF)) {
       APInt LimitC = getMinMaxLimit(getInverseMinMaxFlavor(SPF),
                                     X->getType()->getScalarSizeInBits());
@@ -4670,7 +4666,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     }
   }
 
-  if (Pred == ICmpInst::ICMP_EQ && match(CmpRHS, m_Zero())) {
+  if (CI->isEquality() && match(CmpRHS, m_Zero())) {
     Value *X;
     const APInt *Y;
     if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y))))
@@ -4698,7 +4694,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt)
     // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt)
     if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt &&
-        Pred == ICmpInst::ICMP_EQ)
+        CI->isEquality())
       return FalseVal;
 
     // X == 0 ? abs(X) : -abs(X) --> -abs(X)
@@ -4720,12 +4716,12 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // If we have a scalar equality comparison, then we know the value in one of
   // the arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
-  if (Pred == ICmpInst::ICMP_EQ) {
-    if (Value *V = simplifySelectWithICmpEq(CmpLHS, CmpRHS, TrueVal, FalseVal,
-                                            Q, MaxRecurse))
+  if (CI->isEquivalence() || CI->isEquivalence(/*Invert=*/true)) {
+    if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal,
+                                                 FalseVal, Q, MaxRecurse))
       return V;
-    if (Value *V = simplifySelectWithICmpEq(CmpRHS, CmpLHS, TrueVal, FalseVal,
-                                            Q, MaxRecurse))
+    if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal,
+                                                 FalseVal, Q, MaxRecurse))
       return V;
 
     Value *X;
@@ -4734,11 +4730,11 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_Zero())) {
       // (X | Y) == 0 implies X == 0 and Y == 0.
-      if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
     }
 
@@ -4746,11 +4742,11 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_AllOnes())) {
       // (X & Y) == -1 implies X == -1 and Y == -1.
-      if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
     }
   }
@@ -4952,7 +4948,7 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
   }
 
   if (Value *V =
-          simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
+          simplifySelectWithCmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
     return V;
 
   if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal, Q))
diff --git a/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll b/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll
new file mode 100644
index 00000000000000..a59139246b00a6
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll
@@ -0,0 +1,156 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
+
+define float @select_fcmp_fsub_oeq(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_oeq(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_oeq_zero(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_oeq_zero(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp oeq float [[Y:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float [[FADD]], float 2.000000e+00
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp oeq float %y, 0.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 2.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_ueq(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_ueq(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp ueq float [[Y:%.*]], 2.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float [[FADD]], float 0.000000e+00
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp ueq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_ueq_nnan(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_ueq_nnan(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp nnan ueq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_une(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_une(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp une float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_une_zero(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_une_zero(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp une float [[Y:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float 2.000000e+00, float [[FADD]]
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp une float %y, 0.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 2., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_one(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_one(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp one float [[Y:%.*]], 2.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float 0.000000e+00, float [[FADD]]
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp one float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_one_nnan(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_one_nnan(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp nnan one float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fadd(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fadd(
+; CHECK-NEXT:    ret float 4.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fadd = fadd float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 4.
+  ret float %sel
+}
+
+define <2 x float> @select_fcmp_fadd_vec(<2 x float> %x, <2 x float> %y) {
+; CHECK-LABEL: @select_fcmp_fadd_vec(
+; CHECK-NEXT:    ret <2 x float> <float 4.000000e+00, float 4.000000e+00>
+;
+  %fcmp = fcmp oeq <2 x float> %y, <float 2., float 2.>
+  %fadd = fadd <2 x float> %y, <float 2., float 2.>
+  %sel = select <2 x i1> %fcmp, <2 x float> %fadd, <2 x float> <float 4., float 4.>
+  ret <2 x float> %sel
+}
+
+
+define float @select_fcmp_fdiv(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fdiv(
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fdiv = fdiv float %y, 2.
+  %sel = select i1 %fcmp, float %fdiv, float 1.
+  ret float %sel
+}
+
+define float @select_fcmp_frem(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_frem(
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 3.
+  %frem = frem float %y, 2.
+  %sel = select i1 %fcmp, float %frem, float 1.
+  ret float %sel
+}
+
+define <2 x float> @select_fcmp_insertelement(<2 x float> %x, <2 x float> %y) {
+; CHECK-LABEL: @select_fcmp_insertelement(
+; CHECK-NEXT:    ret <2 x float> <float 4.000000e+00, float 2.000000e+00>
+;
+  %fcmp = fcmp oeq <2 x float> %y, <float 2., float 2.>
+  %insert = insertelement <2 x float> %y, float 4., i64 0
+  %sel = select <2 x i1> %fcmp, <2 x float> %insert, <2 x float> <float 4., float 2.>
+  ret <2 x float> %sel
+}
+
+define <4 x float> @select_fcmp_shufflevector_select(<4 x float> %x, <4 x float> %y) {
+; CHECK-LABEL: @select_fcmp_shufflevector_select(
+; CHECK-NEXT:    ret <4 x float> <float poison, float 2.000000e+00, float poison, float 2.000000e+00>
+;
+  %fcmp = fcmp oeq <4 x float> %y, <float 2., float 2., float 2., float 2.>
+  %shuffle = shufflevector <4 x float> %y, <4 x float> poison, <4 x i32> <i32 4, i32 1, i32 6, i32 3>
+  %sel = select <4 x i1> %fcmp, <4 x float> %shuffle, <4 x float> <float poison, float 2., float poison, float 2.>
+  ret <4 x float> %sel
+}
+

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Ramkumar Ramachandra (artagnon)

Changes

Since cd16b07 (IR: introduce CmpInst::isEquivalence), there is now an isEquivalence routine in CmpInst that we can use to determine equivalence in simplifySelectWithICmpEq. Implement this, extending the code from integer-equalities to integer and floating-point equivalences.


Full diff: https://github.com/llvm/llvm-project/pull/115152.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+28-32)
  • (added) llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll (+156)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2cb2612bf611e3..198707c5667c8c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4617,10 +4617,10 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
 
 /// Try to simplify a select instruction when its condition operand is an
 /// integer equality comparison.
-static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
-                                       Value *TrueVal, Value *FalseVal,
-                                       const SimplifyQuery &Q,
-                                       unsigned MaxRecurse) {
+static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
+                                            Value *TrueVal, Value *FalseVal,
+                                            const SimplifyQuery &Q,
+                                            unsigned MaxRecurse) {
   if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
                              /* AllowRefinement */ false,
                              /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
@@ -4635,23 +4635,21 @@ static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
 
 /// Try to simplify a select instruction when its condition operand is an
 /// integer comparison.
-static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
-                                         Value *FalseVal,
-                                         const SimplifyQuery &Q,
-                                         unsigned MaxRecurse) {
+static Value *simplifySelectWithCmpCond(Value *CondVal, Value *TrueVal,
+                                        Value *FalseVal, const SimplifyQuery &Q,
+                                        unsigned MaxRecurse) {
   ICmpInst::Predicate Pred;
   Value *CmpLHS, *CmpRHS;
-  if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
+  if (!match(CondVal, m_Cmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
     return nullptr;
+  auto *CI = cast<CmpInst>(CondVal);
 
   if (Value *V = simplifyCmpSelOfMaxMin(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal))
     return V;
 
-  // Canonicalize ne to eq predicate.
-  if (Pred == ICmpInst::ICMP_NE) {
-    Pred = ICmpInst::ICMP_EQ;
+  // Canonicalize the equivalence, of which equality is a subset.
+  if (CI->isEquivalence(/*Invert=*/true))
     std::swap(TrueVal, FalseVal);
-  }
 
   // Check for integer min/max with a limit constant:
   // X > MIN_INT ? X : MIN_INT --> X
@@ -4659,9 +4657,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   if (TrueVal->getType()->isIntOrIntVectorTy()) {
     Value *X, *Y;
     SelectPatternFlavor SPF =
-        matchDecomposedSelectPattern(cast<ICmpInst>(CondVal), TrueVal, FalseVal,
-                                     X, Y)
-            .Flavor;
+        matchDecomposedSelectPattern(CI, TrueVal, FalseVal, X, Y).Flavor;
     if (SelectPatternResult::isMinOrMax(SPF) && Pred == getMinMaxPred(SPF)) {
       APInt LimitC = getMinMaxLimit(getInverseMinMaxFlavor(SPF),
                                     X->getType()->getScalarSizeInBits());
@@ -4670,7 +4666,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     }
   }
 
-  if (Pred == ICmpInst::ICMP_EQ && match(CmpRHS, m_Zero())) {
+  if (CI->isEquality() && match(CmpRHS, m_Zero())) {
     Value *X;
     const APInt *Y;
     if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y))))
@@ -4698,7 +4694,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt)
     // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt)
     if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt &&
-        Pred == ICmpInst::ICMP_EQ)
+        CI->isEquality())
       return FalseVal;
 
     // X == 0 ? abs(X) : -abs(X) --> -abs(X)
@@ -4720,12 +4716,12 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // If we have a scalar equality comparison, then we know the value in one of
   // the arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
-  if (Pred == ICmpInst::ICMP_EQ) {
-    if (Value *V = simplifySelectWithICmpEq(CmpLHS, CmpRHS, TrueVal, FalseVal,
-                                            Q, MaxRecurse))
+  if (CI->isEquivalence() || CI->isEquivalence(/*Invert=*/true)) {
+    if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal,
+                                                 FalseVal, Q, MaxRecurse))
       return V;
-    if (Value *V = simplifySelectWithICmpEq(CmpRHS, CmpLHS, TrueVal, FalseVal,
-                                            Q, MaxRecurse))
+    if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal,
+                                                 FalseVal, Q, MaxRecurse))
       return V;
 
     Value *X;
@@ -4734,11 +4730,11 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_Zero())) {
       // (X | Y) == 0 implies X == 0 and Y == 0.
-      if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
     }
 
@@ -4746,11 +4742,11 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_AllOnes())) {
       // (X & Y) == -1 implies X == -1 and Y == -1.
-      if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q,
-                                              MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
+                                                   Q, MaxRecurse))
         return V;
     }
   }
@@ -4952,7 +4948,7 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
   }
 
   if (Value *V =
-          simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
+          simplifySelectWithCmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))
     return V;
 
   if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal, Q))
diff --git a/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll b/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll
new file mode 100644
index 00000000000000..a59139246b00a6
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/select-equivalence-fp.ll
@@ -0,0 +1,156 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
+
+define float @select_fcmp_fsub_oeq(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_oeq(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_oeq_zero(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_oeq_zero(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp oeq float [[Y:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float [[FADD]], float 2.000000e+00
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp oeq float %y, 0.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 2.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_ueq(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_ueq(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp ueq float [[Y:%.*]], 2.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float [[FADD]], float 0.000000e+00
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp ueq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_ueq_nnan(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_ueq_nnan(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp nnan ueq float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 0.
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_une(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_une(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp une float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_une_zero(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_une_zero(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp une float [[Y:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float 2.000000e+00, float [[FADD]]
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp une float %y, 0.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 2., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_one(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_one(
+; CHECK-NEXT:    [[FCMP:%.*]] = fcmp one float [[Y:%.*]], 2.000000e+00
+; CHECK-NEXT:    [[FADD:%.*]] = fsub float [[Y]], 2.000000e+00
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[FCMP]], float 0.000000e+00, float [[FADD]]
+; CHECK-NEXT:    ret float [[SEL]]
+;
+  %fcmp = fcmp one float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fsub_one_nnan(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fsub_one_nnan(
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %fcmp = fcmp nnan one float %y, 2.
+  %fadd = fsub float %y, 2.
+  %sel = select i1 %fcmp, float 0., float %fadd
+  ret float %sel
+}
+
+define float @select_fcmp_fadd(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fadd(
+; CHECK-NEXT:    ret float 4.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fadd = fadd float %y, 2.
+  %sel = select i1 %fcmp, float %fadd, float 4.
+  ret float %sel
+}
+
+define <2 x float> @select_fcmp_fadd_vec(<2 x float> %x, <2 x float> %y) {
+; CHECK-LABEL: @select_fcmp_fadd_vec(
+; CHECK-NEXT:    ret <2 x float> <float 4.000000e+00, float 4.000000e+00>
+;
+  %fcmp = fcmp oeq <2 x float> %y, <float 2., float 2.>
+  %fadd = fadd <2 x float> %y, <float 2., float 2.>
+  %sel = select <2 x i1> %fcmp, <2 x float> %fadd, <2 x float> <float 4., float 4.>
+  ret <2 x float> %sel
+}
+
+
+define float @select_fcmp_fdiv(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_fdiv(
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 2.
+  %fdiv = fdiv float %y, 2.
+  %sel = select i1 %fcmp, float %fdiv, float 1.
+  ret float %sel
+}
+
+define float @select_fcmp_frem(float %x, float %y) {
+; CHECK-LABEL: @select_fcmp_frem(
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %fcmp = fcmp oeq float %y, 3.
+  %frem = frem float %y, 2.
+  %sel = select i1 %fcmp, float %frem, float 1.
+  ret float %sel
+}
+
+define <2 x float> @select_fcmp_insertelement(<2 x float> %x, <2 x float> %y) {
+; CHECK-LABEL: @select_fcmp_insertelement(
+; CHECK-NEXT:    ret <2 x float> <float 4.000000e+00, float 2.000000e+00>
+;
+  %fcmp = fcmp oeq <2 x float> %y, <float 2., float 2.>
+  %insert = insertelement <2 x float> %y, float 4., i64 0
+  %sel = select <2 x i1> %fcmp, <2 x float> %insert, <2 x float> <float 4., float 2.>
+  ret <2 x float> %sel
+}
+
+define <4 x float> @select_fcmp_shufflevector_select(<4 x float> %x, <4 x float> %y) {
+; CHECK-LABEL: @select_fcmp_shufflevector_select(
+; CHECK-NEXT:    ret <4 x float> <float poison, float 2.000000e+00, float poison, float 2.000000e+00>
+;
+  %fcmp = fcmp oeq <4 x float> %y, <float 2., float 2., float 2., float 2.>
+  %shuffle = shufflevector <4 x float> %y, <4 x float> poison, <4 x i32> <i32 4, i32 1, i32 6, i32 3>
+  %sel = select <4 x i1> %fcmp, <4 x float> %shuffle, <4 x float> <float poison, float 2., float poison, float 2.>
+  ret <4 x float> %sel
+}
+

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is missing tests for the non-refining code path. Try something along the lines of folding (x == 1) ? y : (x * y) to x * y.

And related to that, I think we need to be extra careful about derefinement here. For integers, the only derefinement we really have to worry about is related to undef and poison, which the existing code already does. But for FP we also have to worry about non-determinism relating to libcalls and NaN behavior.

For example, if we have something like x == C : sin(C) : sin(x), then we should generally not fold that to sin(x), because we cannot guarantee that sin(x) at runtime folds to the same sin(C) that the host libm produced.

We should be able to avoid this by passing the non-determinism flag to the constant folding API in simplifyWithOpReplaced, for the non-refining case.

In prepraration to extend select folding to include floating-points
using CmpInst::isEquivalence, cover it with tests first.
Since cd16b07 (IR: introduce CmpInst::isEquivalence), there is now an
isEquivalence routine in CmpInst that we can use to determine
equivalence in simplifySelectWithICmpEq. Implement this, extending
the code from integer-equalities to integer and floating-point
equivalences.
@artagnon
Copy link
Contributor Author

For example, if we have something like x == C : sin(C) : sin(x), then we should generally not fold that to sin(x), because we cannot guarantee that sin(x) at runtime folds to the same sin(C) that the host libm produced.

We should be able to avoid this by passing the non-determinism flag to the constant folding API in simplifyWithOpReplaced, for the non-refining case.

Thanks for catching this subtle point. I do think x == C : sin(x) : sin(C) should fold, as I've demonstrated in the test case.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@artagnon artagnon merged commit 94eebf7 into llvm:main Nov 15, 2024
8 checks passed
@artagnon artagnon deleted the is-equiv-fp branch November 15, 2024 20:06
@nikic
Copy link
Contributor

nikic commented Dec 16, 2024

Alive reports a verification failure for one of the tests: https://alive2.llvm.org/ce/z/eHm-rS

----------------------------------------
define half @src(half %x, half %y) {
#0:
  %fcmp = fcmp oeq half %x, 0x3c00
  %fmul = fmul half %y, %x
  %sel = select i1 %fcmp, half %y, half %fmul
  ret half %sel
}
=>
define half @tgt(half %x, half %y) {
#0:
  %fmul = fmul half %y, %x
  ret half %fmul
}
Transformation doesn't verify!

ERROR: Value mismatch

Example:
half %x = #x3c00 (1)
half %y = #xfc80 (SNaN)

Source:
i1 %fcmp = #x1 (1)
half %fmul = #x7e00 (QNaN)
half %sel = #xfc80 (SNaN)

Target:
half %fmul = #x7c80 (SNaN)
Source value: #xfc80 (SNaN)
Target value: #x7c80 (SNaN)

The problem is that the fmul may return a different NaN value than just passing the value through.

cc @nunoplopes

@nunoplopes
Copy link
Member

Yep, the optimization needs the 'nnan' flag. @jcranmer-intel to confirm.

@nikic
Copy link
Contributor

nikic commented Dec 16, 2024

I was wondering why this gets folded at all, given that this is the AllowRefinement=false case. I think the reason is this code:

// id op x -> x, x op id -> x
if (NewOps[0] == ConstantExpr::getBinOpIdentity(Opcode, I->getType()))
return NewOps[1];
if (NewOps[1] == ConstantExpr::getBinOpIdentity(Opcode, I->getType(),
/* RHS */ true))
return NewOps[0];

We assume that folding x op id to x is non-refining, which is not strictly true for FP.

@nikic
Copy link
Contributor

nikic commented Dec 16, 2024

I've put up #120098 to fix this.

nikic added a commit that referenced this pull request Dec 17, 2024
If x is NaN, then fmul (x, 1) may produce a different NaN value.

Our float semantics explicitly permit folding fmul (x, 1) to x, but we
can't do this when we're replacing a select input, as selects are
supposed to preserve the exact bitwise value.

Fixes
#115152 (comment).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants