Skip to content

[InstCombine] Fold (op x, ({z,s}ext (icmp eq x, C))) to select #89020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Apr 17, 2024

  • [InstCombine] Add tests for folding (op x, ({z,s}ext (icmp eq x, C))); NFC
  • [InstCombine] Fold (op x, ({z,s}ext (icmp eq x, C))) to select

This is a followup to: #88579

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: None (goldsteinn)

Changes
  • [InstCombine] Add tests for folding (op x, ({z,s}ext (icmp eq x, C))); NFC
  • [InstCombine] Fold (op x, ({z,s}ext (icmp eq x, C))) to select
  • [InstSimplify] Add basic simplification support for {s,u}shl_sat

Patch is 20.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89020.diff

12 Files Affected:

  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+16)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+16)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+9)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+3)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+3)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+9)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+83)
  • (modified) llvm/test/Transforms/InstCombine/apint-shift.ll (+1-8)
  • (added) llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll (+187)
  • (modified) llvm/test/Transforms/InstCombine/freeze-integer-intrinsics.ll (+2-4)
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48af..04cc1278006ba1 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1512,6 +1512,8 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::usub_with_overflow:
   case Intrinsic::smul_with_overflow:
   case Intrinsic::umul_with_overflow:
+  case Intrinsic::sshl_sat:
+  case Intrinsic::ushl_sat:      
   case Intrinsic::sadd_sat:
   case Intrinsic::uadd_sat:
   case Intrinsic::ssub_sat:
@@ -2818,6 +2820,20 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
       };
       return ConstantStruct::get(cast<StructType>(Ty), Ops);
     }
+    case Intrinsic::sshl_sat:
+    case Intrinsic::ushl_sat:
+      // This is the same as for binary ops - poison propagates.
+      // TODO: Poison handling should be consolidated.
+      if (isa<PoisonValue>(Operands[0]) || isa<PoisonValue>(Operands[1]))
+        return PoisonValue::get(Ty);
+      if (!C0 && !C1)
+        return UndefValue::get(Ty);
+      if (!C0 || !C1)
+        return Constant::getNullValue(Ty);
+      if (IntrinsicID == Intrinsic::ushl_sat)
+        return ConstantInt::get(Ty, C0->ushl_sat(*C1));
+      else
+        return ConstantInt::get(Ty, C0->sshl_sat(*C1));
     case Intrinsic::uadd_sat:
     case Intrinsic::sadd_sat:
       // This is the same as for binary ops - poison propagates.
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 8955de6375dec4..e6085d17e8f838 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6534,6 +6534,22 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
     if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1))
       return Constant::getNullValue(ReturnType);
     break;
+  case Intrinsic::ushl_sat:
+    // ushl_sat(0, X) -> 0
+    // ushl_sat(UINT_MAX, X) -> UINT_MAX
+    // ushl_sat(X, 0) -> X
+    if (match(Op0, m_Zero()) || match(Op0, m_AllOnes()) || match(Op1, m_Zero()))
+      return Op0;
+    break;
+  case Intrinsic::sshl_sat:
+    // sshl_sat(0, X) -> 0
+    // sshl_sat(INT_MAX, X) -> INT_MAX
+    // sshl_sat(INT_MIN, X) -> INT_MIN
+    // sshl_sat(X, 0) -> X
+    if (match(Op0, m_Zero()) || match(Op0, m_MaxSignedValue()) ||
+        match(Op0, m_SignMask()) || match(Op1, m_Zero()))
+      return Op0;
+    break;
   case Intrinsic::uadd_sat:
     // sat(MAX + X) -> MAX
     // sat(X + MAX) -> MAX
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index c59b867b10e7d1..caf507133f8a1f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1458,6 +1458,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   // (A*B)+(A*C) -> A*(B+C) etc
   if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
@@ -2076,6 +2079,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
   // If this is a 'B = x-(-A)', change to B = x+A.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index d311690be64f16..6df7bce5962a60 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2275,6 +2275,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   // See if we can simplify any instructions used by the instruction whose sole
   // purpose is to compute bits we don't care about.
   if (SimplifyDemandedInstructionBits(I))
@@ -3444,6 +3447,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   // See if we can simplify any instructions used by the instruction whose sole
   // purpose is to compute bits we don't care about.
   if (SimplifyDemandedInstructionBits(I))
@@ -4579,6 +4585,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
   if (Instruction *NewXor = foldXorToXor(I, Builder))
     return NewXor;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   // (A&B)^(A&C) -> A&(B^C) etc
   if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 60e4be883f513b..c0d78f81400c67 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1462,6 +1462,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   IntrinsicInst *II = dyn_cast<IntrinsicInst>(&CI);
   if (!II) return visitCallBase(CI);
 
+  if (Value *R = foldOpOfXWithXEqC(II, SQ.getWithInstruction(&CI)))
+    return replaceInstUsesWith(CI, R);
+
   // For atomic unordered mem intrinsics if len is not a positive or
   // not a multiple of element size then behavior is undefined.
   if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index b9ad3a74007929..7708231f4e3a10 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -754,6 +754,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned);
 
+  Value *foldOpOfXWithXEqC(Value *Op, const SimplifyQuery &SQ);
   bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock);
   void tryToSinkInstructionDbgValues(
       Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 48372381a0d1cd..d38e559bb2236c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -204,6 +204,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Instruction *Phi = foldBinopWithPhiOperands(I))
     return Phi;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 95aa2119e2d88b..4a0dae63189030 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1020,6 +1020,9 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
   if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder))
     return V;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, Q))
+    return replaceInstUsesWith(I, R);
+
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   Type *Ty = I.getType();
   unsigned BitWidth = Ty->getScalarSizeInBits();
@@ -1256,6 +1259,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
   if (Instruction *R = commonShiftTransforms(I))
     return R;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   Type *Ty = I.getType();
   Value *X;
@@ -1591,6 +1597,9 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
   if (Instruction *R = commonShiftTransforms(I))
     return R;
 
+  if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I)))
+    return replaceInstUsesWith(I, R);
+
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   Type *Ty = I.getType();
   unsigned BitWidth = Ty->getScalarSizeInBits();
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5a144cc7378962..d936f79ff32735 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4734,6 +4734,89 @@ void InstCombinerImpl::tryToSinkInstructionDbgValues(
   }
 }
 
+// If we have:
+//  `(op X, (zext/sext (icmp eq X, C)))`
+// We can transform it to:
+//  `(select (icmp eq X, C), (op C, (zext/sext 1)), (op X, 0))`
+// We do so if the `zext/sext` is one use and `(op X, 0)` simplifies.
+Value *InstCombinerImpl::foldOpOfXWithXEqC(Value *Op, const SimplifyQuery &SQ) {
+  Value *Cond;
+  Constant *C, *ExtC;
+
+  // match `(op X, (zext/sext (icmp eq X, C)))` and see if `(op X, 0)`
+  // simplifies.
+  // If we match and simplify, store the `icmp` in `Cond`, `(zext/sext C)` in
+  // `ExtC`.
+  auto MatchXWithXEqC = [&](Value *Op0, Value *Op1) -> Value * {
+    if (match(Op0, m_OneUse(m_ZExtOrSExt(m_Value(Cond))))) {
+      ICmpInst::Predicate Pred;
+      if (!match(Cond, m_ICmp(Pred, m_Specific(Op1), m_ImmConstant(C))) ||
+          Pred != ICmpInst::ICMP_EQ)
+        return nullptr;
+
+      ExtC = isa<SExtInst>(Op0) ? ConstantInt::getAllOnesValue(C->getType())
+                                : ConstantInt::get(C->getType(), 1);
+      return simplifyWithOpReplaced(Op, Op0,
+                                    Constant::getNullValue(Op1->getType()), SQ,
+                                    /*AllowRefinement=*/true);
+    }
+    return nullptr;
+  };
+
+  Value *SimpleOp = nullptr, *ConstOp = nullptr;
+  if (auto *BO = dyn_cast<BinaryOperator>(Op)) {
+    switch (BO->getOpcode()) {
+      // Potential TODO: For all of these, if Op1 is the compare, the compare
+      // must be true and we could replace Op0 with C (otherwise immediate UB).
+    case Instruction::UDiv:
+    case Instruction::SDiv:
+    case Instruction::URem:
+    case Instruction::SRem:
+      return nullptr;
+    default:
+      break;
+    }
+
+    // Try X is Op0
+    if ((SimpleOp = MatchXWithXEqC(BO->getOperand(0), BO->getOperand(1))))
+      ConstOp = Builder.CreateBinOp(BO->getOpcode(), ExtC, C);
+    // Try X is Op1
+    else if ((SimpleOp = MatchXWithXEqC(BO->getOperand(1), BO->getOperand(0))))
+      ConstOp = Builder.CreateBinOp(BO->getOpcode(), C, ExtC);
+  } else if (auto *II = dyn_cast<IntrinsicInst>(Op)) {
+    switch (II->getIntrinsicID()) {
+    default:
+      return nullptr;
+    case Intrinsic::sshl_sat:
+    case Intrinsic::ushl_sat:
+    case Intrinsic::umax:
+    case Intrinsic::umin:
+    case Intrinsic::smax:
+    case Intrinsic::smin:
+    case Intrinsic::uadd_sat:
+    case Intrinsic::usub_sat:
+    case Intrinsic::sadd_sat:
+    case Intrinsic::ssub_sat:
+      // Try X is Op0
+      if ((SimpleOp =
+               MatchXWithXEqC(II->getArgOperand(0), II->getArgOperand(1))))
+        ConstOp = Builder.CreateBinaryIntrinsic(II->getIntrinsicID(), ExtC, C);
+      // Try X is Op1
+      else if ((SimpleOp =
+                    MatchXWithXEqC(II->getArgOperand(1), II->getArgOperand(0))))
+        ConstOp = Builder.CreateBinaryIntrinsic(II->getIntrinsicID(), C, ExtC);
+      break;
+    }
+  }
+
+  assert((SimpleOp == nullptr) == (ConstOp == nullptr) &&
+         "Simplfied Op and Constant Op are de-synced!");
+  if (SimpleOp == nullptr)
+    return nullptr;
+
+  return Builder.CreateSelect(Cond, ConstOp, SimpleOp);
+}
+
 void InstCombinerImpl::tryToSinkInstructionDbgVariableRecords(
     Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock,
     BasicBlock *DestBlock,
diff --git a/llvm/test/Transforms/InstCombine/apint-shift.ll b/llvm/test/Transforms/InstCombine/apint-shift.ll
index 05c3db70ce1ca9..f508939b733217 100644
--- a/llvm/test/Transforms/InstCombine/apint-shift.ll
+++ b/llvm/test/Transforms/InstCombine/apint-shift.ll
@@ -564,14 +564,7 @@ define i40 @test26(i40 %A) {
 ; https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=9880
 define i177 @ossfuzz_9880(i177 %X) {
 ; CHECK-LABEL: @ossfuzz_9880(
-; CHECK-NEXT:    [[A:%.*]] = alloca i177, align 8
-; CHECK-NEXT:    [[L1:%.*]] = load i177, ptr [[A]], align 4
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i177 [[L1]], -1
-; CHECK-NEXT:    [[B5_NEG:%.*]] = sext i1 [[TMP1]] to i177
-; CHECK-NEXT:    [[B14:%.*]] = add i177 [[L1]], [[B5_NEG]]
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i177 [[B14]], -1
-; CHECK-NEXT:    [[B1:%.*]] = zext i1 [[TMP2]] to i177
-; CHECK-NEXT:    ret i177 [[B1]]
+; CHECK-NEXT:    ret i177 0
 ;
   %A = alloca i177
   %L1 = load i177, ptr %A
diff --git a/llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll b/llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll
new file mode 100644
index 00000000000000..84ae405a15d943
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll
@@ -0,0 +1,187 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare void @use.i8(i8)
+define i8 @fold_add_zext_eq_0(i8 %x) {
+; CHECK-LABEL: @fold_add_zext_eq_0(
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 1)
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 0
+  %x_eq_ext = zext i1 %x_eq to i8
+  %r = add i8 %x, %x_eq_ext
+  ret i8 %r
+}
+
+define i8 @fold_add_sext_eq_0(i8 %x) {
+; CHECK-LABEL: @fold_add_sext_eq_0(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 -1, i8 [[X]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 0
+  %x_eq_ext = sext i1 %x_eq to i8
+  %r = add i8 %x, %x_eq_ext
+  ret i8 %r
+}
+
+define i8 @fold_add_zext_eq_0_fail_multiuse_exp(i8 %x) {
+; CHECK-LABEL: @fold_add_zext_eq_0_fail_multiuse_exp(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8
+; CHECK-NEXT:    [[R:%.*]] = add i8 [[X_EQ_EXT]], [[X]]
+; CHECK-NEXT:    call void @use.i8(i8 [[X_EQ_EXT]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 0
+  %x_eq_ext = zext i1 %x_eq to i8
+  %r = add i8 %x, %x_eq_ext
+  call void @use.i8(i8 %x_eq_ext)
+  ret i8 %r
+}
+
+define i8 @fold_mul_sext_eq_12(i8 %x) {
+; CHECK-LABEL: @fold_mul_sext_eq_12(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 12
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 -12, i8 0
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 12
+  %x_eq_ext = sext i1 %x_eq to i8
+  %r = mul i8 %x, %x_eq_ext
+  ret i8 %r
+}
+
+define i8 @fold_mul_sext_eq_12_fail_multiuse(i8 %x) {
+; CHECK-LABEL: @fold_mul_sext_eq_12_fail_multiuse(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 12
+; CHECK-NEXT:    [[X_EQ_EXT:%.*]] = sext i1 [[X_EQ]] to i8
+; CHECK-NEXT:    [[R:%.*]] = mul i8 [[X_EQ_EXT]], [[X]]
+; CHECK-NEXT:    call void @use.i8(i8 [[X_EQ_EXT]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 12
+  %x_eq_ext = sext i1 %x_eq to i8
+  %r = mul i8 %x, %x_eq_ext
+  call void @use.i8(i8 %x_eq_ext)
+  ret i8 %r
+}
+
+define i8 @fold_shl_zext_eq_3_rhs(i8 %x) {
+; CHECK-LABEL: @fold_shl_zext_eq_3_rhs(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 6, i8 [[X]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 3
+  %x_eq_ext = zext i1 %x_eq to i8
+  %r = shl i8 %x, %x_eq_ext
+  ret i8 %r
+}
+
+define i8 @fold_shl_zext_eq_3_lhs(i8 %x) {
+; CHECK-LABEL: @fold_shl_zext_eq_3_lhs(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 8, i8 0
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 3
+  %x_eq_ext = zext i1 %x_eq to i8
+  %r = shl i8 %x_eq_ext, %x
+  ret i8 %r
+}
+
+define <2 x i8> @fold_lshr_sext_eq_15_5_lhs(<2 x i8> %x) {
+; CHECK-LABEL: @fold_lshr_sext_eq_15_5_lhs(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq <2 x i8> [[X:%.*]], <i8 15, i8 5>
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[X_EQ]], <2 x i8> <i8 poison, i8 7>, <2 x i8> zeroinitializer
+; CHECK-NEXT:    ret <2 x i8> [[R]]
+;
+  %x_eq = icmp eq <2 x i8> %x, <i8 15, i8 5>
+  %x_eq_ext = sext <2 x i1> %x_eq to <2 x i8>
+  %r = lshr <2 x i8> %x_eq_ext, %x
+  ret <2 x i8> %r
+}
+
+define <2 x i8> @fold_lshr_sext_eq_15_poison_rhs(<2 x i8> %x) {
+; CHECK-LABEL: @fold_lshr_sext_eq_15_poison_rhs(
+; CHECK-NEXT:    ret <2 x i8> [[X:%.*]]
+;
+  %x_eq = icmp eq <2 x i8> %x, <i8 15, i8 poison>
+  %x_eq_ext = sext <2 x i1> %x_eq to <2 x i8>
+  %r = lshr <2 x i8> %x, %x_eq_ext
+  ret <2 x i8> %r
+}
+
+define i8 @fold_umax_zext_eq_9(i8 %x) {
+; CHECK-LABEL: @fold_umax_zext_eq_9(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 9
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 -1, i8 [[X]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 9
+  %x_eq_ext = sext i1 %x_eq to i8
+  %r = call i8 @llvm.umax.i8(i8 %x, i8 %x_eq_ext)
+  ret i8 %r
+}
+
+define i8 @fold_sshl_sat_sext_eq_3_rhs(i8 %x) {
+; CHECK-LABEL: @fold_sshl_sat_sext_eq_3_rhs(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 127, i8 [[X]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 3
+  %x_eq_ext = sext i1 %x_eq to i8
+  %r = call i8 @llvm.sshl.sat.i8(i8 %x, i8 %x_eq_ext)
+  ret i8 %r
+}
+
+define i8 @fold_ushl_sat_zext_eq_3_lhs(i8 %x) {
+; CHECK-LABEL: @fold_ushl_sat_zext_eq_3_lhs(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 8, i8 0
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 3
+  %x_eq_ext = zext i1 %x_eq to i8
+  %r = call i8 @llvm.ushl.sat.i8(i8 %x_eq_ext, i8 %x)
+  ret i8 %r
+}
+
+define i8 @fold_uadd_sat_zext_eq_3_rhs(i8 %x) {
+; CHECK-LABEL: @fold_uadd_sat_zext_eq_3_rhs(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 4, i8 [[X]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 3
+  %x_eq_ext = zext i1 %x_eq to i8
+  %r = call i8 @llvm.uadd.sat.i8(i8 %x, i8 %x_eq_ext)
+  ret i8 %r
+}
+
+define i8 @fold_ssub_sat_sext_eq_99_lhs_fail(i8 %x) {
+; CHECK-LABEL: @fold_ssub_sat_sext_eq_99_lhs_fail(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 99
+; CHECK-NEXT:    [[X_EQ_EXT:%.*]] = sext i1 [[X_EQ]] to i8
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[X_EQ_EXT]], i8 [[X]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 99
+  %x_eq_ext = sext i1 %x_eq to i8
+  %r = call i8 @llvm.ssub.sat.i8(i8 %x_eq_ext, i8 %x)
+  ret i8 %r
+}
+
+define i8 @fold_ssub_sat_zext_eq_99_rhs(i8 %x) {
+; CHECK-LABEL: @fold_ssub_sat_zext_eq_99_rhs(
+; CHECK-NEXT:    [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 99
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[X_EQ]], i8 98, i8 [[X]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x_eq = icmp eq i8 %x, 99
+  %x_eq_ext = zext i1 %x_eq to i8
+  %r = call i8 @llvm.ssub.sat.i8(i8 %x, i8 %x_eq_ext)
+  ret i8 %r
+}
diff --git a/llvm/test/Transforms/InstCombine/freeze-integer-intrinsics.ll b/llvm/test/Transforms/InstCombine/freeze-integer-intrinsics.ll
index 105bd28fb052e8..99720339b69834 100644
--- a/llvm/test/Transforms/InstCombine/freeze-integer-intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/freeze-integer-intrinsics.ll
@@ -396,8 +396,7 @@ define <2 x i32> @sshl_sat_v2i32_unsafe_constant_vector(<2 x i32> %arg0) {
 
 define <vscale x 2 x i32> @ushl_sat_v2i32_scalable_zeroinitializer(<vscale x 2 x i32> %arg0) {
 ; CHECK-LABEL: @ushl_sat_v2i32_scalable_zeroinitializer(
-; CHECK-NEXT:    [[CALL:%.*]] = call <vscale x 2 x i32> @llvm.ushl.sat.nxv2i32(<vscale x 2 x i32> [[ARG0:%.*]], <vscale x 2 x i32> zeroinitializer)
-; CHECK-NEXT:    [[FREEZE:%.*]] = freeze <vscale x...
[truncated]

@goldsteinn goldsteinn changed the title goldsteinn/op icmp ext x [InstCombine] Fold (op x, ({z,s}ext (icmp eq x, C))) to select Apr 17, 2024
Copy link

github-actions bot commented Apr 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Apr 17, 2024
// We can transform it to:
// `(select (icmp eq X, C), (op C, (zext/sext 1)), (op X, 0))`
// We do so if the `zext/sext` is one use and `(op X, 0)` simplifies.
Value *InstCombinerImpl::foldOpOfXWithXEqC(Value *Op, const SimplifyQuery &SQ) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to handle add X, (zext/sext (icmp eq/ne X, C)) first. The current version looks like an over-generalization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, similar patterns like X * (X != 0) -> X exist already. Can they be folded here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The X != C case isn't the same. We can handle if if (op x, 1) simplifies which is a lot rarer than (op x, 0).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to handle add X, (zext/sext (icmp eq/ne X, C)) first. The current version looks like an over-generalization.

I can see the argument that the intrinsics complicate the situation a bit and can drop those if you feel strongly, but imo the rest of the code is basically the same if its just add vs any op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I think the additional complexity is actually quite significant. Doing this for just add is something like:

  if (match(I, m_c_Add(m_Value(X),
                       m_ZExt(m_CombineAnd(
                           m_Value(Cond),
                           m_ICmp(Pred, m_Deferred(X), m_ImmConstant(C)))))) &&
      Pred == ICmpInst::ICMP_EQ) 
    return Builder.CreateSelect(
        Cond, Builder.CreateAdd(C, ConstantInt::get(C->getType(), 1)), X);

Generalizing it to other operations is a hundred lines of code here, plus a dozen call sites elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let me test which operators get hit. If its just add ill drop all this, otherwise think its worth the generlization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, created #93840

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this PR contain simplification support for shl_sat?!

@goldsteinn
Copy link
Contributor Author

Why does this PR contain simplification support for shl_sat?!

Ill split.

@dtcxzyw
Copy link
Member

dtcxzyw commented May 12, 2024

Reverse ping.

@goldsteinn
Copy link
Contributor Author

Reverse ping.

Was waiting for #88579 to get in before rebasing this.

I can rebase/split shl stuff if you want it to proceed concurrently though.

`(op x, ({z,s}ext (icmp eq x, C)))` is either `(op C, ({z,s}ext 1))`
or `(op x, 0)`.

If both possibilities simplify (i.e constant fold for the former and
either constant fold or converted to just `x` in the latter), fold to:

`(select (icmp eq x, C), (op C, ({z,s}ext 1)), (op x, 0)`.

Which is easier to analyze and should get roughly the same or better
codegen (in most cases something like `({z,s}ext (icmp))` lowers
simliarly to a `select`).
@goldsteinn goldsteinn force-pushed the goldsteinn/op-icmp-ext-x branch from 85c31f9 to 631fd95 Compare May 28, 2024 18:27
@goldsteinn
Copy link
Contributor Author

@dtcxzyw can you re-run this. If we don't hit any non add patterns ill close this and post nikic's simplification.

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request May 28, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented May 29, 2024

@dtcxzyw can you re-run this. If we don't hit any non add patterns ill close this and post nikic's simplification.

Confirmed that most of cases are add patterns except bench/spdlog/optimized/async.cpp.ll. I cannot tell which pattern is folded based on the ir diff.

goldsteinn added a commit to goldsteinn/llvm-project that referenced this pull request May 30, 2024
We can convert this to a select based on the `(icmp eq X, C)`, then
constant fold the addition the true arm begin `(add C, (sext/zext 1))`
and the false arm being `(add X, 0)` e.g

    - `(select (icmp eq X, C), (add C, (sext/zext 1)), (add X, 0))`.

This is essentially a specialization of the only case that sees to
actually show up from llvm#89020
goldsteinn added a commit to goldsteinn/llvm-project that referenced this pull request Jun 1, 2024
We can convert this to a select based on the `(icmp eq X, C)`, then
constant fold the addition the true arm begin `(add C, (sext/zext 1))`
and the false arm being `(add X, 0)` e.g

    - `(select (icmp eq X, C), (add C, (sext/zext 1)), (add X, 0))`.

This is essentially a specialization of the only case that sees to
actually show up from llvm#89020
goldsteinn added a commit that referenced this pull request Jun 1, 2024
We can convert this to a select based on the `(icmp eq X, C)`, then
constant fold the addition the true arm begin `(add C, (sext/zext 1))`
and the false arm being `(add X, 0)` e.g

    - `(select (icmp eq X, C), (add C, (sext/zext 1)), (add X, 0))`.

This is essentially a specialization of the only case that sees to
actually show up from #89020

Closes #93840
@goldsteinn
Copy link
Contributor Author

Closing because we pushed: 0310f7f to handle the add case

@goldsteinn goldsteinn closed this Jun 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants