Skip to content

Commit 3139775

Browse files
committed
[WIP][X86] combineX86ShufflesRecursively - attempt to combine shuffles with larger types from EXTRACT_SUBVECTOR nodes
This replaces the rather limited combineX86ShuffleChainWithExtract function with handling for EXTRACT_SUBVECTOR node as we recurse down the shuffle chain, widening the shuffle mask to accomodate the larger value type. This will mainly help AVX2/AVX512 cases with cross-lane shuffles, but it also helps collapse some cases where the same subvector has gotten reused in multiple lanes. Exposed missing DemandedElts handling inside ISD::TRUNCATE nodes for ComputeNumSignBits
1 parent ec936b3 commit 3139775

File tree

96 files changed

+51032
-49917
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+51032
-49917
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5109,7 +5109,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
51095109
case ISD::TRUNCATE: {
51105110
// Check if the sign bits of source go down as far as the truncated value.
51115111
unsigned NumSrcBits = Op.getOperand(0).getScalarValueSizeInBits();
5112-
unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
5112+
unsigned NumSrcSignBits =
5113+
ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
51135114
if (NumSrcSignBits > (NumSrcBits - VTBits))
51145115
return NumSrcSignBits - (NumSrcBits - VTBits);
51155116
break;

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 50 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -39640,13 +39640,6 @@ static bool matchBinaryPermuteShuffle(
3964039640
return false;
3964139641
}
3964239642

39643-
static SDValue combineX86ShuffleChainWithExtract(
39644-
ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39645-
ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39646-
bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39647-
bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39648-
const X86Subtarget &Subtarget);
39649-
3965039643
/// Combine an arbitrary chain of shuffles into a single instruction if
3965139644
/// possible.
3965239645
///
@@ -40191,14 +40184,6 @@ static SDValue combineX86ShuffleChain(
4019140184
return DAG.getBitcast(RootVT, Res);
4019240185
}
4019340186

40194-
// If that failed and either input is extracted then try to combine as a
40195-
// shuffle with the larger type.
40196-
if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40197-
Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40198-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40199-
IsMaskedShuffle, DAG, DL, Subtarget))
40200-
return WideShuffle;
40201-
4020240187
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
4020340188
// (non-VLX will pad to 512-bit shuffles).
4020440189
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40364,14 +40349,6 @@ static SDValue combineX86ShuffleChain(
4036440349
return DAG.getBitcast(RootVT, Res);
4036540350
}
4036640351

40367-
// If that failed and either input is extracted then try to combine as a
40368-
// shuffle with the larger type.
40369-
if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40370-
Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40371-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40372-
DAG, DL, Subtarget))
40373-
return WideShuffle;
40374-
4037540352
// If we have a dual input shuffle then lower to VPERMV3,
4037640353
// (non-VLX will pad to 512-bit shuffles)
4037740354
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40397,148 +40374,6 @@ static SDValue combineX86ShuffleChain(
4039740374
return SDValue();
4039840375
}
4039940376

40400-
// Combine an arbitrary chain of shuffles + extract_subvectors into a single
40401-
// instruction if possible.
40402-
//
40403-
// Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40404-
// type size to attempt to combine:
40405-
// shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40406-
// -->
40407-
// extract_subvector(shuffle(x,y,m2),0)
40408-
static SDValue combineX86ShuffleChainWithExtract(
40409-
ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40410-
ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40411-
bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40412-
bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40413-
const X86Subtarget &Subtarget) {
40414-
unsigned NumMaskElts = BaseMask.size();
40415-
unsigned NumInputs = Inputs.size();
40416-
if (NumInputs == 0)
40417-
return SDValue();
40418-
40419-
unsigned RootSizeInBits = RootVT.getSizeInBits();
40420-
unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40421-
assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40422-
40423-
// Peek through subvectors to find widest legal vector.
40424-
// TODO: Handle ISD::TRUNCATE
40425-
unsigned WideSizeInBits = RootSizeInBits;
40426-
for (SDValue Input : Inputs) {
40427-
Input = peekThroughBitcasts(Input);
40428-
while (1) {
40429-
if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40430-
Input = peekThroughBitcasts(Input.getOperand(0));
40431-
continue;
40432-
}
40433-
if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40434-
Input.getOperand(0).isUndef()) {
40435-
Input = peekThroughBitcasts(Input.getOperand(1));
40436-
continue;
40437-
}
40438-
break;
40439-
}
40440-
if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40441-
WideSizeInBits < Input.getValueSizeInBits())
40442-
WideSizeInBits = Input.getValueSizeInBits();
40443-
}
40444-
40445-
// Bail if we fail to find a source larger than the existing root.
40446-
if (WideSizeInBits <= RootSizeInBits ||
40447-
(WideSizeInBits % RootSizeInBits) != 0)
40448-
return SDValue();
40449-
40450-
// Create new mask for larger type.
40451-
SmallVector<int, 64> WideMask;
40452-
growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40453-
40454-
// Attempt to peek through inputs and adjust mask when we extract from an
40455-
// upper subvector.
40456-
int AdjustedMasks = 0;
40457-
SmallVector<SDValue, 4> WideInputs(Inputs);
40458-
for (unsigned I = 0; I != NumInputs; ++I) {
40459-
SDValue &Input = WideInputs[I];
40460-
Input = peekThroughBitcasts(Input);
40461-
while (1) {
40462-
if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40463-
Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40464-
uint64_t Idx = Input.getConstantOperandVal(1);
40465-
if (Idx != 0) {
40466-
++AdjustedMasks;
40467-
unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40468-
Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40469-
40470-
int lo = I * WideMask.size();
40471-
int hi = (I + 1) * WideMask.size();
40472-
for (int &M : WideMask)
40473-
if (lo <= M && M < hi)
40474-
M += Idx;
40475-
}
40476-
Input = peekThroughBitcasts(Input.getOperand(0));
40477-
continue;
40478-
}
40479-
// TODO: Handle insertions into upper subvectors.
40480-
if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40481-
Input.getOperand(0).isUndef() &&
40482-
isNullConstant(Input.getOperand(2))) {
40483-
Input = peekThroughBitcasts(Input.getOperand(1));
40484-
continue;
40485-
}
40486-
break;
40487-
}
40488-
}
40489-
40490-
// Remove unused/repeated shuffle source ops.
40491-
resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40492-
assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40493-
40494-
// Bail if we're always extracting from the lowest subvectors,
40495-
// combineX86ShuffleChain should match this for the current width, or the
40496-
// shuffle still references too many inputs.
40497-
if (AdjustedMasks == 0 || WideInputs.size() > 2)
40498-
return SDValue();
40499-
40500-
// Minor canonicalization of the accumulated shuffle mask to make it easier
40501-
// to match below. All this does is detect masks with sequential pairs of
40502-
// elements, and shrink them to the half-width mask. It does this in a loop
40503-
// so it will reduce the size of the mask to the minimal width mask which
40504-
// performs an equivalent shuffle.
40505-
while (WideMask.size() > 1) {
40506-
SmallVector<int, 64> WidenedMask;
40507-
if (!canWidenShuffleElements(WideMask, WidenedMask))
40508-
break;
40509-
WideMask = std::move(WidenedMask);
40510-
}
40511-
40512-
// Canonicalization of binary shuffle masks to improve pattern matching by
40513-
// commuting the inputs.
40514-
if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40515-
ShuffleVectorSDNode::commuteMask(WideMask);
40516-
std::swap(WideInputs[0], WideInputs[1]);
40517-
}
40518-
40519-
// Increase depth for every upper subvector we've peeked through.
40520-
Depth += AdjustedMasks;
40521-
40522-
// Attempt to combine wider chain.
40523-
// TODO: Can we use a better Root?
40524-
SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40525-
WideInputs.back().getValueSizeInBits()
40526-
? WideInputs.front()
40527-
: WideInputs.back();
40528-
assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40529-
"WideRootSize mismatch");
40530-
40531-
if (SDValue WideShuffle = combineX86ShuffleChain(
40532-
WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40533-
Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40534-
IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40535-
WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40536-
return DAG.getBitcast(RootVT, WideShuffle);
40537-
}
40538-
40539-
return SDValue();
40540-
}
40541-
4054240377
// Canonicalize the combined shuffle mask chain with horizontal ops.
4054340378
// NOTE: This may update the Ops and Mask.
4054440379
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -40951,6 +40786,54 @@ static SDValue combineX86ShufflesRecursively(
4095140786
OpMask.assign(NumElts, SM_SentinelUndef);
4095240787
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
4095340788
OpZero = OpUndef = APInt::getZero(NumElts);
40789+
} else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40790+
TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
40791+
Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
40792+
(Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
40793+
// Extracting from vector larger than RootVT - scale the mask and attempt to
40794+
// fold the shuffle with the larger root type, then extract the lower
40795+
// elements.
40796+
unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
40797+
unsigned Scale = NewRootSizeInBits / RootSizeInBits;
40798+
MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
40799+
Scale * RootVT.getVectorNumElements());
40800+
SmallVector<int, 64> NewRootMask;
40801+
growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
40802+
// If we're using the lowest subvector, just replace it directly in the src
40803+
// ops/nodes.
40804+
SmallVector<SDValue, 16> NewSrcOps(SrcOps);
40805+
SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
40806+
if (isNullConstant(Op.getOperand(1))) {
40807+
NewSrcOps[SrcOpIndex] = Op.getOperand(0);
40808+
NewSrcNodes.push_back(Op.getNode());
40809+
}
40810+
// Don't increase the combine depth - we're effectively working on the same
40811+
// nodes, just with a wider type.
40812+
if (SDValue WideShuffle = combineX86ShufflesRecursively(
40813+
NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
40814+
Depth, MaxDepth, AllowVariableCrossLaneMask,
40815+
AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
40816+
return DAG.getBitcast(
40817+
RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
40818+
return SDValue();
40819+
} else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
40820+
Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40821+
Op.getOperand(1).getOperand(0).getValueSizeInBits() >
40822+
RootSizeInBits) {
40823+
// If we're inserting an subvector extracted from a vector larger than
40824+
// RootVT, then combine the insert_subvector as a shuffle, the
40825+
// extract_subvector will be folded in a later recursion.
40826+
SDValue BaseVec = Op.getOperand(0);
40827+
SDValue SubVec = Op.getOperand(1);
40828+
int InsertIdx = Op.getConstantOperandVal(2);
40829+
unsigned NumBaseElts = VT.getVectorNumElements();
40830+
unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
40831+
OpInputs.assign({BaseVec, SubVec});
40832+
OpMask.resize(NumBaseElts);
40833+
std::iota(OpMask.begin(), OpMask.end(), 0);
40834+
std::iota(OpMask.begin() + InsertIdx,
40835+
OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
40836+
OpZero = OpUndef = APInt::getZero(NumBaseElts);
4095440837
} else {
4095540838
return SDValue();
4095640839
}
@@ -41297,25 +41180,9 @@ static SDValue combineX86ShufflesRecursively(
4129741180
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
4129841181
IsMaskedShuffle, DAG, DL, Subtarget))
4129941182
return Shuffle;
41300-
41301-
// If all the operands come from the same larger vector, fallthrough and try
41302-
// to use combineX86ShuffleChainWithExtract.
41303-
SDValue LHS = peekThroughBitcasts(Ops.front());
41304-
SDValue RHS = peekThroughBitcasts(Ops.back());
41305-
if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41306-
(RootSizeInBits / Mask.size()) != 64 ||
41307-
LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41308-
RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41309-
LHS.getOperand(0) != RHS.getOperand(0))
41310-
return SDValue();
4131141183
}
4131241184

41313-
// If that failed and any input is extracted then try to combine as a
41314-
// shuffle with the larger type.
41315-
return combineX86ShuffleChainWithExtract(
41316-
Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41317-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41318-
DAG, DL, Subtarget);
41185+
return SDValue();
4131941186
}
4132041187

4132141188
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -43928,6 +43795,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
4392843795
case X86ISD::UNPCKL:
4392943796
case X86ISD::UNPCKH:
4393043797
case X86ISD::BLENDI:
43798+
case X86ISD::SHUFP:
4393143799
// Integer ops.
4393243800
case X86ISD::PACKSS:
4393343801
case X86ISD::PACKUS:

0 commit comments

Comments
 (0)