Skip to content

Commit b5dbddd

Browse files
authored
[DAG] visitEXTRACT_SUBVECTOR - change fold helper methods to take operands instead of EXTRACT_SUBVECTOR node. NFC. (#138279)
Call with the individual subvector type, source vector and index operands instead of the original EXTRACT_SUBVECTOR node. Some of these folds still assumed that EXTRACT_SUBVECTOR/INSERT_SUBVECTOR nodes could have variable indices, despite us moving to all constant indices some time ago - all of that code has now been simplified. I've moved the narrowExtractedVectorBinOp call higher up, but it won't affect fold order - it didn't rely on the peekThroughBitcasts call, and worked on BinOps, not BUILD_VECTOR/INSERT_SUBVECTOR nodes. Prep work to make it easier for more of these folds to work through BITCAST nodes.
1 parent a635bbf commit b5dbddd

File tree

1 file changed

+41
-61
lines changed

1 file changed

+41
-61
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 41 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -25100,26 +25100,26 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
2510025100

2510125101
// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
2510225102
// if the subvector can be sourced for free.
25103-
static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
25103+
static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
2510425104
if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
25105-
V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
25105+
V.getOperand(1).getValueType() == SubVT &&
25106+
V.getConstantOperandAPInt(2) == Index) {
2510625107
return V.getOperand(1);
2510725108
}
25108-
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
25109-
if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
25109+
if (V.getOpcode() == ISD::CONCAT_VECTORS &&
2511025110
V.getOperand(0).getValueType() == SubVT &&
25111-
(IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
25112-
uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
25111+
(Index % SubVT.getVectorMinNumElements()) == 0) {
25112+
uint64_t SubIdx = Index / SubVT.getVectorMinNumElements();
2511325113
return V.getOperand(SubIdx);
2511425114
}
2511525115
return SDValue();
2511625116
}
2511725117

25118-
static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
25118+
static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp,
25119+
unsigned Index, const SDLoc &DL,
2511925120
SelectionDAG &DAG,
2512025121
bool LegalOperations) {
2512125122
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25122-
SDValue BinOp = Extract->getOperand(0);
2512325123
unsigned BinOpcode = BinOp.getOpcode();
2512425124
if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
2512525125
return SDValue();
@@ -25128,9 +25128,6 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
2512825128
SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
2512925129
if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
2513025130
return SDValue();
25131-
25132-
SDValue Index = Extract->getOperand(1);
25133-
EVT SubVT = Extract->getValueType(0);
2513425131
if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
2513525132
return SDValue();
2513625133

@@ -25146,29 +25143,25 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
2514625143
// We are inserting both operands of the wide binop only to extract back
2514725144
// to the narrow vector size. Eliminate all of the insert/extract:
2514825145
// ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
25149-
return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
25150-
BinOp->getFlags());
25146+
return DAG.getNode(BinOpcode, DL, SubVT, Sub0, Sub1, BinOp->getFlags());
2515125147
}
2515225148

2515325149
/// If we are extracting a subvector produced by a wide binary operator try
2515425150
/// to use a narrow binary operator and/or avoid concatenation and extraction.
25155-
static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
25151+
static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
25152+
const SDLoc &DL, SelectionDAG &DAG,
2515625153
bool LegalOperations) {
2515725154
// TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
2515825155
// some of these bailouts with other transforms.
2515925156

25160-
if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
25157+
if (SDValue V = narrowInsertExtractVectorBinOp(VT, Src, Index, DL, DAG,
25158+
LegalOperations))
2516125159
return V;
2516225160

25163-
// The extract index must be a constant, so we can map it to a concat operand.
25164-
auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
25165-
if (!ExtractIndexC)
25166-
return SDValue();
25167-
2516825161
// We are looking for an optionally bitcasted wide vector binary operator
2516925162
// feeding an extract subvector.
2517025163
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25171-
SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
25164+
SDValue BinOp = peekThroughBitcasts(Src);
2517225165
unsigned BOpcode = BinOp.getOpcode();
2517325166
if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
2517425167
return SDValue();
@@ -25190,9 +25183,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
2519025183
if (!WideBVT.isFixedLengthVector())
2519125184
return SDValue();
2519225185

25193-
EVT VT = Extract->getValueType(0);
25194-
unsigned ExtractIndex = ExtractIndexC->getZExtValue();
25195-
assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
25186+
assert((Index % VT.getVectorNumElements()) == 0 &&
2519625187
"Extract index is not a multiple of the vector length.");
2519725188

2519825189
// Bail out if this is not a proper multiple width extraction.
@@ -25219,12 +25210,11 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
2521925210
// for concat ops. The narrow binop alone makes this transform profitable.
2522025211
// We can't just reuse the original extract index operand because we may have
2522125212
// bitcasted.
25222-
unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
25213+
unsigned ConcatOpNum = Index / VT.getVectorNumElements();
2522325214
unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
2522425215
if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
25225-
BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
25216+
BinOp.hasOneUse() && Src->hasOneUse()) {
2522625217
// extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
25227-
SDLoc DL(Extract);
2522825218
SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
2522925219
SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
2523025220
BinOp.getOperand(0), NewExtIndex);
@@ -25264,7 +25254,6 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
2526425254
// extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
2526525255
// extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
2526625256
// extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
25267-
SDLoc DL(Extract);
2526825257
SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
2526925258
SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
2527025259
: DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
@@ -25284,24 +25273,24 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
2528425273
/// If we are extracting a subvector from a wide vector load, convert to a
2528525274
/// narrow load to eliminate the extraction:
2528625275
/// (extract_subvector (load wide vector)) --> (load narrow vector)
25287-
static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
25288-
SelectionDAG &DAG) {
25276+
static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index,
25277+
const SDLoc &DL, SelectionDAG &DAG) {
2528925278
// TODO: Add support for big-endian. The offset calculation must be adjusted.
2529025279
if (DAG.getDataLayout().isBigEndian())
2529125280
return SDValue();
2529225281

25293-
auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
25282+
auto *Ld = dyn_cast<LoadSDNode>(Src);
2529425283
if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple())
2529525284
return SDValue();
2529625285

25297-
// Allow targets to opt-out.
25298-
EVT VT = Extract->getValueType(0);
25299-
2530025286
// We can only create byte sized loads.
2530125287
if (!VT.isByteSized())
2530225288
return SDValue();
2530325289

25304-
unsigned Index = Extract->getConstantOperandVal(1);
25290+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25291+
if (!TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, VT))
25292+
return SDValue();
25293+
2530525294
unsigned NumElts = VT.getVectorMinNumElements();
2530625295
// A fixed length vector being extracted from a scalable vector
2530725296
// may not be any *smaller* than the scalable one.
@@ -25319,7 +25308,6 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
2531925308
if (Offset.isFixed())
2532025309
ByteOffset = Offset.getFixedValue();
2532125310

25322-
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2532325311
if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT, ByteOffset))
2532425312
return SDValue();
2532525313

@@ -25350,23 +25338,18 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
2535025338
/// iff it is legal and profitable to do so. Notably, the trimmed mask
2535125339
/// (containing only the elements that are extracted)
2535225340
/// must reference at most two subvectors.
25353-
static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
25341+
static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
25342+
unsigned Index,
25343+
const SDLoc &DL,
2535425344
SelectionDAG &DAG,
25355-
const TargetLowering &TLI,
2535625345
bool LegalOperations) {
25357-
assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
25358-
"Must only be called on EXTRACT_SUBVECTOR's");
25359-
25360-
SDValue N0 = N->getOperand(0);
25361-
2536225346
// Only deal with non-scalable vectors.
25363-
EVT NarrowVT = N->getValueType(0);
25364-
EVT WideVT = N0.getValueType();
25347+
EVT WideVT = Src.getValueType();
2536525348
if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
2536625349
return SDValue();
2536725350

2536825351
// The operand must be a shufflevector.
25369-
auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
25352+
auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Src);
2537025353
if (!WideShuffleVector)
2537125354
return SDValue();
2537225355

@@ -25375,13 +25358,13 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
2537525358
return SDValue();
2537625359

2537725360
// And the narrow shufflevector that we'll form must be legal.
25361+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2537825362
if (LegalOperations &&
2537925363
!TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
2538025364
return SDValue();
2538125365

25382-
uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
2538325366
int NumEltsExtracted = NarrowVT.getVectorNumElements();
25384-
assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
25367+
assert((Index % NumEltsExtracted) == 0 &&
2538525368
"Extract index is not a multiple of the output vector length.");
2538625369

2538725370
int WideNumElts = WideVT.getVectorNumElements();
@@ -25392,8 +25375,7 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
2539225375
DemandedSubvectors;
2539325376

2539425377
// Try to decode the wide mask into narrow mask from at most two subvectors.
25395-
for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
25396-
NumEltsExtracted)) {
25378+
for (int M : WideShuffleVector->getMask().slice(Index, NumEltsExtracted)) {
2539725379
assert((M >= -1) && (M < (2 * WideNumElts)) &&
2539825380
"Out-of-bounds shuffle mask?");
2539925381

@@ -25476,8 +25458,6 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
2547625458
!TLI.isShuffleMaskLegal(NewMask, NarrowVT))
2547725459
return SDValue();
2547825460

25479-
SDLoc DL(N);
25480-
2548125461
SmallVector<SDValue, 2> NewOps;
2548225462
for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
2548325463
&DemandedSubvector : DemandedSubvectors) {
@@ -25507,9 +25487,8 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
2550725487
if (V.isUndef())
2550825488
return DAG.getUNDEF(NVT);
2550925489

25510-
if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
25511-
if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DL, DAG))
25512-
return NarrowLoad;
25490+
if (SDValue NarrowLoad = narrowExtractedVectorLoad(NVT, V, ExtIdx, DL, DAG))
25491+
return NarrowLoad;
2551325492

2551425493
// Combine an extract of an extract into a single extract_subvector.
2551525494
// ext (ext X, C), 0 --> ext X, C
@@ -25631,9 +25610,13 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
2563125610
}
2563225611
}
2563325612

25634-
if (SDValue V =
25635-
foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
25636-
return V;
25613+
if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector(
25614+
NVT, V, ExtIdx, DL, DAG, LegalOperations))
25615+
return Shuffle;
25616+
25617+
if (SDValue NarrowBOp =
25618+
narrowExtractedVectorBinOp(NVT, V, ExtIdx, DL, DAG, LegalOperations))
25619+
return NarrowBOp;
2563725620

2563825621
V = peekThroughBitcasts(V);
2563925622

@@ -25694,9 +25677,6 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
2569425677
}
2569525678
}
2569625679

25697-
if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
25698-
return NarrowBOp;
25699-
2570025680
// If only EXTRACT_SUBVECTOR nodes use the source vector we can
2570125681
// simplify it based on the (valid) extractions.
2570225682
if (!V.getValueType().isScalableVector() &&

0 commit comments

Comments
 (0)