Skip to content

[X86][CodeGen] Support lowering for CCMP/CTEST #91747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1801,11 +1801,8 @@ void DAGCombiner::Run(CombineLevel AtLevel) {

if (N->getNumValues() == RV->getNumValues())
DAG.ReplaceAllUsesWith(N, RV.getNode());
else {
assert(N->getValueType(0) == RV.getValueType() &&
N->getNumValues() == 1 && "Type mismatch");
Comment on lines -1805 to -1806
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can help to catch unexpected nodes. Do we need to remove both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need. After changing SUB to CCMP, neither of them satisfy. For this case, N->getValueType(1) == RV.getValueType() && N->getNumValues() == 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need. After changing SUB to CCMP, neither of them satisfy.

This means that there is a problem in combining SUB to CCMP. The assert here is to catch such kind of errors.
See my other comment.

else
DAG.ReplaceAllUsesWith(N, &RV);
}

// Push the new node and any users onto the worklist. Omit this if the
// new node is the EntryToken (e.g. if a store managed to get optimized
Expand Down
37 changes: 28 additions & 9 deletions llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1553,11 +1553,16 @@ void X86DAGToDAGISel::PostprocessISelDAG() {
switch (Opc) {
default:
continue;
// TESTrr+ANDrr/rm -> TESTrr/TESTmr
// ANDrr/rm + TESTrr+ -> TESTrr/TESTmr
case X86::TEST8rr:
case X86::TEST16rr:
case X86::TEST32rr:
case X86::TEST64rr: {
case X86::TEST64rr:
// ANDrr/rm + CTESTrr -> CTESTrr/CTESTmr
case X86::CTEST8rr:
case X86::CTEST16rr:
case X86::CTEST32rr:
case X86::CTEST64rr: {
auto &Op0 = N->getOperand(0);
if (Op0 != N->getOperand(1) || !Op0->hasNUsesOfValue(2, Op0.getResNo()) ||
!Op0.isMachineOpcode())
Expand All @@ -1575,8 +1580,11 @@ void X86DAGToDAGISel::PostprocessISelDAG() {
CASE_ND(AND64rr) {
if (And->hasAnyUseOfValue(1))
continue;
MachineSDNode *Test = CurDAG->getMachineNode(
Opc, SDLoc(N), MVT::i32, And.getOperand(0), And.getOperand(1));
SmallVector<SDValue> Ops(N->op_values());
Ops[0] = And.getOperand(0);
Ops[1] = And.getOperand(1);
MachineSDNode *Test =
CurDAG->getMachineNode(Opc, SDLoc(N), MVT::i32, Ops);
ReplaceUses(N, Test);
MadeChange = true;
continue;
Expand All @@ -1588,8 +1596,9 @@ void X86DAGToDAGISel::PostprocessISelDAG() {
if (And->hasAnyUseOfValue(1))
continue;
unsigned NewOpc;
bool IsCTESTCC = X86::isCTESTCC(Opc);
#define FROM_TO(A, B) \
CASE_ND(A) NewOpc = X86::B; \
CASE_ND(A) NewOpc = IsCTESTCC ? X86::C##B : X86::B; \
break;
switch (And.getMachineOpcode()) {
FROM_TO(AND8rm, TEST8mr);
Expand All @@ -1600,10 +1609,20 @@ void X86DAGToDAGISel::PostprocessISelDAG() {
#undef FROM_TO
#undef CASE_ND
// Need to swap the memory and register operand.
SDValue Ops[] = {And.getOperand(1), And.getOperand(2),
And.getOperand(3), And.getOperand(4),
And.getOperand(5), And.getOperand(0),
And.getOperand(6) /* Chain */};
SmallVector<SDValue> Ops = {And.getOperand(1), And.getOperand(2),
And.getOperand(3), And.getOperand(4),
And.getOperand(5), And.getOperand(0)};
// CC, Cflags.
if (IsCTESTCC) {
Ops.push_back(N->getOperand(2));
Ops.push_back(N->getOperand(3));
}
// Chain of memory load
Ops.push_back(And.getOperand(6));
// Glue
if (IsCTESTCC)
Ops.push_back(N->getOperand(4));

MachineSDNode *Test = CurDAG->getMachineNode(
NewOpc, SDLoc(N), MVT::i32, MVT::Other, Ops);
CurDAG->setNodeMemRefs(
Expand Down
174 changes: 171 additions & 3 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ static cl::opt<int> BrMergingBaseCostThresh(
"to never merge branches."),
cl::Hidden);

static cl::opt<int> BrMergingCcmpBias(
"x86-br-merging-ccmp-bias", cl::init(6),
cl::desc("Increases 'x86-br-merging-base-cost' in cases that the target "
"supports conditional compare instructions."),
cl::Hidden);

static cl::opt<int> BrMergingLikelyBias(
"x86-br-merging-likely-bias", cl::init(0),
cl::desc("Increases 'x86-br-merging-base-cost' in cases that it is likely "
Expand Down Expand Up @@ -3412,6 +3418,9 @@ X86TargetLowering::getJumpConditionMergingParams(Instruction::BinaryOps Opc,
const Value *Rhs) const {
using namespace llvm::PatternMatch;
int BaseCost = BrMergingBaseCostThresh.getValue();
// With CCMP, branches can be merged in a more efficient way.
if (BaseCost >= 0 && Subtarget.hasCCMP())
BaseCost += BrMergingCcmpBias;
// a == b && a == c is a fast pattern on x86.
ICmpInst::Predicate Pred;
if (BaseCost >= 0 && Opc == Instruction::And &&
Expand Down Expand Up @@ -33970,6 +33979,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(TESTUI)
NODE_NAME_CASE(FP80_ADD)
NODE_NAME_CASE(STRICT_FP80_ADD)
NODE_NAME_CASE(CCMP)
NODE_NAME_CASE(CTEST)
}
return nullptr;
#undef NODE_NAME_CASE
Expand Down Expand Up @@ -49217,6 +49228,147 @@ static SDValue combineBMILogicOp(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue combineX86SubCmpForFlags(SDNode *N, SDValue Flag,
SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &ST) {
// cmp(setcc(cc, X), 0)
// brcond ne
// ->
// X
// brcond cc

// sub(setcc(cc, X), 1)
// brcond ne
// ->
// X
// brcond ~cc
//
// if only flag has users

SDValue SetCC = N->getOperand(0);

// TODO: Remove the check hasCCMP() and update the non-APX tests.
if (!ST.hasCCMP() || SetCC.getOpcode() != X86ISD::SETCC || !Flag.hasOneUse())
return SDValue();

// Check the only user of flag is `brcond ne`.
SDNode *BrCond = *Flag->uses().begin();
if (BrCond->getOpcode() != X86ISD::BRCOND)
return SDValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking the users of the node, you should call this function from combineBrCond.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there such a limitation? Starting from the SUB/CMP seems more intuitive to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but it is the usual approach.

unsigned CondNo = 2;
if (static_cast<X86::CondCode>(BrCond->getConstantOperandVal(CondNo)) !=
X86::COND_NE)
return SDValue();

SDValue X = SetCC.getOperand(1);
// Replace API is called manually here b/c the number of results may change.
DAG.ReplaceAllUsesOfValueWith(Flag, X);

SDValue CCN = SetCC.getOperand(0);
X86::CondCode CC =
static_cast<X86::CondCode>(CCN->getAsAPIntVal().getSExtValue());
X86::CondCode OppositeCC = X86::GetOppositeBranchCondition(CC);
// Update CC for the consumer of the flag.
// The old CC is `ne`. Hence, when comparing the result with 0, we are
// checking if the second condition evaluates to true. When comparing the
// result with 1, we are checking uf the second condition evaluates to false.
SmallVector<SDValue> Ops(BrCond->op_values());
if (isNullConstant(N->getOperand(1)))
Ops[CondNo] = CCN;
else if (isOneConstant(N->getOperand(1)))
Ops[CondNo] = DAG.getTargetConstant(OppositeCC, SDLoc(BrCond), MVT::i8);
else
llvm_unreachable("expect constant 0 or 1");

SDValue NewBrCond =
DAG.getNode(X86ISD::BRCOND, SDLoc(BrCond), BrCond->getValueType(0), Ops);
// Avoid self-assign error b/c CC1 can be `e/ne`.
if (BrCond != NewBrCond.getNode())
DCI.CombineTo(BrCond, NewBrCond);
return X;
}

static SDValue combineAndOrForCcmpCtest(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &ST) {
// and/or(setcc(cc0, flag0), setcc(cc1, sub (X, Y)))
// ->
// setcc(cc1, ccmp(X, Y, ~cflags/cflags, cc0/~cc0, flag0))

// and/or(setcc(cc0, flag0), setcc(cc1, cmp (X, 0)))
// ->
// setcc(cc1, ctest(X, X, ~cflags/cflags, cc0/~cc0, flag0))
//
// where cflags is determined by cc1.

if (!ST.hasCCMP())
return SDValue();

SDValue SetCC0 = N->getOperand(0);
SDValue SetCC1 = N->getOperand(1);
if (SetCC0.getOpcode() != X86ISD::SETCC ||
SetCC1.getOpcode() != X86ISD::SETCC)
return SDValue();

auto GetCombineToOpc = [&](SDValue V) -> unsigned {
SDValue Op = V.getOperand(1);
unsigned Opc = Op.getOpcode();
if (Opc == X86ISD::SUB)
return X86ISD::CCMP;
if (Opc == X86ISD::CMP && isNullConstant(Op.getOperand(1)))
return X86ISD::CTEST;
return 0U;
};

unsigned NewOpc = 0;

// AND/OR is commutable. Canonicalize the operands to make SETCC with SUB/CMP
// appear on the right.
if (!(NewOpc = GetCombineToOpc(SetCC1))) {
std::swap(SetCC0, SetCC1);
if (!(NewOpc = GetCombineToOpc(SetCC1)))
return SDValue();
}

X86::CondCode CC0 =
static_cast<X86::CondCode>(SetCC0.getConstantOperandVal(0));
// CCMP/CTEST is not conditional when the source condition is COND_P/COND_NP.
if (CC0 == X86::COND_P || CC0 == X86::COND_NP)
return SDValue();

bool IsOR = N->getOpcode() == ISD::OR;

// CMP/TEST is executed and updates the EFLAGS normally only when SrcCC
// evaluates to true. So we need to inverse CC0 as SrcCC when the logic
// operator is OR. Similar for CC1.
SDValue SrcCC =
IsOR ? DAG.getTargetConstant(X86::GetOppositeBranchCondition(CC0),
SDLoc(SetCC0.getOperand(0)), MVT::i8)
: SetCC0.getOperand(0);
SDValue CC1N = SetCC1.getOperand(0);
X86::CondCode CC1 =
static_cast<X86::CondCode>(CC1N->getAsAPIntVal().getSExtValue());
X86::CondCode OppositeCC1 = X86::GetOppositeBranchCondition(CC1);
X86::CondCode CFlagsCC = IsOR ? CC1 : OppositeCC1;
SDLoc DL(N);
SDValue CFlags = DAG.getTargetConstant(
X86::getCCMPCondFlagsFromCondCode(CFlagsCC), DL, MVT::i8);
SDValue Sub = SetCC1.getOperand(1);

// Replace any uses of the old flag produced by SUB/CMP with the new one
// produced by CCMP/CTEST.
SDValue CCMP = (NewOpc == X86ISD::CCMP)
? DAG.getNode(X86ISD::CCMP, DL, MVT::i32,
{Sub.getOperand(0), Sub.getOperand(1),
CFlags, SrcCC, SetCC0.getOperand(1)})
: DAG.getNode(X86ISD::CTEST, DL, MVT::i32,
{Sub.getOperand(0), Sub.getOperand(0),
CFlags, SrcCC, SetCC0.getOperand(1)});

return DAG.getNode(X86ISD::SETCC, DL, MVT::i8, {CC1N, CCMP});
}

static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
Expand Down Expand Up @@ -49300,6 +49452,9 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
}
}

if (SDValue SetCC = combineAndOrForCcmpCtest(N, DAG, DCI, Subtarget))
return SetCC;

if (SDValue V = combineScalarAndWithMaskSetcc(N, DAG, Subtarget))
return V;

Expand Down Expand Up @@ -50085,6 +50240,9 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
}
}

if (SDValue SetCC = combineAndOrForCcmpCtest(N, DAG, DCI, Subtarget))
return SetCC;

if (SDValue R = combineBitOpWithMOVMSK(N, DAG))
return R;

Expand Down Expand Up @@ -54606,6 +54764,7 @@ static bool onlyZeroFlagUsed(SDValue Flags) {
}

static SDValue combineCMP(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
// Only handle test patterns.
if (!isNullConstant(N->getOperand(1)))
Expand All @@ -54620,6 +54779,10 @@ static SDValue combineCMP(SDNode *N, SelectionDAG &DAG,
EVT VT = Op.getValueType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

if (SDValue CMP =
combineX86SubCmpForFlags(N, SDValue(N, 0), DAG, DCI, Subtarget))
return CMP;

// If we have a constant logical shift that's only used in a comparison
// against zero turn it into an equivalent AND. This allows turning it into
// a TEST instruction later.
Expand Down Expand Up @@ -54748,7 +54911,8 @@ static SDValue combineCMP(SDNode *N, SelectionDAG &DAG,
}

static SDValue combineX86AddSub(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &ST) {
assert((X86ISD::ADD == N->getOpcode() || X86ISD::SUB == N->getOpcode()) &&
"Expected X86ISD::ADD or X86ISD::SUB");

Expand All @@ -54759,6 +54923,10 @@ static SDValue combineX86AddSub(SDNode *N, SelectionDAG &DAG,
bool IsSub = X86ISD::SUB == N->getOpcode();
unsigned GenericOpc = IsSub ? ISD::SUB : ISD::ADD;

if (IsSub && isOneConstant(N->getOperand(1)) && !N->hasAnyUseOfValue(0))
if (SDValue CMP = combineX86SubCmpForFlags(N, SDValue(N, 1), DAG, DCI, ST))
return CMP;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use DCI.CombineTo accepting two SDValues, the first one of which should be the first (unused) result of SUB. The SUB will be removed later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this DCI.CombineTo look like? I am trying replace the second result of SUB (old flag) with the result of CCMP (new flag). Not understand why you suggest the first parameter is the first (unused) result of SUB.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUB has two results, CMP has one result, which corresponds to the second result of the SUB.
Something like return DCI.CombineTo(N, SDValue(N, 0), CMP) would make the removed assert not trigger.
But I see you've found another way of silencing it.


// If we don't use the flag result, simplify back to a generic ADD/SUB.
if (!N->hasAnyUseOfValue(1)) {
SDValue Res = DAG.getNode(GenericOpc, DL, VT, LHS, RHS);
Expand Down Expand Up @@ -57058,11 +57226,11 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::BLENDV: return combineSelect(N, DAG, DCI, Subtarget);
case ISD::BITCAST: return combineBitcast(N, DAG, DCI, Subtarget);
case X86ISD::CMOV: return combineCMov(N, DAG, DCI, Subtarget);
case X86ISD::CMP: return combineCMP(N, DAG, Subtarget);
case X86ISD::CMP: return combineCMP(N, DAG, DCI, Subtarget);
case ISD::ADD: return combineAdd(N, DAG, DCI, Subtarget);
case ISD::SUB: return combineSub(N, DAG, DCI, Subtarget);
case X86ISD::ADD:
case X86ISD::SUB: return combineX86AddSub(N, DAG, DCI);
case X86ISD::SUB: return combineX86AddSub(N, DAG, DCI, Subtarget);
case X86ISD::SBB: return combineSBB(N, DAG);
case X86ISD::ADC: return combineADC(N, DAG, DCI);
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,10 @@ namespace llvm {
// Perform an FP80 add after changing precision control in FPCW.
FP80_ADD,

// Conditional compare instructions
CCMP,
CTEST,

/// X86 strict FP compare instructions.
STRICT_FCMP = ISD::FIRST_TARGET_STRICTFP_OPCODE,
STRICT_FCMPS,
Expand Down
Loading
Loading