Skip to content

[DAG] visitEXTRACT_SUBVECTOR - change fold helper methods to take operands instead of EXTRACT_SUBVECTOR node. NFC. #138279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 41 additions & 61 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25100,26 +25100,26 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {

// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
// if the subvector can be sourced for free.
static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
V.getOperand(1).getValueType() == SubVT &&
V.getConstantOperandAPInt(2) == Index) {
return V.getOperand(1);
}
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
if (V.getOpcode() == ISD::CONCAT_VECTORS &&
V.getOperand(0).getValueType() == SubVT &&
(IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
(Index % SubVT.getVectorMinNumElements()) == 0) {
uint64_t SubIdx = Index / SubVT.getVectorMinNumElements();
return V.getOperand(SubIdx);
}
return SDValue();
}

static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp,
unsigned Index, const SDLoc &DL,
SelectionDAG &DAG,
bool LegalOperations) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDValue BinOp = Extract->getOperand(0);
unsigned BinOpcode = BinOp.getOpcode();
if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
return SDValue();
Expand All @@ -25128,9 +25128,6 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
return SDValue();

SDValue Index = Extract->getOperand(1);
EVT SubVT = Extract->getValueType(0);
if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
return SDValue();

Expand All @@ -25146,29 +25143,25 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
// We are inserting both operands of the wide binop only to extract back
// to the narrow vector size. Eliminate all of the insert/extract:
// ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
BinOp->getFlags());
return DAG.getNode(BinOpcode, DL, SubVT, Sub0, Sub1, BinOp->getFlags());
}

/// If we are extracting a subvector produced by a wide binary operator try
/// to use a narrow binary operator and/or avoid concatenation and extraction.
static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
const SDLoc &DL, SelectionDAG &DAG,
bool LegalOperations) {
// TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
// some of these bailouts with other transforms.

if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
if (SDValue V = narrowInsertExtractVectorBinOp(VT, Src, Index, DL, DAG,
LegalOperations))
return V;

// The extract index must be a constant, so we can map it to a concat operand.
auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
if (!ExtractIndexC)
return SDValue();

// We are looking for an optionally bitcasted wide vector binary operator
// feeding an extract subvector.
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
SDValue BinOp = peekThroughBitcasts(Src);
unsigned BOpcode = BinOp.getOpcode();
if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
return SDValue();
Expand All @@ -25190,9 +25183,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
if (!WideBVT.isFixedLengthVector())
return SDValue();

EVT VT = Extract->getValueType(0);
unsigned ExtractIndex = ExtractIndexC->getZExtValue();
assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
assert((Index % VT.getVectorNumElements()) == 0 &&
"Extract index is not a multiple of the vector length.");

// Bail out if this is not a proper multiple width extraction.
Expand All @@ -25219,12 +25210,11 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
// for concat ops. The narrow binop alone makes this transform profitable.
// We can't just reuse the original extract index operand because we may have
// bitcasted.
unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
unsigned ConcatOpNum = Index / VT.getVectorNumElements();
unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
BinOp.hasOneUse() && Src->hasOneUse()) {
// extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
SDLoc DL(Extract);
SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
BinOp.getOperand(0), NewExtIndex);
Expand Down Expand Up @@ -25264,7 +25254,6 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
// extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
// extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
// extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
SDLoc DL(Extract);
SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
: DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
Expand All @@ -25284,24 +25273,24 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
/// If we are extracting a subvector from a wide vector load, convert to a
/// narrow load to eliminate the extraction:
/// (extract_subvector (load wide vector)) --> (load narrow vector)
static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
SelectionDAG &DAG) {
static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index,
const SDLoc &DL, SelectionDAG &DAG) {
// TODO: Add support for big-endian. The offset calculation must be adjusted.
if (DAG.getDataLayout().isBigEndian())
return SDValue();

auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
auto *Ld = dyn_cast<LoadSDNode>(Src);
if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple())
return SDValue();

// Allow targets to opt-out.
EVT VT = Extract->getValueType(0);

// We can only create byte sized loads.
if (!VT.isByteSized())
return SDValue();

unsigned Index = Extract->getConstantOperandVal(1);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, VT))
return SDValue();

unsigned NumElts = VT.getVectorMinNumElements();
// A fixed length vector being extracted from a scalable vector
// may not be any *smaller* than the scalable one.
Expand All @@ -25319,7 +25308,6 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
if (Offset.isFixed())
ByteOffset = Offset.getFixedValue();

const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT, ByteOffset))
return SDValue();

Expand Down Expand Up @@ -25350,23 +25338,18 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
/// iff it is legal and profitable to do so. Notably, the trimmed mask
/// (containing only the elements that are extracted)
/// must reference at most two subvectors.
static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
unsigned Index,
const SDLoc &DL,
SelectionDAG &DAG,
const TargetLowering &TLI,
bool LegalOperations) {
assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
"Must only be called on EXTRACT_SUBVECTOR's");

SDValue N0 = N->getOperand(0);

// Only deal with non-scalable vectors.
EVT NarrowVT = N->getValueType(0);
EVT WideVT = N0.getValueType();
EVT WideVT = Src.getValueType();
if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
return SDValue();

// The operand must be a shufflevector.
auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Src);
if (!WideShuffleVector)
return SDValue();

Expand All @@ -25375,13 +25358,13 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
return SDValue();

// And the narrow shufflevector that we'll form must be legal.
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (LegalOperations &&
!TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
return SDValue();

uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
int NumEltsExtracted = NarrowVT.getVectorNumElements();
assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
assert((Index % NumEltsExtracted) == 0 &&
"Extract index is not a multiple of the output vector length.");

int WideNumElts = WideVT.getVectorNumElements();
Expand All @@ -25392,8 +25375,7 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
DemandedSubvectors;

// Try to decode the wide mask into narrow mask from at most two subvectors.
for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
NumEltsExtracted)) {
for (int M : WideShuffleVector->getMask().slice(Index, NumEltsExtracted)) {
assert((M >= -1) && (M < (2 * WideNumElts)) &&
"Out-of-bounds shuffle mask?");

Expand Down Expand Up @@ -25476,8 +25458,6 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
!TLI.isShuffleMaskLegal(NewMask, NarrowVT))
return SDValue();

SDLoc DL(N);

SmallVector<SDValue, 2> NewOps;
for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
&DemandedSubvector : DemandedSubvectors) {
Expand Down Expand Up @@ -25507,9 +25487,8 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
if (V.isUndef())
return DAG.getUNDEF(NVT);

if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DL, DAG))
return NarrowLoad;
if (SDValue NarrowLoad = narrowExtractedVectorLoad(NVT, V, ExtIdx, DL, DAG))
return NarrowLoad;

// Combine an extract of an extract into a single extract_subvector.
// ext (ext X, C), 0 --> ext X, C
Expand Down Expand Up @@ -25631,9 +25610,13 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
}
}

if (SDValue V =
foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
return V;
if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector(
NVT, V, ExtIdx, DL, DAG, LegalOperations))
return Shuffle;

if (SDValue NarrowBOp =
narrowExtractedVectorBinOp(NVT, V, ExtIdx, DL, DAG, LegalOperations))
return NarrowBOp;

V = peekThroughBitcasts(V);

Expand Down Expand Up @@ -25694,9 +25677,6 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
}
}

if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
return NarrowBOp;

// If only EXTRACT_SUBVECTOR nodes use the source vector we can
// simplify it based on the (valid) extractions.
if (!V.getValueType().isScalableVector() &&
Expand Down