Skip to content

Commit f0dd9c3

Browse files
committed
[WIP][X86] combineX86ShufflesRecursively - attempt to combine shuffles with larger types from EXTRACT_SUBVECTOR nodes
This replaces the rather limited combineX86ShuffleChainWithExtract function with handling for EXTRACT_SUBVECTOR node as we recurse down the shuffle chain, widening the shuffle mask to accomodate the larger value type. This will mainly help AVX2/AVX512 cases with cross-lane shuffles, but it also helps collapse some cases where the same subvector has gotten reused in multiple lanes. Exposed missing DemandedElts handling inside ISD::TRUNCATE nodes for ComputeNumSignBits
1 parent 124e547 commit f0dd9c3

File tree

96 files changed

+49923
-48800
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+49923
-48800
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5108,7 +5108,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
51085108
case ISD::TRUNCATE: {
51095109
// Check if the sign bits of source go down as far as the truncated value.
51105110
unsigned NumSrcBits = Op.getOperand(0).getScalarValueSizeInBits();
5111-
unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
5111+
unsigned NumSrcSignBits =
5112+
ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
51125113
if (NumSrcSignBits > (NumSrcBits - VTBits))
51135114
return NumSrcSignBits - (NumSrcBits - VTBits);
51145115
break;

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 50 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -39737,13 +39737,6 @@ static bool matchBinaryPermuteShuffle(
3973739737
return false;
3973839738
}
3973939739

39740-
static SDValue combineX86ShuffleChainWithExtract(
39741-
ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39742-
ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39743-
bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39744-
bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39745-
const X86Subtarget &Subtarget);
39746-
3974739740
/// Combine an arbitrary chain of shuffles into a single instruction if
3974839741
/// possible.
3974939742
///
@@ -40288,14 +40281,6 @@ static SDValue combineX86ShuffleChain(
4028840281
return DAG.getBitcast(RootVT, Res);
4028940282
}
4029040283

40291-
// If that failed and either input is extracted then try to combine as a
40292-
// shuffle with the larger type.
40293-
if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40294-
Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40295-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40296-
IsMaskedShuffle, DAG, DL, Subtarget))
40297-
return WideShuffle;
40298-
4029940284
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
4030040285
// (non-VLX will pad to 512-bit shuffles).
4030140286
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40461,14 +40446,6 @@ static SDValue combineX86ShuffleChain(
4046140446
return DAG.getBitcast(RootVT, Res);
4046240447
}
4046340448

40464-
// If that failed and either input is extracted then try to combine as a
40465-
// shuffle with the larger type.
40466-
if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40467-
Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40468-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40469-
DAG, DL, Subtarget))
40470-
return WideShuffle;
40471-
4047240449
// If we have a dual input shuffle then lower to VPERMV3,
4047340450
// (non-VLX will pad to 512-bit shuffles)
4047440451
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40494,148 +40471,6 @@ static SDValue combineX86ShuffleChain(
4049440471
return SDValue();
4049540472
}
4049640473

40497-
// Combine an arbitrary chain of shuffles + extract_subvectors into a single
40498-
// instruction if possible.
40499-
//
40500-
// Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40501-
// type size to attempt to combine:
40502-
// shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40503-
// -->
40504-
// extract_subvector(shuffle(x,y,m2),0)
40505-
static SDValue combineX86ShuffleChainWithExtract(
40506-
ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40507-
ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40508-
bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40509-
bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40510-
const X86Subtarget &Subtarget) {
40511-
unsigned NumMaskElts = BaseMask.size();
40512-
unsigned NumInputs = Inputs.size();
40513-
if (NumInputs == 0)
40514-
return SDValue();
40515-
40516-
unsigned RootSizeInBits = RootVT.getSizeInBits();
40517-
unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40518-
assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40519-
40520-
// Peek through subvectors to find widest legal vector.
40521-
// TODO: Handle ISD::TRUNCATE
40522-
unsigned WideSizeInBits = RootSizeInBits;
40523-
for (SDValue Input : Inputs) {
40524-
Input = peekThroughBitcasts(Input);
40525-
while (1) {
40526-
if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40527-
Input = peekThroughBitcasts(Input.getOperand(0));
40528-
continue;
40529-
}
40530-
if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40531-
Input.getOperand(0).isUndef()) {
40532-
Input = peekThroughBitcasts(Input.getOperand(1));
40533-
continue;
40534-
}
40535-
break;
40536-
}
40537-
if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40538-
WideSizeInBits < Input.getValueSizeInBits())
40539-
WideSizeInBits = Input.getValueSizeInBits();
40540-
}
40541-
40542-
// Bail if we fail to find a source larger than the existing root.
40543-
if (WideSizeInBits <= RootSizeInBits ||
40544-
(WideSizeInBits % RootSizeInBits) != 0)
40545-
return SDValue();
40546-
40547-
// Create new mask for larger type.
40548-
SmallVector<int, 64> WideMask;
40549-
growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40550-
40551-
// Attempt to peek through inputs and adjust mask when we extract from an
40552-
// upper subvector.
40553-
int AdjustedMasks = 0;
40554-
SmallVector<SDValue, 4> WideInputs(Inputs);
40555-
for (unsigned I = 0; I != NumInputs; ++I) {
40556-
SDValue &Input = WideInputs[I];
40557-
Input = peekThroughBitcasts(Input);
40558-
while (1) {
40559-
if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40560-
Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40561-
uint64_t Idx = Input.getConstantOperandVal(1);
40562-
if (Idx != 0) {
40563-
++AdjustedMasks;
40564-
unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40565-
Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40566-
40567-
int lo = I * WideMask.size();
40568-
int hi = (I + 1) * WideMask.size();
40569-
for (int &M : WideMask)
40570-
if (lo <= M && M < hi)
40571-
M += Idx;
40572-
}
40573-
Input = peekThroughBitcasts(Input.getOperand(0));
40574-
continue;
40575-
}
40576-
// TODO: Handle insertions into upper subvectors.
40577-
if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40578-
Input.getOperand(0).isUndef() &&
40579-
isNullConstant(Input.getOperand(2))) {
40580-
Input = peekThroughBitcasts(Input.getOperand(1));
40581-
continue;
40582-
}
40583-
break;
40584-
}
40585-
}
40586-
40587-
// Remove unused/repeated shuffle source ops.
40588-
resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40589-
assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40590-
40591-
// Bail if we're always extracting from the lowest subvectors,
40592-
// combineX86ShuffleChain should match this for the current width, or the
40593-
// shuffle still references too many inputs.
40594-
if (AdjustedMasks == 0 || WideInputs.size() > 2)
40595-
return SDValue();
40596-
40597-
// Minor canonicalization of the accumulated shuffle mask to make it easier
40598-
// to match below. All this does is detect masks with sequential pairs of
40599-
// elements, and shrink them to the half-width mask. It does this in a loop
40600-
// so it will reduce the size of the mask to the minimal width mask which
40601-
// performs an equivalent shuffle.
40602-
while (WideMask.size() > 1) {
40603-
SmallVector<int, 64> WidenedMask;
40604-
if (!canWidenShuffleElements(WideMask, WidenedMask))
40605-
break;
40606-
WideMask = std::move(WidenedMask);
40607-
}
40608-
40609-
// Canonicalization of binary shuffle masks to improve pattern matching by
40610-
// commuting the inputs.
40611-
if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40612-
ShuffleVectorSDNode::commuteMask(WideMask);
40613-
std::swap(WideInputs[0], WideInputs[1]);
40614-
}
40615-
40616-
// Increase depth for every upper subvector we've peeked through.
40617-
Depth += AdjustedMasks;
40618-
40619-
// Attempt to combine wider chain.
40620-
// TODO: Can we use a better Root?
40621-
SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40622-
WideInputs.back().getValueSizeInBits()
40623-
? WideInputs.front()
40624-
: WideInputs.back();
40625-
assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40626-
"WideRootSize mismatch");
40627-
40628-
if (SDValue WideShuffle = combineX86ShuffleChain(
40629-
WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40630-
Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40631-
IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40632-
WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40633-
return DAG.getBitcast(RootVT, WideShuffle);
40634-
}
40635-
40636-
return SDValue();
40637-
}
40638-
4063940474
// Canonicalize the combined shuffle mask chain with horizontal ops.
4064040475
// NOTE: This may update the Ops and Mask.
4064140476
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -41048,6 +40883,54 @@ static SDValue combineX86ShufflesRecursively(
4104840883
OpMask.assign(NumElts, SM_SentinelUndef);
4104940884
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
4105040885
OpZero = OpUndef = APInt::getZero(NumElts);
40886+
} else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40887+
TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
40888+
Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
40889+
(Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
40890+
// Extracting from vector larger than RootVT - scale the mask and attempt to
40891+
// fold the shuffle with the larger root type, then extract the lower
40892+
// elements.
40893+
unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
40894+
unsigned Scale = NewRootSizeInBits / RootSizeInBits;
40895+
MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
40896+
Scale * RootVT.getVectorNumElements());
40897+
SmallVector<int, 64> NewRootMask;
40898+
growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
40899+
// If we're using the lowest subvector, just replace it directly in the src
40900+
// ops/nodes.
40901+
SmallVector<SDValue, 16> NewSrcOps(SrcOps);
40902+
SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
40903+
if (isNullConstant(Op.getOperand(1))) {
40904+
NewSrcOps[SrcOpIndex] = Op.getOperand(0);
40905+
NewSrcNodes.push_back(Op.getNode());
40906+
}
40907+
// Don't increase the combine depth - we're effectively working on the same
40908+
// nodes, just with a wider type.
40909+
if (SDValue WideShuffle = combineX86ShufflesRecursively(
40910+
NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
40911+
Depth, MaxDepth, AllowVariableCrossLaneMask,
40912+
AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
40913+
return DAG.getBitcast(
40914+
RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
40915+
return SDValue();
40916+
} else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
40917+
Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40918+
Op.getOperand(1).getOperand(0).getValueSizeInBits() >
40919+
RootSizeInBits) {
40920+
// If we're inserting an subvector extracted from a vector larger than
40921+
// RootVT, then combine the insert_subvector as a shuffle, the
40922+
// extract_subvector will be folded in a later recursion.
40923+
SDValue BaseVec = Op.getOperand(0);
40924+
SDValue SubVec = Op.getOperand(1);
40925+
int InsertIdx = Op.getConstantOperandVal(2);
40926+
unsigned NumBaseElts = VT.getVectorNumElements();
40927+
unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
40928+
OpInputs.assign({BaseVec, SubVec});
40929+
OpMask.resize(NumBaseElts);
40930+
std::iota(OpMask.begin(), OpMask.end(), 0);
40931+
std::iota(OpMask.begin() + InsertIdx,
40932+
OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
40933+
OpZero = OpUndef = APInt::getZero(NumBaseElts);
4105140934
} else {
4105240935
return SDValue();
4105340936
}
@@ -41394,25 +41277,9 @@ static SDValue combineX86ShufflesRecursively(
4139441277
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
4139541278
IsMaskedShuffle, DAG, DL, Subtarget))
4139641279
return Shuffle;
41397-
41398-
// If all the operands come from the same larger vector, fallthrough and try
41399-
// to use combineX86ShuffleChainWithExtract.
41400-
SDValue LHS = peekThroughBitcasts(Ops.front());
41401-
SDValue RHS = peekThroughBitcasts(Ops.back());
41402-
if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41403-
(RootSizeInBits / Mask.size()) != 64 ||
41404-
LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41405-
RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41406-
LHS.getOperand(0) != RHS.getOperand(0))
41407-
return SDValue();
4140841280
}
4140941281

41410-
// If that failed and any input is extracted then try to combine as a
41411-
// shuffle with the larger type.
41412-
return combineX86ShuffleChainWithExtract(
41413-
Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41414-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41415-
DAG, DL, Subtarget);
41282+
return SDValue();
4141641283
}
4141741284

4141841285
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -44025,6 +43892,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
4402543892
case X86ISD::UNPCKL:
4402643893
case X86ISD::UNPCKH:
4402743894
case X86ISD::BLENDI:
43895+
case X86ISD::SHUFP:
4402843896
// Integer ops.
4402943897
case X86ISD::PACKSS:
4403043898
case X86ISD::PACKUS:

0 commit comments

Comments
 (0)