Skip to content

[RISCV] Use vrsub for select of add and sub of the same operands #123400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 83 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::UDIV, ISD::SREM,
ISD::UREM, ISD::INSERT_VECTOR_ELT,
ISD::ABS, ISD::CTPOP,
ISD::VECTOR_SHUFFLE});
ISD::VECTOR_SHUFFLE, ISD::VSELECT});

if (Subtarget.hasVendorXTHeadMemPair())
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
if (Subtarget.useRVVForFixedLengthVectors())
Expand Down Expand Up @@ -16874,6 +16875,53 @@ static SDValue useInversedSetcc(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static bool matchSelectAddSub(SDValue TrueVal, SDValue FalseVal, bool &SwapCC) {
if (!TrueVal.hasOneUse() || !FalseVal.hasOneUse())
return false;

SwapCC = false;
if (TrueVal.getOpcode() == ISD::SUB && FalseVal.getOpcode() == ISD::ADD) {
std::swap(TrueVal, FalseVal);
SwapCC = true;
}

if (TrueVal.getOpcode() != ISD::ADD || FalseVal.getOpcode() != ISD::SUB)
return false;

SDValue A = FalseVal.getOperand(0);
SDValue B = FalseVal.getOperand(1);
// Add is commutative, so check both orders
return ((TrueVal.getOperand(0) == A && TrueVal.getOperand(1) == B) ||
(TrueVal.getOperand(1) == A && TrueVal.getOperand(0) == B));
}

/// Convert vselect CC, (add a, b), (sub a, b) to add a, (vselect CC, -b, b).
/// This allows us match a vadd.vv fed by a masked vrsub, which reduces
/// register pressure over the add followed by masked vsub sequence.
static SDValue performVSELECTCombine(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue CC = N->getOperand(0);
SDValue TrueVal = N->getOperand(1);
SDValue FalseVal = N->getOperand(2);

bool SwapCC;
if (!matchSelectAddSub(TrueVal, FalseVal, SwapCC))
return SDValue();

SDValue Sub = SwapCC ? TrueVal : FalseVal;
SDValue A = Sub.getOperand(0);
SDValue B = Sub.getOperand(1);

// Arrange the select such that we can match a masked
// vrsub.vi to perform the conditional negate
SDValue NegB = DAG.getNegative(B, DL, VT);
if (!SwapCC)
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
}

static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (SDValue Folded = foldSelectOfCTTZOrCTLZ(N, DAG))
Expand Down Expand Up @@ -17153,20 +17201,48 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
}

/// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
/// during the combine phase before type legalization, and relies on
/// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
/// for the source mask.
static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
const unsigned ElementSize = VT.getScalarSizeInBits();
const unsigned NumElts = VT.getVectorNumElements();
SDValue V1 = N->getOperand(0);
SDValue V2 = N->getOperand(1);
ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
MVT XLenVT = Subtarget.getXLenVT();

// Recognized a disguised select of add/sub.
bool SwapCC;
if (ShuffleVectorInst::isSelectMask(Mask, NumElts) &&
matchSelectAddSub(V1, V2, SwapCC)) {
SDValue Sub = SwapCC ? V1 : V2;
SDValue A = Sub.getOperand(0);
SDValue B = Sub.getOperand(1);

SmallVector<SDValue> MaskVals;
for (int MaskIndex : Mask) {
bool SelectMaskVal = (MaskIndex < (int)NumElts);
MaskVals.push_back(DAG.getConstant(SelectMaskVal, DL, XLenVT));
}
assert(MaskVals.size() == NumElts && "Unexpected select-like shuffle");
EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
SDValue CC = DAG.getBuildVector(MaskVT, DL, MaskVals);

// Arrange the select such that we can match a masked
// vrsub.vi to perform the conditional negate
SDValue NegB = DAG.getNegative(B, DL, VT);
if (!SwapCC)
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
}

// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
// during the combine phase before type legalization, and relies on
// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
// for the source mask.
if (TLI.isTypeLegal(VT) || ElementSize <= Subtarget.getELen() ||
!isPowerOf2_64(ElementSize) || VT.getVectorNumElements() % 2 != 0 ||
VT.isFloatingPoint() || TLI.isShuffleMaskLegal(Mask, VT))
Expand All @@ -17183,7 +17259,6 @@ static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
return DAG.getBitcast(VT, Res);
}


static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {

Expand Down Expand Up @@ -17857,6 +17932,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
return performSELECTCombine(N, DAG, Subtarget);
case ISD::VSELECT:
return performVSELECTCombine(N, DAG);
case RISCVISD::CZERO_EQZ:
case RISCVISD::CZERO_NEZ: {
SDValue Val = N->getOperand(0);
Expand Down
Loading
Loading