Skip to content

Commit a9ad601

Browse files
authored
[RISCV] Use vrsub for select of add and sub of the same operands (#123400)
If we have a (vselect c, a+b, a-b), we can combine this to a+(vselect c, b, -b). That by itself isn't hugely profitable, but if we reverse the select, we get a form which matches a masked vrsub.vi with zero. The result is that we can use a masked vrsub *before* the add instead of a masked add or sub. This doesn't change the critical path (since we already had the pass through on the masked second op), but does reduce register pressure since a, b, and (a+b) don't need to all be alive at once. In addition to the vselect form, we can also see the same pattern with a vector_shuffle encoding the vselect. I explored canonicalizing these to vselects instead, but that exposes several unrelated missing combines.
1 parent 7293455 commit a9ad601

File tree

2 files changed

+139
-112
lines changed

2 files changed

+139
-112
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15351535
ISD::UDIV, ISD::SREM,
15361536
ISD::UREM, ISD::INSERT_VECTOR_ELT,
15371537
ISD::ABS, ISD::CTPOP,
1538-
ISD::VECTOR_SHUFFLE});
1538+
ISD::VECTOR_SHUFFLE, ISD::VSELECT});
1539+
15391540
if (Subtarget.hasVendorXTHeadMemPair())
15401541
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
15411542
if (Subtarget.useRVVForFixedLengthVectors())
@@ -16874,6 +16875,53 @@ static SDValue useInversedSetcc(SDNode *N, SelectionDAG &DAG,
1687416875
return SDValue();
1687516876
}
1687616877

16878+
static bool matchSelectAddSub(SDValue TrueVal, SDValue FalseVal, bool &SwapCC) {
16879+
if (!TrueVal.hasOneUse() || !FalseVal.hasOneUse())
16880+
return false;
16881+
16882+
SwapCC = false;
16883+
if (TrueVal.getOpcode() == ISD::SUB && FalseVal.getOpcode() == ISD::ADD) {
16884+
std::swap(TrueVal, FalseVal);
16885+
SwapCC = true;
16886+
}
16887+
16888+
if (TrueVal.getOpcode() != ISD::ADD || FalseVal.getOpcode() != ISD::SUB)
16889+
return false;
16890+
16891+
SDValue A = FalseVal.getOperand(0);
16892+
SDValue B = FalseVal.getOperand(1);
16893+
// Add is commutative, so check both orders
16894+
return ((TrueVal.getOperand(0) == A && TrueVal.getOperand(1) == B) ||
16895+
(TrueVal.getOperand(1) == A && TrueVal.getOperand(0) == B));
16896+
}
16897+
16898+
/// Convert vselect CC, (add a, b), (sub a, b) to add a, (vselect CC, -b, b).
16899+
/// This allows us match a vadd.vv fed by a masked vrsub, which reduces
16900+
/// register pressure over the add followed by masked vsub sequence.
16901+
static SDValue performVSELECTCombine(SDNode *N, SelectionDAG &DAG) {
16902+
SDLoc DL(N);
16903+
EVT VT = N->getValueType(0);
16904+
SDValue CC = N->getOperand(0);
16905+
SDValue TrueVal = N->getOperand(1);
16906+
SDValue FalseVal = N->getOperand(2);
16907+
16908+
bool SwapCC;
16909+
if (!matchSelectAddSub(TrueVal, FalseVal, SwapCC))
16910+
return SDValue();
16911+
16912+
SDValue Sub = SwapCC ? TrueVal : FalseVal;
16913+
SDValue A = Sub.getOperand(0);
16914+
SDValue B = Sub.getOperand(1);
16915+
16916+
// Arrange the select such that we can match a masked
16917+
// vrsub.vi to perform the conditional negate
16918+
SDValue NegB = DAG.getNegative(B, DL, VT);
16919+
if (!SwapCC)
16920+
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
16921+
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
16922+
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
16923+
}
16924+
1687716925
static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG,
1687816926
const RISCVSubtarget &Subtarget) {
1687916927
if (SDValue Folded = foldSelectOfCTTZOrCTLZ(N, DAG))
@@ -17153,20 +17201,48 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1715317201
return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
1715417202
}
1715517203

17156-
/// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
17157-
/// during the combine phase before type legalization, and relies on
17158-
/// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
17159-
/// for the source mask.
1716017204
static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
1716117205
const RISCVSubtarget &Subtarget,
1716217206
const RISCVTargetLowering &TLI) {
1716317207
SDLoc DL(N);
1716417208
EVT VT = N->getValueType(0);
1716517209
const unsigned ElementSize = VT.getScalarSizeInBits();
17210+
const unsigned NumElts = VT.getVectorNumElements();
1716617211
SDValue V1 = N->getOperand(0);
1716717212
SDValue V2 = N->getOperand(1);
1716817213
ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
17214+
MVT XLenVT = Subtarget.getXLenVT();
17215+
17216+
// Recognized a disguised select of add/sub.
17217+
bool SwapCC;
17218+
if (ShuffleVectorInst::isSelectMask(Mask, NumElts) &&
17219+
matchSelectAddSub(V1, V2, SwapCC)) {
17220+
SDValue Sub = SwapCC ? V1 : V2;
17221+
SDValue A = Sub.getOperand(0);
17222+
SDValue B = Sub.getOperand(1);
17223+
17224+
SmallVector<SDValue> MaskVals;
17225+
for (int MaskIndex : Mask) {
17226+
bool SelectMaskVal = (MaskIndex < (int)NumElts);
17227+
MaskVals.push_back(DAG.getConstant(SelectMaskVal, DL, XLenVT));
17228+
}
17229+
assert(MaskVals.size() == NumElts && "Unexpected select-like shuffle");
17230+
EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
17231+
SDValue CC = DAG.getBuildVector(MaskVT, DL, MaskVals);
1716917232

17233+
// Arrange the select such that we can match a masked
17234+
// vrsub.vi to perform the conditional negate
17235+
SDValue NegB = DAG.getNegative(B, DL, VT);
17236+
if (!SwapCC)
17237+
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
17238+
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
17239+
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
17240+
}
17241+
17242+
// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
17243+
// during the combine phase before type legalization, and relies on
17244+
// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
17245+
// for the source mask.
1717017246
if (TLI.isTypeLegal(VT) || ElementSize <= Subtarget.getELen() ||
1717117247
!isPowerOf2_64(ElementSize) || VT.getVectorNumElements() % 2 != 0 ||
1717217248
VT.isFloatingPoint() || TLI.isShuffleMaskLegal(Mask, VT))
@@ -17183,7 +17259,6 @@ static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
1718317259
return DAG.getBitcast(VT, Res);
1718417260
}
1718517261

17186-
1718717262
static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
1718817263
const RISCVSubtarget &Subtarget) {
1718917264

@@ -17857,6 +17932,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1785717932
return performTRUNCATECombine(N, DAG, Subtarget);
1785817933
case ISD::SELECT:
1785917934
return performSELECTCombine(N, DAG, Subtarget);
17935+
case ISD::VSELECT:
17936+
return performVSELECTCombine(N, DAG);
1786017937
case RISCVISD::CZERO_EQZ:
1786117938
case RISCVISD::CZERO_NEZ: {
1786217939
SDValue Val = N->getOperand(0);

0 commit comments

Comments
 (0)