Skip to content

Commit 8c78b70

Browse files
committed
[AArch64] Eliminate Common Subexpression of CSEL by Reassociation
If we have a CSEL instruction that depends on the flags set by a (SUBS x c) instruction and the true and/or false expression is (add (add x y) -c), we can reassociate the latter expression to (add (SUBS x c) y) and save one instruction. Proof for the basic transformation: https://alive2.llvm.org/ce/z/-337Pb We can extend this transformation for slightly different constants. For example, if we have (add (add x y) -(c-1)) and a the comparison x <u c, we can transform the comparison to x <=u c-1 to eliminate the comparison instruction, too. Similarly, we can transform (x == 0) to (x <u 1). Proofs for the transformations that alter the constants: https://alive2.llvm.org/ce/z/3nVqgR Fixes #119606.
1 parent c190e18 commit 8c78b70

File tree

2 files changed

+179
-82
lines changed

2 files changed

+179
-82
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24838,6 +24838,122 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
2483824838
return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
2483924839
}
2484024840

24841+
// Reassociate the true/false expressions of a CSEL instruction to obtain a
24842+
// common subexpression with the comparison instruction. For example, change
24843+
// (CSEL (ADD (ADD x y) -c) f LO (SUBS x c)) to
24844+
// (CSEL (ADD (SUBS x c) y) f LO (SUBS x c)) such that (SUBS x c) is a common
24845+
// subexpression.
24846+
static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
24847+
SDValue SubsNode = N->getOperand(3);
24848+
if (SubsNode.getOpcode() != AArch64ISD::SUBS || !SubsNode.hasOneUse())
24849+
return SDValue();
24850+
auto *CmpOpConst = dyn_cast<ConstantSDNode>(SubsNode.getOperand(1));
24851+
if (!CmpOpConst)
24852+
return SDValue();
24853+
24854+
auto CC = static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
24855+
bool IsEquality = CC == AArch64CC::EQ || CC == AArch64CC::NE;
24856+
if (IsEquality && !CmpOpConst->isZero())
24857+
return SDValue();
24858+
24859+
SDValue CmpOpOther = SubsNode.getOperand(0);
24860+
EVT VT = N->getValueType(0);
24861+
24862+
auto Reassociate = [&](SDValue Op, APInt ExpectedConst, SDValue NewCmp) {
24863+
if (Op.getOpcode() != ISD::ADD)
24864+
return SDValue();
24865+
auto *AddOpConst = dyn_cast<ConstantSDNode>(Op.getOperand(1));
24866+
if (!AddOpConst)
24867+
return SDValue();
24868+
if (AddOpConst->getAPIntValue() != ExpectedConst)
24869+
return SDValue();
24870+
if (Op.getOperand(0).getOpcode() != ISD::ADD ||
24871+
!Op.getOperand(0).hasOneUse())
24872+
return SDValue();
24873+
SDValue X = Op.getOperand(0).getOperand(0);
24874+
SDValue Y = Op.getOperand(0).getOperand(1);
24875+
if (X != CmpOpOther)
24876+
std::swap(X, Y);
24877+
if (X != CmpOpOther)
24878+
return SDValue();
24879+
SDNodeFlags Flags;
24880+
if (Op.getOperand(0).getNode()->getFlags().hasNoUnsignedWrap())
24881+
Flags.setNoUnsignedWrap(true);
24882+
return DAG.getNode(ISD::ADD, SDLoc(Op), VT, NewCmp.getValue(0), Y, Flags);
24883+
};
24884+
24885+
auto Fold = [&](APInt NewCmpConst, AArch64CC::CondCode NewCC) {
24886+
SDValue NewCmp = DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode),
24887+
DAG.getVTList(VT, MVT_CC), CmpOpOther,
24888+
DAG.getConstant(NewCmpConst, SDLoc(CmpOpConst),
24889+
CmpOpConst->getValueType(0)));
24890+
24891+
APInt ExpectedConst = -NewCmpConst;
24892+
SDValue TValReassoc = Reassociate(N->getOperand(0), ExpectedConst, NewCmp);
24893+
SDValue FValReassoc = Reassociate(N->getOperand(1), ExpectedConst, NewCmp);
24894+
if (!TValReassoc && !FValReassoc)
24895+
return SDValue();
24896+
if (TValReassoc)
24897+
DAG.ReplaceAllUsesWith(N->getOperand(0), TValReassoc);
24898+
else
24899+
TValReassoc = N->getOperand(0);
24900+
if (FValReassoc)
24901+
DAG.ReplaceAllUsesWith(N->getOperand(1), FValReassoc);
24902+
else
24903+
FValReassoc = N->getOperand(1);
24904+
return DAG.getNode(AArch64ISD::CSEL, SDLoc(N), VT, TValReassoc, FValReassoc,
24905+
DAG.getConstant(NewCC, SDLoc(N->getOperand(2)), MVT_CC),
24906+
NewCmp.getValue(1));
24907+
};
24908+
24909+
// First, try to eliminate the compare instruction by searching for a
24910+
// subtraction with the same constant.
24911+
if (!IsEquality) // Not useful for equalities
24912+
if (SDValue R = Fold(CmpOpConst->getAPIntValue(), CC))
24913+
return R;
24914+
24915+
// Next, search for a subtraction with a slightly different constant. By
24916+
// adjusting the condition code, we can still eliminate the compare
24917+
// instruction. Adjusting the constant is only valid if it does not result
24918+
// in signed/unsigned wrap for signed/unsigned comparisons, respectively.
24919+
// Since such comparisons are trivially true/false, we should not encounter
24920+
// them here but check for them nevertheless to be on the safe side.
24921+
auto CheckedFold = [&](bool Check, APInt NewCmpConst,
24922+
AArch64CC::CondCode NewCC) {
24923+
return Check ? Fold(NewCmpConst, NewCC) : SDValue();
24924+
};
24925+
switch (CC) {
24926+
case AArch64CC::EQ:
24927+
case AArch64CC::LS:
24928+
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
24929+
CmpOpConst->getAPIntValue() + 1, AArch64CC::LO);
24930+
case AArch64CC::NE:
24931+
case AArch64CC::HI:
24932+
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
24933+
CmpOpConst->getAPIntValue() + 1, AArch64CC::HS);
24934+
case AArch64CC::LO:
24935+
return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
24936+
CmpOpConst->getAPIntValue() - 1, AArch64CC::LS);
24937+
case AArch64CC::HS:
24938+
return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
24939+
CmpOpConst->getAPIntValue() - 1, AArch64CC::HI);
24940+
case AArch64CC::LT:
24941+
return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
24942+
CmpOpConst->getAPIntValue() - 1, AArch64CC::LE);
24943+
case AArch64CC::LE:
24944+
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
24945+
CmpOpConst->getAPIntValue() + 1, AArch64CC::LT);
24946+
case AArch64CC::GT:
24947+
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
24948+
CmpOpConst->getAPIntValue() + 1, AArch64CC::GE);
24949+
case AArch64CC::GE:
24950+
return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
24951+
CmpOpConst->getAPIntValue() - 1, AArch64CC::GT);
24952+
default:
24953+
return SDValue();
24954+
}
24955+
}
24956+
2484124957
// Optimize CSEL instructions
2484224958
static SDValue performCSELCombine(SDNode *N,
2484324959
TargetLowering::DAGCombinerInfo &DCI,
@@ -24849,6 +24965,11 @@ static SDValue performCSELCombine(SDNode *N,
2484924965
if (SDValue R = foldCSELOfCSEL(N, DAG))
2485024966
return R;
2485124967

24968+
// Try to reassociate the true/false expressions so that we can do CSE with
24969+
// a SUBS instruction used to perform the comparison.
24970+
if (SDValue R = reassociateCSELOperandsForCSE(N, DAG))
24971+
return R;
24972+
2485224973
// CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
2485324974
// CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
2485424975
if (SDValue Folded = foldCSELofCTTZ(N, DAG))

0 commit comments

Comments
 (0)