Skip to content

[InstCombine] Missed optimization for select a%2==0, (a/2*2)*(a/2*2), 0 #92658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ void computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownBits &Known);
void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
unsigned Depth, const SimplifyQuery &Q);

void computeKnownBitsFromCond(const Value *V, Value *Cond, KnownBits &Known,
unsigned Depth, const SimplifyQuery &SQ,
bool Invert);

/// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
const KnownBits &KnownLHS,
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
return llvm::computeKnownBits(V, Depth, SQ.getWithInstruction(CxtI));
}

void computeKnownBitsFromCond(const Value *V, Value *Cmp, KnownBits &Known,
unsigned Depth, const Instruction *CxtI,
bool Invert) const {
llvm::computeKnownBitsFromCond(V, Cmp, Known, Depth,
SQ.getWithInstruction(CxtI), Invert);
}

bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false,
unsigned Depth = 0,
const Instruction *CxtI = nullptr) {
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
}

static void computeKnownBitsFromCond(const Value *V, Value *Cond,
KnownBits &Known, unsigned Depth,
const SimplifyQuery &SQ, bool Invert) {
void llvm::computeKnownBitsFromCond(const Value *V, Value *Cond,
KnownBits &Known, unsigned Depth,
const SimplifyQuery &SQ, bool Invert) {
Value *A, *B;
if (Depth < MaxAnalysisRecursionDepth &&
match(Cond, m_LogicalOp(m_Value(A), m_Value(B)))) {
Expand Down
63 changes: 63 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,62 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}

/// Attempts to fold (AND %A constant) --> %A
/// if all bits that are zero in the negated constant
/// are also zero in A's known zero bits.
static Value *foldAndMaskPattern(Value *V, Value *Cmp, SelectInst &SI,
InstCombinerImpl &IC, unsigned Depth = 0) {

Value *A;
const APInt *MaskedConstant;

if (match(V, m_And(m_Value(A), m_APInt(MaskedConstant))) &&
isGuaranteedNotToBeUndef(A)) {
KnownBits Known = IC.computeKnownBits(A, 0, &SI);
IC.computeKnownBitsFromCond(A, Cmp, Known, 0, &SI, false);
if ((~(*MaskedConstant)).isSubsetOf(Known.Zero))
return A;
}

auto *I = dyn_cast<Instruction>(V);
if (!I || !isSafeToSpeculativelyExecute(I) || Depth >= 2)
return nullptr;

bool Changed = false;
for (unsigned i = 0; i < I->getNumOperands(); ++i) {
llvm::Value *Operand = I->getOperand(i);

if (std::any_of(Operand->user_begin(), Operand->user_end(),
[I](const User *User) { return User != I; }))
break;

Value *NewOp = foldAndMaskPattern(Operand, Cmp, SI, IC, Depth + 1);
if (NewOp) {
IC.replaceOperand(*I, i, NewOp);
Changed = true;
}
}

return Changed ? I : nullptr;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think this would be better as a recursive function.

The base case being matching m_c_And(m_Value(A), m_APInt(Mask)) and then you can try to simplfy operands of binops/etc... you find along the way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comments, I had thought about something similar but wasn't entirely sure about what approach to go with. Will work on this.

/// Attmpts to fold expressions in both branches of a select instruction
/// based on KnownBits implied by the condition
// static Instruction *foldSelectWithIcmpEqAndPattern(Value *TVal, Value *FVal,
// Value *CondVal,
// SelectInst &SI,
// InstCombinerImpl &IC) {
// if (TVal->hasOneUse())
// if (Value *newTrueOp = simplifyAndMaskPattern(TVal, CondVal, SI, IC))
// return IC.replaceOperand(SI, 1, newTrueOp);

// if (FVal->hasOneUse())
// if (Value *newFalseOp = simplifyAndMaskPattern(FVal, CondVal, SI, IC))
// return IC.replaceOperand(SI, 2, newFalseOp);

// return nullptr;
// }

/// Fold the following code sequence:
/// \code
/// int a = ctlz(x & -x);
Expand Down Expand Up @@ -4110,5 +4166,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}

// Attempts to recursively identify and fold (AND A constant) --> A
// in the true branch of the select if all bits
// that are zero in the negated constant are also zero in A's known zero bits.
if (TrueVal->hasOneUse())
if (Value *newTrueOp = foldAndMaskPattern(TrueVal, CondVal, SI, *this))
return replaceOperand(SI, 1, newTrueOp);

return nullptr;
}
121 changes: 121 additions & 0 deletions llvm/test/Transforms/InstCombine/select-known-bits.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt < %s -passes=instcombine -S | FileCheck %s

define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b) {
; CHECK-LABEL: define i8 @select_icmp_eq_mul_and(
; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[A]], [[A]]
; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B]]
; CHECK-NEXT: ret i8 [[RETVAL]]
;
%1 = and i8 %a, 1
%cmp = icmp eq i8 %1, 0
%div = and i8 %a, -2
%mul = mul i8 %div, %div
%retval = select i1 %cmp, i8 %mul, i8 %b
ret i8 %retval
}

define i8 @select_icmp_eq_mul_and_inv(i8 noundef %a, i8 %b) {
; CHECK-LABEL: define i8 @select_icmp_eq_mul_and_inv(
; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[A]], [[A]]
; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[MUL]], i8 [[B]]
; CHECK-NEXT: ret i8 [[RETVAL]]
;
%1 = and i8 %a, 1
%cmp = icmp eq i8 %1, 1
%div = and i8 %a, -2
%mul = mul i8 %div, %div
%retval = select i1 %cmp, i8 %b, i8 %mul
ret i8 %retval
}

define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b) {
; CHECK-LABEL: define i8 @select_icmp_eq_and(
; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[A]], i8 [[B]]
; CHECK-NEXT: ret i8 [[RETVAL]]
;
%1 = and i8 %a, 1
%cmp = icmp eq i8 %1, 0
%div = and i8 %a, -2
%retval = select i1 %cmp, i8 %div, i8 %b
ret i8 %retval
}

define i8 @select_icmp_eq_and_inv(i8 noundef %a, i8 %b) {
; CHECK-LABEL: define i8 @select_icmp_eq_and_inv(
; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[A]], i8 [[B]]
; CHECK-NEXT: ret i8 [[RETVAL]]
;
%1 = and i8 %a, 1
%cmp = icmp eq i8 %1, 1
%div = and i8 %a, -2
%retval = select i1 %cmp, i8 %b, i8 %div
ret i8 %retval
}

;negative test
define i8 @select_icmp_eq_and_undef(i8 %a, i8 %b) {
; CHECK-LABEL: define i8 @select_icmp_eq_and_undef(
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: [[DIV:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]]
; CHECK-NEXT: ret i8 [[RETVAL]]
;
%1 = and i8 %a, 1
%cmp = icmp eq i8 %1, 0
%div = and i8 %a, -2
%retval = select i1 %cmp, i8 %div, i8 %b
ret i8 %retval
}

;negative test
define i8 @select_icmp_eq_and_diff(i8 noundef %a, i8 %b, i8 %c) {
; CHECK-LABEL: define i8 @select_icmp_eq_and_diff(
; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: [[DIV:%.*]] = and i8 [[C]], -2
; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]]
; CHECK-NEXT: ret i8 [[RETVAL]]
;
%1 = and i8 %a, 1
%cmp = icmp eq i8 %1, 0
%div = and i8 %c, -2
%retval = select i1 %cmp, i8 %div, i8 %b
ret i8 %retval
}

;negative test
define i8 @select_icmp_eq_mul_and_extra_use(i8 noundef %a, i8 %b) {
; CHECK-LABEL: define i8 @select_icmp_eq_mul_and_extra_use(
; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
; CHECK-NEXT: [[DIV:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[DIV]], [[DIV]]
; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B]]
; CHECK-NEXT: [[SUM:%.*]] = add i8 [[MUL]], [[RETVAL]]
; CHECK-NEXT: ret i8 [[SUM]]
;
%1 = and i8 %a, 1
%cmp = icmp eq i8 %1, 0
%div = and i8 %a, -2
%mul = mul i8 %div, %div
%retval = select i1 %cmp, i8 %mul, i8 %b
%sum = add i8 %mul, %retval
ret i8 %sum
}
11 changes: 4 additions & 7 deletions llvm/test/Transforms/InstCombine/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2989,9 +2989,8 @@ define i8 @select_replacement_loop3(i32 noundef %x) {

define i16 @select_replacement_loop4(i16 noundef %p_12) {
; CHECK-LABEL: @select_replacement_loop4(
; CHECK-NEXT: [[AND1:%.*]] = and i16 [[P_12:%.*]], 1
; CHECK-NEXT: [[CMP21:%.*]] = icmp ult i16 [[P_12]], 2
; CHECK-NEXT: [[AND3:%.*]] = select i1 [[CMP21]], i16 [[AND1]], i16 0
; CHECK-NEXT: [[CMP21:%.*]] = icmp ult i16 [[P_12:%.*]], 2
; CHECK-NEXT: [[AND3:%.*]] = select i1 [[CMP21]], i16 [[P_12]], i16 0
; CHECK-NEXT: ret i16 [[AND3]]
;
%cmp1 = icmp ult i16 %p_12, 2
Expand Down Expand Up @@ -4671,8 +4670,7 @@ define i8 @select_knownbits_simplify(i8 noundef %x) {
; CHECK-LABEL: @select_knownbits_simplify(
; CHECK-NEXT: [[X_LO:%.*]] = and i8 [[X:%.*]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X_LO]], 0
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], -2
; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[AND]], i8 0
; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[X]], i8 0
; CHECK-NEXT: ret i8 [[RES]]
;
%x.lo = and i8 %x, 1
Expand All @@ -4686,8 +4684,7 @@ define i8 @select_knownbits_simplify_nested(i8 noundef %x) {
; CHECK-LABEL: @select_knownbits_simplify_nested(
; CHECK-NEXT: [[X_LO:%.*]] = and i8 [[X:%.*]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X_LO]], 0
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], -2
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[AND]], [[AND]]
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[X]]
; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 0
; CHECK-NEXT: ret i8 [[RES]]
;
Expand Down
Loading