Skip to content

[AArch64] Eliminate Common Subexpression of CSEL by Reassociation #121350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24838,6 +24838,122 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
}

// Reassociate the true/false expressions of a CSEL instruction to obtain a
// common subexpression with the comparison instruction. For example, change
// (CSEL (ADD (ADD x y) -c) f LO (SUBS x c)) to
// (CSEL (ADD (SUBS x c) y) f LO (SUBS x c)) such that (SUBS x c) is a common
// subexpression.
static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
SDValue SubsNode = N->getOperand(3);
if (SubsNode.getOpcode() != AArch64ISD::SUBS || !SubsNode.hasOneUse())
return SDValue();
auto *CmpOpConst = dyn_cast<ConstantSDNode>(SubsNode.getOperand(1));
if (!CmpOpConst)
return SDValue();

SDValue CmpOpOther = SubsNode.getOperand(0);
EVT VT = N->getValueType(0);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we handle eq and ne in the same way?

define i32 @src(i32 %x0, i32 %x1) {
  %cmp = icmp eq i32 %x1, 7
  %add = add i32 %x0, %x1
  %sub = sub i32 %add, 7
  %ret = select i1 %cmp, i32 0, i32 %sub
  ret i32 %ret
}

define i32 @tgt(i32 %x0, i32 %x1) {
  %cmp = icmp eq i32 %x1, 7
  %add = sub i32 %x1, 7
  %sub = add i32 %add, %x0
  %ret = select i1 %cmp, i32 0, i32 %sub
  ret i32 %ret
}

(I guess for all of those the constant isn't actually necessary and the basic pattern is that we can reassociate an add/sub to allow it to be shared with a sub. FYI There is a related but different patch to that (swapping conditions to allow sharing) added in #121412 (mostly for handling constants, for non-constants it already exists in a couple of places in llvm already). I understand this is getting further away from your motivating case. If we can fix the APInt issues then the follow on can easily be added in a later patch if that is simpler. Sometimes it is better to have smaller patches).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can also handle eq and ne. I've also implemented this.

I agree that smaller patches are sometimes better. I'll take a look into non-constant operands, which would probably require some more restructuring, and start working on a follow-up patch. For this PR, it is probably easier to stick to constants to keep the PR concise.

// Get the operand that can be reassociated with the SUBS instruction.
auto GetReassociationOp = [&](SDValue Op, APInt ExpectedConst) {
if (Op.getOpcode() != ISD::ADD)
return SDValue();
if (Op.getOperand(0).getOpcode() != ISD::ADD ||
!Op.getOperand(0).hasOneUse())
return SDValue();
SDValue X = Op.getOperand(0).getOperand(0);
SDValue Y = Op.getOperand(0).getOperand(1);
if (X != CmpOpOther)
std::swap(X, Y);
if (X != CmpOpOther)
return SDValue();
auto *AddOpConst = dyn_cast<ConstantSDNode>(Op.getOperand(1));
if (!AddOpConst || AddOpConst->getAPIntValue() != ExpectedConst)
return SDValue();
return Y;
};

// Try the reassociation using the given constant and condition code.
auto Fold = [&](APInt NewCmpConst, AArch64CC::CondCode NewCC) {
APInt ExpectedConst = -NewCmpConst;
SDValue TReassocOp = GetReassociationOp(N->getOperand(0), ExpectedConst);
SDValue FReassocOp = GetReassociationOp(N->getOperand(1), ExpectedConst);
if (!TReassocOp && !FReassocOp)
return SDValue();

SDValue NewCmp = DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode),
DAG.getVTList(VT, MVT_CC), CmpOpOther,
DAG.getConstant(NewCmpConst, SDLoc(CmpOpConst),
CmpOpConst->getValueType(0)));

auto Reassociate = [&](SDValue ReassocOp, unsigned OpNum) {
if (!ReassocOp)
return N->getOperand(OpNum);
SDValue Res = DAG.getNode(ISD::ADD, SDLoc(N->getOperand(OpNum)), VT,
NewCmp.getValue(0), ReassocOp);
DAG.ReplaceAllUsesWith(N->getOperand(OpNum), Res);
return Res;
};

SDValue TValReassoc = Reassociate(TReassocOp, 0);
SDValue FValReassoc = Reassociate(FReassocOp, 1);
return DAG.getNode(AArch64ISD::CSEL, SDLoc(N), VT, TValReassoc, FValReassoc,
DAG.getConstant(NewCC, SDLoc(N->getOperand(2)), MVT_CC),
NewCmp.getValue(1));
};

auto CC = static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));

// First, try to eliminate the compare instruction by searching for a
// subtraction with the same constant.
if (SDValue R = Fold(CmpOpConst->getAPIntValue(), CC))
return R;

if ((CC == AArch64CC::EQ || CC == AArch64CC::NE) && !CmpOpConst->isZero())
return SDValue();

// Next, search for a subtraction with a slightly different constant. By
// adjusting the condition code, we can still eliminate the compare
// instruction. Adjusting the constant is only valid if it does not result
// in signed/unsigned wrap for signed/unsigned comparisons, respectively.
// Since such comparisons are trivially true/false, we should not encounter
// them here but check for them nevertheless to be on the safe side.
auto CheckedFold = [&](bool Check, APInt NewCmpConst,
AArch64CC::CondCode NewCC) {
return Check ? Fold(NewCmpConst, NewCC) : SDValue();
};
switch (CC) {
case AArch64CC::EQ:
case AArch64CC::LS:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::LO);
case AArch64CC::NE:
case AArch64CC::HI:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::HS);
case AArch64CC::LO:
return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::LS);
case AArch64CC::HS:
return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::HI);
case AArch64CC::LT:
return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::LE);
case AArch64CC::LE:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::LT);
case AArch64CC::GT:
return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
CmpOpConst->getAPIntValue() + 1, AArch64CC::GE);
case AArch64CC::GE:
return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
CmpOpConst->getAPIntValue() - 1, AArch64CC::GT);
default:
return SDValue();
}
}

// Optimize CSEL instructions
static SDValue performCSELCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
Expand All @@ -24849,6 +24965,11 @@ static SDValue performCSELCombine(SDNode *N,
if (SDValue R = foldCSELOfCSEL(N, DAG))
return R;

// Try to reassociate the true/false expressions so that we can do CSE with
// a SUBS instruction used to perform the comparison.
if (SDValue R = reassociateCSELOperandsForCSE(N, DAG))
return R;

// CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
// CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
if (SDValue Folded = foldCSELofCTTZ(N, DAG))
Expand Down
Loading
Loading