Skip to content

[InstCombine] Modify foldSelectICmpEq to only handle more useful and simple cases. #121672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 67 additions & 80 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1821,8 +1821,12 @@ static Instruction *foldSelectWithExtremeEqCond(Value *CmpLHS, Value *CmpRHS,
return new ICmpInst(Pred, CmpLHS, B);
}

// Fold (X Op0 Y) == 0 ? (X Op1 Y) : (X Op2 Y)
// -> (X Op2 Y)
// By proving that `(X Op1 Y) == (X Op2 Y)` in the context of `(X Op0 Y) == 0`.
static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
InstCombinerImpl &IC) {

ICmpInst::Predicate Pred = ICI->getPredicate();
if (!ICmpInst::isEquality(Pred))
return nullptr;
Expand All @@ -1835,96 +1839,79 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
if (Pred == ICmpInst::ICMP_NE)
std::swap(TrueVal, FalseVal);

if (Instruction *Res =
foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal))
return Res;
if (auto *R = foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal))
return R;

// Transform (X == C) ? X : Y -> (X == C) ? C : Y
// specific handling for Bitwise operation.
// x&y -> (x|y) ^ (x^y) or (x|y) & ~(x^y)
// x|y -> (x&y) | (x^y) or (x&y) ^ (x^y)
// x^y -> (x|y) ^ (x&y) or (x|y) & ~(x&y)
Value *X, *Y;
if (!match(CmpLHS, m_BitwiseLogic(m_Value(X), m_Value(Y))) ||
!match(TrueVal, m_c_BitwiseLogic(m_Specific(X), m_Specific(Y))))
return nullptr;

const unsigned AndOps = Instruction::And, OrOps = Instruction::Or,
XorOps = Instruction::Xor, NoOps = 0;
enum NotMask { None = 0, NotInner, NotRHS };

auto matchFalseVal = [&](unsigned OuterOpc, unsigned InnerOpc,
unsigned NotMask) {
auto matchInner = m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y));
if (OuterOpc == NoOps)
return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner);

if (NotMask == NotInner) {
return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner),
m_Specific(CmpRHS)));
} else if (NotMask == NotRHS) {
return match(FalseVal, m_c_BinOp(OuterOpc, matchInner,
m_NotForbidPoison(m_Specific(CmpRHS))));
} else {
return match(FalseVal,
m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS)));
}
};

// (X&Y)==C ? X|Y : X^Y -> (X^Y)|C : X^Y or (X^Y)^ C : X^Y
// (X&Y)==C ? X^Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y
if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
// (X&Y)==C ? X|Y : (X^Y)|C -> (X^Y)|C : (X^Y)|C -> (X^Y)|C
// (X&Y)==C ? X|Y : (X^Y)^C -> (X^Y)^C : (X^Y)^C -> (X^Y)^C
if (matchFalseVal(OrOps, XorOps, None) ||
matchFalseVal(XorOps, XorOps, None))
return IC.replaceInstUsesWith(SI, FalseVal);
} else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
// (X&Y)==C ? X^Y : (X|Y)^ C -> (X|Y)^ C : (X|Y)^ C -> (X|Y)^ C
// (X&Y)==C ? X^Y : (X|Y)&~C -> (X|Y)&~C : (X|Y)&~C -> (X|Y)&~C
if (matchFalseVal(XorOps, OrOps, None) ||
matchFalseVal(AndOps, OrOps, NotRHS))
if (match(CmpRHS, m_Zero())) {
// (X & Y) == 0 ? X |/^/+ Y : X |/^/+ Y -> X |/^/+ Y (false arm)
// `(X & Y) == 0` implies no common bits which means:
// `X ^ Y == X | Y == X + Y`
// https://alive2.llvm.org/ce/z/jjcduh
if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
auto MatchAddOrXor =
m_CombineOr(m_c_Add(m_Specific(X), m_Specific(Y)),
m_CombineOr(m_c_Or(m_Specific(X), m_Specific(Y)),
m_c_Xor(m_Specific(X), m_Specific(Y))));
if (match(TrueVal, MatchAddOrXor) && match(FalseVal, MatchAddOrXor))
return IC.replaceInstUsesWith(SI, FalseVal);
}
}

// (X|Y)==C ? X&Y : X^Y -> (X^Y)^C : X^Y or ~(X^Y)&C : X^Y
// (X|Y)==C ? X^Y : X&Y -> (X&Y)^C : X&Y or ~(X&Y)&C : X&Y
if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y)))) {
if (match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y)))) {
// (X|Y)==C ? X&Y: (X^Y)^C -> (X^Y)^C: (X^Y)^C -> (X^Y)^C
// (X|Y)==C ? X&Y:~(X^Y)&C ->~(X^Y)&C:~(X^Y)&C -> ~(X^Y)&C
if (matchFalseVal(XorOps, XorOps, None) ||
matchFalseVal(AndOps, XorOps, NotInner))
return IC.replaceInstUsesWith(SI, FalseVal);
} else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
// (X|Y)==C ? X^Y : (X&Y)^C -> (X&Y)^C : (X&Y)^C -> (X&Y)^C
// (X|Y)==C ? X^Y :~(X&Y)&C -> ~(X&Y)&C :~(X&Y)&C -> ~(X&Y)&C
if (matchFalseVal(XorOps, AndOps, None) ||
matchFalseVal(AndOps, AndOps, NotInner))
return IC.replaceInstUsesWith(SI, FalseVal);
}
}
// (X | Y) == 0 ? X Op0 Y : X Op1 Y -> X Op1 Y
// For any `Op0` and `Op1` that are zero when `X` and `Y` are zero.
// https://alive2.llvm.org/ce/z/azHzBW
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In conjunction with my suggestion for the eq case, a principled way to handle this would be to extend

// select((X | Y) == 0 ? X : 0) --> 0 (commuted 2 ways)
if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
match(CmpRHS, m_Zero())) {
// (X | Y) == 0 implies X == 0 and Y == 0.
if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
Q, MaxRecurse))
return V;
if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
Q, MaxRecurse))
return V;
}
and simplifyWithOpReplaced to support multiple replacements at the same time, instead of trying to replace X and Y with 0 individually. (Whether this is feasible depends on the compile-time impact of course.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for X | Y == 0, but not the X & Y == 0 case.

I think best way forward is add that simplification with follow up to replace X & Y == 0 ? X ^/+ Y : ... with X & Y == 0 ? X | Y : .... I don't think the full fold of X & Y == 0 : X | Y : X ^/|/+ Y really exists, but it would be just a single line to add if we decide its worthwhile.

if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
(match(TrueVal, m_c_BinOp(m_Specific(X), m_Specific(Y))) ||
// In true arm we can also accept just `0`.
match(TrueVal, m_Zero())) &&
match(FalseVal, m_c_BinOp(m_Specific(X), m_Specific(Y)))) {
auto IsOpcZeroWithZeros = [](Value *V) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
return false;
switch (I->getOpcode()) {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
case Instruction::Mul:
case Instruction::Add:
case Instruction::Sub:
case Instruction::Shl:
case Instruction::AShr:
case Instruction::LShr:
return true;
default:
return false;
}
};

// (X^Y)==C ? X&Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y
// (X^Y)==C ? X|Y : X&Y -> (X&Y)|C : X&Y or (X&Y)^ C : X&Y
if (match(CmpLHS, m_Xor(m_Value(X), m_Value(Y)))) {
if ((match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y))))) {
// (X^Y)==C ? X&Y : (X|Y)^C -> (X|Y)^C
// (X^Y)==C ? X&Y : (X|Y)&~C -> (X|Y)&~C
if (matchFalseVal(XorOps, OrOps, None) ||
matchFalseVal(AndOps, OrOps, NotRHS))
return IC.replaceInstUsesWith(SI, FalseVal);
} else if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
// (X^Y)==C ? (X|Y) : (X&Y)|C -> (X&Y)|C
// (X^Y)==C ? (X|Y) : (X&Y)^C -> (X&Y)^C
if (matchFalseVal(OrOps, AndOps, None) ||
matchFalseVal(XorOps, AndOps, None))
if ((match(TrueVal, m_Zero()) || IsOpcZeroWithZeros(TrueVal)) &&
IsOpcZeroWithZeros(FalseVal))
return IC.replaceInstUsesWith(SI, FalseVal);
}
}
// (X == Y) ? X | Y : X & Y
// (X == Y) ? X & Y : X | Y
// If `X == Y` then `X == Y == X | Y == X & Y`.
// NB: `X == Y` is canonicalization of `(X ^ Y) == 0`.
// https://alive2.llvm.org/ce/z/SJskbz
X = CmpLHS;
Y = CmpRHS;
auto MatchOrAnd = m_CombineOr(m_c_Or(m_Specific(X), m_Specific(Y)),
m_c_And(m_Specific(X), m_Specific(Y)));
if (match(FalseVal, MatchOrAnd) &&
// In the true arm we can also just match `X` or `Y`.
(match(TrueVal, MatchOrAnd) || match(TrueVal, m_Specific(X)) ||
match(TrueVal, m_Specific(Y)))) {
// Can't preserve `or disjoint` here so rebuild.
auto *BO = dyn_cast<BinaryOperator>(FalseVal);
if (!BO)
return nullptr;

return IC.replaceInstUsesWith(
SI, IC.Builder.CreateBinOp(BO->getOpcode(), BO->getOperand(0),
BO->getOperand(1)));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we handle this case by changing

/// Try to simplify a select instruction when its condition operand is an
/// integer equality or floating-point equivalence comparison.
static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
Value *TrueVal, Value *FalseVal,
const SimplifyQuery &Q,
unsigned MaxRecurse) {
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
/* AllowRefinement */ false,
/* DropFlags */ nullptr, MaxRecurse) == TrueVal)
return FalseVal;
if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ true,
/* DropFlags */ nullptr, MaxRecurse) == FalseVal)
return FalseVal;
return nullptr;
}
to compare both simplified values instead? Currently we simplify one and compare against the original other. If we compared both simplified values, I think we should be able to handle this pattern without any specialized code.

return nullptr;
}

Expand Down
Loading
Loading