Skip to content

Commit 8037046

Browse files
preameslukel97
andauthored
[DAG] Add wrappers for insert_vector_elt and extract_vector_elt [nfc] (#139141)
As with the recently added subvector variants, provide the unsigned index operand to simplify a bunch of code. --------- Co-authored-by: Luke Lau <[email protected]>
1 parent f058333 commit 8037046

File tree

3 files changed

+40
-42
lines changed

3 files changed

+40
-42
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,22 @@ class SelectionDAG {
924924
/// Example: shuffle A, B, <0,5,2,7> -> shuffle B, A, <4,1,6,3>
925925
SDValue getCommutedVectorShuffle(const ShuffleVectorSDNode &SV);
926926

927+
/// Extract element at \p Idx from \p Vec. See EXTRACT_VECTOR_ELT
928+
/// description for result type handling.
929+
SDValue getExtractVectorElt(const SDLoc &DL, EVT VT, SDValue Vec,
930+
unsigned Idx) {
931+
return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Vec,
932+
getVectorIdxConstant(Idx, DL));
933+
}
934+
935+
/// Insert \p Elt into \p Vec at offset \p Idx. See INSERT_VECTOR_ELT
936+
/// description for element type handling.
937+
SDValue getInsertVectorElt(const SDLoc &DL, SDValue Vec, SDValue Elt,
938+
unsigned Idx) {
939+
return getNode(ISD::INSERT_VECTOR_ELT, DL, Vec.getValueType(), Vec, Elt,
940+
getVectorIdxConstant(Idx, DL));
941+
}
942+
927943
/// Insert \p SubVec at the \p Idx element of \p Vec.
928944
SDValue getInsertSubvector(const SDLoc &DL, SDValue Vec, SDValue SubVec,
929945
unsigned Idx) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3244,8 +3244,7 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
32443244
if (LegalSVT.bitsLT(SVT))
32453245
return SDValue();
32463246
}
3247-
return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), LegalSVT, SrcVector,
3248-
getVectorIdxConstant(SplatIdx, SDLoc(V)));
3247+
return getExtractVectorElt(SDLoc(V), LegalSVT, SrcVector, SplatIdx);
32493248
}
32503249
return SDValue();
32513250
}
@@ -7557,11 +7556,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
75577556
// elements.
75587557
if (N2C && N1.getOpcode() == ISD::CONCAT_VECTORS &&
75597558
N1.getOperand(0).getValueType().isFixedLengthVector()) {
7560-
unsigned Factor =
7561-
N1.getOperand(0).getValueType().getVectorNumElements();
7562-
return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
7563-
N1.getOperand(N2C->getZExtValue() / Factor),
7564-
getVectorIdxConstant(N2C->getZExtValue() % Factor, DL));
7559+
unsigned Factor = N1.getOperand(0).getValueType().getVectorNumElements();
7560+
return getExtractVectorElt(DL, VT,
7561+
N1.getOperand(N2C->getZExtValue() / Factor),
7562+
N2C->getZExtValue() % Factor);
75657563
}
75667564

75677565
// EXTRACT_VECTOR_ELT of BUILD_VECTOR or SPLAT_VECTOR is often formed while
@@ -8624,8 +8622,7 @@ static SDValue getMemsetStores(SelectionDAG &DAG, const SDLoc &dl,
86248622
// Target which can combine store(extractelement VectorTy, Idx) can get
86258623
// the smaller value for free.
86268624
SDValue TailValue = DAG.getNode(ISD::BITCAST, dl, SVT, MemSetValue);
8627-
Value = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, TailValue,
8628-
DAG.getVectorIdxConstant(Index, dl));
8625+
Value = DAG.getExtractVectorElt(dl, VT, TailValue, Index);
86298626
} else
86308627
Value = getMemsetValue(Src, VT, DAG, dl);
86318628
}
@@ -12775,8 +12772,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
1277512772

1277612773
// A vector operand; extract a single element.
1277712774
EVT OperandEltVT = OperandVT.getVectorElementType();
12778-
Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT,
12779-
Operand, getVectorIdxConstant(i, dl));
12775+
Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i);
1278012776
}
1278112777

1278212778
SDValue EltOp = getNode(N->getOpcode(), dl, {EltVT, EltVT1}, Operands);
@@ -12810,8 +12806,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
1281012806
if (OperandVT.isVector()) {
1281112807
// A vector operand; extract a single element.
1281212808
EVT OperandEltVT = OperandVT.getVectorElementType();
12813-
Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT,
12814-
Operand, getVectorIdxConstant(i, dl));
12809+
Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i);
1281512810
} else {
1281612811
// A scalar operand; just use it as is.
1281712812
Operands[j] = Operand;
@@ -13090,8 +13085,7 @@ void SelectionDAG::ExtractVectorElements(SDValue Op,
1309013085
EltVT = VT.getVectorElementType();
1309113086
SDLoc SL(Op);
1309213087
for (unsigned i = Start, e = Start + Count; i != e; ++i) {
13093-
Args.push_back(getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Op,
13094-
getVectorIdxConstant(i, SL)));
13088+
Args.push_back(getExtractVectorElt(SL, EltVT, Op, i));
1309513089
}
1309613090
}
1309713091

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3805,8 +3805,7 @@ static SDValue lowerBuildVectorViaDominantValues(SDValue Op, SelectionDAG &DAG,
38053805
if (V.isUndef() || !Processed.insert(V).second)
38063806
continue;
38073807
if (ValueCounts[V] == 1) {
3808-
Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V,
3809-
DAG.getVectorIdxConstant(OpIdx.index(), DL));
3808+
Vec = DAG.getInsertVectorElt(DL, Vec, V, OpIdx.index());
38103809
} else {
38113810
// Blend in all instances of this value using a VSELECT, using a
38123811
// mask where each bit signals whether that element is the one
@@ -3963,10 +3962,9 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
39633962
if (ViaIntVT == MVT::i32)
39643963
SplatValue = SignExtend64<32>(SplatValue);
39653964

3966-
SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ViaVecVT,
3967-
DAG.getUNDEF(ViaVecVT),
3968-
DAG.getSignedConstant(SplatValue, DL, XLenVT),
3969-
DAG.getVectorIdxConstant(0, DL));
3965+
SDValue Vec = DAG.getInsertVectorElt(
3966+
DL, DAG.getUNDEF(ViaVecVT),
3967+
DAG.getSignedConstant(SplatValue, DL, XLenVT), 0);
39703968
if (ViaVecLen != 1)
39713969
Vec = DAG.getExtractSubvector(DL, MVT::getVectorVT(ViaIntVT, 1), Vec, 0);
39723970
return DAG.getBitcast(VT, Vec);
@@ -7180,9 +7178,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
71807178
EVT BVT = EVT::getVectorVT(*DAG.getContext(), Op0VT, 1);
71817179
if (!isTypeLegal(BVT))
71827180
return SDValue();
7183-
return DAG.getBitcast(VT, DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, BVT,
7184-
DAG.getUNDEF(BVT), Op0,
7185-
DAG.getVectorIdxConstant(0, DL)));
7181+
return DAG.getBitcast(
7182+
VT, DAG.getInsertVectorElt(DL, DAG.getUNDEF(BVT), Op0, 0));
71867183
}
71877184
return SDValue();
71887185
}
@@ -7194,8 +7191,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
71947191
if (!isTypeLegal(BVT))
71957192
return SDValue();
71967193
SDValue BVec = DAG.getBitcast(BVT, Op0);
7197-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
7198-
DAG.getVectorIdxConstant(0, DL));
7194+
return DAG.getExtractVectorElt(DL, VT, BVec, 0);
71997195
}
72007196
return SDValue();
72017197
}
@@ -9916,8 +9912,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
99169912

99179913
if (!EltVT.isInteger()) {
99189914
// Floating-point extracts are handled in TableGen.
9919-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
9920-
DAG.getVectorIdxConstant(0, DL));
9915+
return DAG.getExtractVectorElt(DL, EltVT, Vec, 0);
99219916
}
99229917

99239918
SDValue Elt0 = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
@@ -10321,8 +10316,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
1032110316
return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
1032210317
}
1032310318
case Intrinsic::riscv_vfmv_f_s:
10324-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
10325-
Op.getOperand(1), DAG.getVectorIdxConstant(0, DL));
10319+
return DAG.getExtractVectorElt(DL, Op.getValueType(), Op.getOperand(1), 0);
1032610320
case Intrinsic::riscv_vmv_v_x:
1032710321
return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
1032810322
Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
@@ -10856,8 +10850,7 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
1085610850
SDValue Policy = DAG.getTargetConstant(RISCVVType::TAIL_AGNOSTIC, DL, XLenVT);
1085710851
SDValue Ops[] = {PassThru, Vec, InitialValue, Mask, VL, Policy};
1085810852
SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, Ops);
10859-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
10860-
DAG.getVectorIdxConstant(0, DL));
10853+
return DAG.getExtractVectorElt(DL, ResVT, Reduction, 0);
1086110854
}
1086210855

1086310856
SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
@@ -10902,8 +10895,7 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
1090210895
case ISD::UMIN:
1090310896
case ISD::SMAX:
1090410897
case ISD::SMIN:
10905-
StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
10906-
DAG.getVectorIdxConstant(0, DL));
10898+
StartV = DAG.getExtractVectorElt(DL, VecEltVT, Vec, 0);
1090710899
}
1090810900
return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec,
1090910901
Mask, VL, DL, DAG, Subtarget);
@@ -10934,9 +10926,7 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT,
1093410926
case ISD::VECREDUCE_FMAXIMUM:
1093510927
case ISD::VECREDUCE_FMIN:
1093610928
case ISD::VECREDUCE_FMAX: {
10937-
SDValue Front =
10938-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0),
10939-
DAG.getVectorIdxConstant(0, DL));
10929+
SDValue Front = DAG.getExtractVectorElt(DL, EltVT, Op.getOperand(0), 0);
1094010930
unsigned RVVOpc =
1094110931
(Opcode == ISD::VECREDUCE_FMIN || Opcode == ISD::VECREDUCE_FMINIMUM)
1094210932
? RISCVISD::VECREDUCE_FMIN_VL
@@ -14055,8 +14045,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1405514045
EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1);
1405614046
if (isTypeLegal(BVT)) {
1405714047
SDValue BVec = DAG.getBitcast(BVT, Op0);
14058-
Results.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
14059-
DAG.getVectorIdxConstant(0, DL)));
14048+
Results.push_back(DAG.getExtractVectorElt(DL, VT, BVec, 0));
1406014049
}
1406114050
}
1406214051
break;
@@ -18202,12 +18191,11 @@ static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
1820218191
if (ConcatVT.getVectorElementType() != InVal.getValueType())
1820318192
return SDValue();
1820418193
unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
18205-
SDValue NewIdx = DAG.getVectorIdxConstant(Elt % ConcatNumElts, DL);
18194+
unsigned NewIdx = Elt % ConcatNumElts;
1820618195

1820718196
unsigned ConcatOpIdx = Elt / ConcatNumElts;
1820818197
SDValue ConcatOp = InVec.getOperand(ConcatOpIdx);
18209-
ConcatOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ConcatVT,
18210-
ConcatOp, InVal, NewIdx);
18198+
ConcatOp = DAG.getInsertVectorElt(DL, ConcatOp, InVal, NewIdx);
1821118199

1821218200
SmallVector<SDValue> ConcatOps(InVec->ops());
1821318201
ConcatOps[ConcatOpIdx] = ConcatOp;

0 commit comments

Comments
 (0)