Skip to content

Commit 93bc1c2

Browse files
committed
[WIP][X86] combineX86ShufflesRecursively - attempt to combine shuffles with larger types from EXTRACT_SUBVECTOR nodes
This replaces the rather limited combineX86ShuffleChainWithExtract function with handling for EXTRACT_SUBVECTOR node as we recurse down the shuffle chain, widening the shuffle mask to accommodate the larger value type. This will mainly help AVX2/AVX512 cases with cross-lane shuffles, but it also helps collapse some cases where the same subvector has gotten reused in multiple lanes. Exposed missing DemandedElts handling inside ISD::TRUNCATE nodes for ComputeNumSignBits
1 parent 3dc9f2d commit 93bc1c2

File tree

95 files changed

+50859
-49771
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+50859
-49771
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5109,7 +5109,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
51095109
case ISD::TRUNCATE: {
51105110
// Check if the sign bits of source go down as far as the truncated value.
51115111
unsigned NumSrcBits = Op.getOperand(0).getScalarValueSizeInBits();
5112-
unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
5112+
unsigned NumSrcSignBits =
5113+
ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
51135114
if (NumSrcSignBits > (NumSrcBits - VTBits))
51145115
return NumSrcSignBits - (NumSrcBits - VTBits);
51155116
break;

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 50 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -39902,13 +39902,6 @@ static bool matchBinaryPermuteShuffle(
3990239902
return false;
3990339903
}
3990439904

39905-
static SDValue combineX86ShuffleChainWithExtract(
39906-
ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39907-
ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39908-
bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39909-
bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39910-
const X86Subtarget &Subtarget);
39911-
3991239905
/// Combine an arbitrary chain of shuffles into a single instruction if
3991339906
/// possible.
3991439907
///
@@ -40453,14 +40446,6 @@ static SDValue combineX86ShuffleChain(
4045340446
return DAG.getBitcast(RootVT, Res);
4045440447
}
4045540448

40456-
// If that failed and either input is extracted then try to combine as a
40457-
// shuffle with the larger type.
40458-
if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40459-
Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40460-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40461-
IsMaskedShuffle, DAG, DL, Subtarget))
40462-
return WideShuffle;
40463-
4046440449
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
4046540450
// (non-VLX will pad to 512-bit shuffles).
4046640451
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40626,14 +40611,6 @@ static SDValue combineX86ShuffleChain(
4062640611
return DAG.getBitcast(RootVT, Res);
4062740612
}
4062840613

40629-
// If that failed and either input is extracted then try to combine as a
40630-
// shuffle with the larger type.
40631-
if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40632-
Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40633-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40634-
DAG, DL, Subtarget))
40635-
return WideShuffle;
40636-
4063740614
// If we have a dual input shuffle then lower to VPERMV3,
4063840615
// (non-VLX will pad to 512-bit shuffles)
4063940616
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40659,149 +40636,6 @@ static SDValue combineX86ShuffleChain(
4065940636
return SDValue();
4066040637
}
4066140638

40662-
// Combine an arbitrary chain of shuffles + extract_subvectors into a single
40663-
// instruction if possible.
40664-
//
40665-
// Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40666-
// type size to attempt to combine:
40667-
// shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40668-
// -->
40669-
// extract_subvector(shuffle(x,y,m2),0)
40670-
static SDValue combineX86ShuffleChainWithExtract(
40671-
ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40672-
ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40673-
bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40674-
bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40675-
const X86Subtarget &Subtarget) {
40676-
unsigned NumMaskElts = BaseMask.size();
40677-
unsigned NumInputs = Inputs.size();
40678-
if (NumInputs == 0)
40679-
return SDValue();
40680-
40681-
unsigned RootSizeInBits = RootVT.getSizeInBits();
40682-
unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40683-
assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40684-
40685-
// Peek through subvectors to find widest legal vector.
40686-
// TODO: Handle ISD::TRUNCATE
40687-
unsigned WideSizeInBits = RootSizeInBits;
40688-
for (SDValue Input : Inputs) {
40689-
Input = peekThroughBitcasts(Input);
40690-
while (1) {
40691-
if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40692-
Input = peekThroughBitcasts(Input.getOperand(0));
40693-
continue;
40694-
}
40695-
if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40696-
Input.getOperand(0).isUndef() &&
40697-
isNullConstant(Input.getOperand(2))) {
40698-
Input = peekThroughBitcasts(Input.getOperand(1));
40699-
continue;
40700-
}
40701-
break;
40702-
}
40703-
if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40704-
WideSizeInBits < Input.getValueSizeInBits())
40705-
WideSizeInBits = Input.getValueSizeInBits();
40706-
}
40707-
40708-
// Bail if we fail to find a source larger than the existing root.
40709-
if (WideSizeInBits <= RootSizeInBits ||
40710-
(WideSizeInBits % RootSizeInBits) != 0)
40711-
return SDValue();
40712-
40713-
// Create new mask for larger type.
40714-
SmallVector<int, 64> WideMask;
40715-
growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40716-
40717-
// Attempt to peek through inputs and adjust mask when we extract from an
40718-
// upper subvector.
40719-
int AdjustedMasks = 0;
40720-
SmallVector<SDValue, 4> WideInputs(Inputs);
40721-
for (unsigned I = 0; I != NumInputs; ++I) {
40722-
SDValue &Input = WideInputs[I];
40723-
Input = peekThroughBitcasts(Input);
40724-
while (1) {
40725-
if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40726-
Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40727-
uint64_t Idx = Input.getConstantOperandVal(1);
40728-
if (Idx != 0) {
40729-
++AdjustedMasks;
40730-
unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40731-
Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40732-
40733-
int lo = I * WideMask.size();
40734-
int hi = (I + 1) * WideMask.size();
40735-
for (int &M : WideMask)
40736-
if (lo <= M && M < hi)
40737-
M += Idx;
40738-
}
40739-
Input = peekThroughBitcasts(Input.getOperand(0));
40740-
continue;
40741-
}
40742-
// TODO: Handle insertions into upper subvectors.
40743-
if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40744-
Input.getOperand(0).isUndef() &&
40745-
isNullConstant(Input.getOperand(2))) {
40746-
Input = peekThroughBitcasts(Input.getOperand(1));
40747-
continue;
40748-
}
40749-
break;
40750-
}
40751-
}
40752-
40753-
// Remove unused/repeated shuffle source ops.
40754-
resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40755-
assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40756-
40757-
// Bail if we're always extracting from the lowest subvectors,
40758-
// combineX86ShuffleChain should match this for the current width, or the
40759-
// shuffle still references too many inputs.
40760-
if (AdjustedMasks == 0 || WideInputs.size() > 2)
40761-
return SDValue();
40762-
40763-
// Minor canonicalization of the accumulated shuffle mask to make it easier
40764-
// to match below. All this does is detect masks with sequential pairs of
40765-
// elements, and shrink them to the half-width mask. It does this in a loop
40766-
// so it will reduce the size of the mask to the minimal width mask which
40767-
// performs an equivalent shuffle.
40768-
while (WideMask.size() > 1) {
40769-
SmallVector<int, 64> WidenedMask;
40770-
if (!canWidenShuffleElements(WideMask, WidenedMask))
40771-
break;
40772-
WideMask = std::move(WidenedMask);
40773-
}
40774-
40775-
// Canonicalization of binary shuffle masks to improve pattern matching by
40776-
// commuting the inputs.
40777-
if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40778-
ShuffleVectorSDNode::commuteMask(WideMask);
40779-
std::swap(WideInputs[0], WideInputs[1]);
40780-
}
40781-
40782-
// Increase depth for every upper subvector we've peeked through.
40783-
Depth += AdjustedMasks;
40784-
40785-
// Attempt to combine wider chain.
40786-
// TODO: Can we use a better Root?
40787-
SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40788-
WideInputs.back().getValueSizeInBits()
40789-
? WideInputs.front()
40790-
: WideInputs.back();
40791-
assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40792-
"WideRootSize mismatch");
40793-
40794-
if (SDValue WideShuffle = combineX86ShuffleChain(
40795-
WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40796-
Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40797-
IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40798-
WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40799-
return DAG.getBitcast(RootVT, WideShuffle);
40800-
}
40801-
40802-
return SDValue();
40803-
}
40804-
4080540639
// Canonicalize the combined shuffle mask chain with horizontal ops.
4080640640
// NOTE: This may update the Ops and Mask.
4080740641
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -41214,6 +41048,54 @@ static SDValue combineX86ShufflesRecursively(
4121441048
OpMask.assign(NumElts, SM_SentinelUndef);
4121541049
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
4121641050
OpZero = OpUndef = APInt::getZero(NumElts);
41051+
} else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41052+
TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
41053+
Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
41054+
(Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
41055+
// Extracting from vector larger than RootVT - scale the mask and attempt to
41056+
// fold the shuffle with the larger root type, then extract the lower
41057+
// elements.
41058+
unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
41059+
unsigned Scale = NewRootSizeInBits / RootSizeInBits;
41060+
MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
41061+
Scale * RootVT.getVectorNumElements());
41062+
SmallVector<int, 64> NewRootMask;
41063+
growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
41064+
// If we're using the lowest subvector, just replace it directly in the src
41065+
// ops/nodes.
41066+
SmallVector<SDValue, 16> NewSrcOps(SrcOps);
41067+
SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
41068+
if (isNullConstant(Op.getOperand(1))) {
41069+
NewSrcOps[SrcOpIndex] = Op.getOperand(0);
41070+
NewSrcNodes.push_back(Op.getNode());
41071+
}
41072+
// Don't increase the combine depth - we're effectively working on the same
41073+
// nodes, just with a wider type.
41074+
if (SDValue WideShuffle = combineX86ShufflesRecursively(
41075+
NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
41076+
Depth, MaxDepth, AllowVariableCrossLaneMask,
41077+
AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
41078+
return DAG.getBitcast(
41079+
RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
41080+
return SDValue();
41081+
} else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
41082+
Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41083+
Op.getOperand(1).getOperand(0).getValueSizeInBits() >
41084+
RootSizeInBits) {
41085+
// If we're inserting an subvector extracted from a vector larger than
41086+
// RootVT, then combine the insert_subvector as a shuffle, the
41087+
// extract_subvector will be folded in a later recursion.
41088+
SDValue BaseVec = Op.getOperand(0);
41089+
SDValue SubVec = Op.getOperand(1);
41090+
int InsertIdx = Op.getConstantOperandVal(2);
41091+
unsigned NumBaseElts = VT.getVectorNumElements();
41092+
unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
41093+
OpInputs.assign({BaseVec, SubVec});
41094+
OpMask.resize(NumBaseElts);
41095+
std::iota(OpMask.begin(), OpMask.end(), 0);
41096+
std::iota(OpMask.begin() + InsertIdx,
41097+
OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
41098+
OpZero = OpUndef = APInt::getZero(NumBaseElts);
4121741099
} else {
4121841100
return SDValue();
4121941101
}
@@ -41560,25 +41442,9 @@ static SDValue combineX86ShufflesRecursively(
4156041442
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
4156141443
IsMaskedShuffle, DAG, DL, Subtarget))
4156241444
return Shuffle;
41563-
41564-
// If all the operands come from the same larger vector, fallthrough and try
41565-
// to use combineX86ShuffleChainWithExtract.
41566-
SDValue LHS = peekThroughBitcasts(Ops.front());
41567-
SDValue RHS = peekThroughBitcasts(Ops.back());
41568-
if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41569-
(RootSizeInBits / Mask.size()) != 64 ||
41570-
LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41571-
RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41572-
LHS.getOperand(0) != RHS.getOperand(0))
41573-
return SDValue();
4157441445
}
4157541446

41576-
// If that failed and any input is extracted then try to combine as a
41577-
// shuffle with the larger type.
41578-
return combineX86ShuffleChainWithExtract(
41579-
Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41580-
AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41581-
DAG, DL, Subtarget);
41447+
return SDValue();
4158241448
}
4158341449

4158441450
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -44212,6 +44078,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
4421244078
case X86ISD::UNPCKL:
4421344079
case X86ISD::UNPCKH:
4421444080
case X86ISD::BLENDI:
44081+
case X86ISD::SHUFP:
4421544082
// Integer ops.
4421644083
case X86ISD::PACKSS:
4421744084
case X86ISD::PACKUS:

0 commit comments

Comments
 (0)