@@ -39924,13 +39924,6 @@ static bool matchBinaryPermuteShuffle(
39924
39924
return false;
39925
39925
}
39926
39926
39927
- static SDValue combineX86ShuffleChainWithExtract(
39928
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39929
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39930
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39931
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39932
- const X86Subtarget &Subtarget);
39933
-
39934
39927
/// Combine an arbitrary chain of shuffles into a single instruction if
39935
39928
/// possible.
39936
39929
///
@@ -40475,14 +40468,6 @@ static SDValue combineX86ShuffleChain(
40475
40468
return DAG.getBitcast(RootVT, Res);
40476
40469
}
40477
40470
40478
- // If that failed and either input is extracted then try to combine as a
40479
- // shuffle with the larger type.
40480
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40481
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40482
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40483
- IsMaskedShuffle, DAG, DL, Subtarget))
40484
- return WideShuffle;
40485
-
40486
40471
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
40487
40472
// (non-VLX will pad to 512-bit shuffles).
40488
40473
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40648,14 +40633,6 @@ static SDValue combineX86ShuffleChain(
40648
40633
return DAG.getBitcast(RootVT, Res);
40649
40634
}
40650
40635
40651
- // If that failed and either input is extracted then try to combine as a
40652
- // shuffle with the larger type.
40653
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40654
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40655
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40656
- DAG, DL, Subtarget))
40657
- return WideShuffle;
40658
-
40659
40636
// If we have a dual input shuffle then lower to VPERMV3,
40660
40637
// (non-VLX will pad to 512-bit shuffles)
40661
40638
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40681,149 +40658,6 @@ static SDValue combineX86ShuffleChain(
40681
40658
return SDValue();
40682
40659
}
40683
40660
40684
- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40685
- // instruction if possible.
40686
- //
40687
- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40688
- // type size to attempt to combine:
40689
- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40690
- // -->
40691
- // extract_subvector(shuffle(x,y,m2),0)
40692
- static SDValue combineX86ShuffleChainWithExtract(
40693
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40694
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40695
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40696
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40697
- const X86Subtarget &Subtarget) {
40698
- unsigned NumMaskElts = BaseMask.size();
40699
- unsigned NumInputs = Inputs.size();
40700
- if (NumInputs == 0)
40701
- return SDValue();
40702
-
40703
- unsigned RootSizeInBits = RootVT.getSizeInBits();
40704
- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40705
- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40706
-
40707
- // Peek through subvectors to find widest legal vector.
40708
- // TODO: Handle ISD::TRUNCATE
40709
- unsigned WideSizeInBits = RootSizeInBits;
40710
- for (SDValue Input : Inputs) {
40711
- Input = peekThroughBitcasts(Input);
40712
- while (1) {
40713
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40714
- Input = peekThroughBitcasts(Input.getOperand(0));
40715
- continue;
40716
- }
40717
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40718
- Input.getOperand(0).isUndef() &&
40719
- isNullConstant(Input.getOperand(2))) {
40720
- Input = peekThroughBitcasts(Input.getOperand(1));
40721
- continue;
40722
- }
40723
- break;
40724
- }
40725
- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40726
- WideSizeInBits < Input.getValueSizeInBits())
40727
- WideSizeInBits = Input.getValueSizeInBits();
40728
- }
40729
-
40730
- // Bail if we fail to find a source larger than the existing root.
40731
- if (WideSizeInBits <= RootSizeInBits ||
40732
- (WideSizeInBits % RootSizeInBits) != 0)
40733
- return SDValue();
40734
-
40735
- // Create new mask for larger type.
40736
- SmallVector<int, 64> WideMask;
40737
- growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40738
-
40739
- // Attempt to peek through inputs and adjust mask when we extract from an
40740
- // upper subvector.
40741
- int AdjustedMasks = 0;
40742
- SmallVector<SDValue, 4> WideInputs(Inputs);
40743
- for (unsigned I = 0; I != NumInputs; ++I) {
40744
- SDValue &Input = WideInputs[I];
40745
- Input = peekThroughBitcasts(Input);
40746
- while (1) {
40747
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40748
- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40749
- uint64_t Idx = Input.getConstantOperandVal(1);
40750
- if (Idx != 0) {
40751
- ++AdjustedMasks;
40752
- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40753
- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40754
-
40755
- int lo = I * WideMask.size();
40756
- int hi = (I + 1) * WideMask.size();
40757
- for (int &M : WideMask)
40758
- if (lo <= M && M < hi)
40759
- M += Idx;
40760
- }
40761
- Input = peekThroughBitcasts(Input.getOperand(0));
40762
- continue;
40763
- }
40764
- // TODO: Handle insertions into upper subvectors.
40765
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40766
- Input.getOperand(0).isUndef() &&
40767
- isNullConstant(Input.getOperand(2))) {
40768
- Input = peekThroughBitcasts(Input.getOperand(1));
40769
- continue;
40770
- }
40771
- break;
40772
- }
40773
- }
40774
-
40775
- // Remove unused/repeated shuffle source ops.
40776
- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40777
- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40778
-
40779
- // Bail if we're always extracting from the lowest subvectors,
40780
- // combineX86ShuffleChain should match this for the current width, or the
40781
- // shuffle still references too many inputs.
40782
- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40783
- return SDValue();
40784
-
40785
- // Minor canonicalization of the accumulated shuffle mask to make it easier
40786
- // to match below. All this does is detect masks with sequential pairs of
40787
- // elements, and shrink them to the half-width mask. It does this in a loop
40788
- // so it will reduce the size of the mask to the minimal width mask which
40789
- // performs an equivalent shuffle.
40790
- while (WideMask.size() > 1) {
40791
- SmallVector<int, 64> WidenedMask;
40792
- if (!canWidenShuffleElements(WideMask, WidenedMask))
40793
- break;
40794
- WideMask = std::move(WidenedMask);
40795
- }
40796
-
40797
- // Canonicalization of binary shuffle masks to improve pattern matching by
40798
- // commuting the inputs.
40799
- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40800
- ShuffleVectorSDNode::commuteMask(WideMask);
40801
- std::swap(WideInputs[0], WideInputs[1]);
40802
- }
40803
-
40804
- // Increase depth for every upper subvector we've peeked through.
40805
- Depth += AdjustedMasks;
40806
-
40807
- // Attempt to combine wider chain.
40808
- // TODO: Can we use a better Root?
40809
- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40810
- WideInputs.back().getValueSizeInBits()
40811
- ? WideInputs.front()
40812
- : WideInputs.back();
40813
- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40814
- "WideRootSize mismatch");
40815
-
40816
- if (SDValue WideShuffle = combineX86ShuffleChain(
40817
- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40818
- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40819
- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40820
- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40821
- return DAG.getBitcast(RootVT, WideShuffle);
40822
- }
40823
-
40824
- return SDValue();
40825
- }
40826
-
40827
40661
// Canonicalize the combined shuffle mask chain with horizontal ops.
40828
40662
// NOTE: This may update the Ops and Mask.
40829
40663
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -41236,6 +41070,54 @@ static SDValue combineX86ShufflesRecursively(
41236
41070
OpMask.assign(NumElts, SM_SentinelUndef);
41237
41071
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
41238
41072
OpZero = OpUndef = APInt::getZero(NumElts);
41073
+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41074
+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
41075
+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
41076
+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
41077
+ // Extracting from vector larger than RootVT - scale the mask and attempt to
41078
+ // fold the shuffle with the larger root type, then extract the lower
41079
+ // elements.
41080
+ unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
41081
+ unsigned Scale = NewRootSizeInBits / RootSizeInBits;
41082
+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
41083
+ Scale * RootVT.getVectorNumElements());
41084
+ SmallVector<int, 64> NewRootMask;
41085
+ growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
41086
+ // If we're using the lowest subvector, just replace it directly in the src
41087
+ // ops/nodes.
41088
+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
41089
+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
41090
+ if (isNullConstant(Op.getOperand(1))) {
41091
+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
41092
+ NewSrcNodes.push_back(Op.getNode());
41093
+ }
41094
+ // Don't increase the combine depth - we're effectively working on the same
41095
+ // nodes, just with a wider type.
41096
+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
41097
+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
41098
+ Depth, MaxDepth, AllowVariableCrossLaneMask,
41099
+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
41100
+ return DAG.getBitcast(
41101
+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
41102
+ return SDValue();
41103
+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
41104
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41105
+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
41106
+ RootSizeInBits) {
41107
+ // If we're inserting an subvector extracted from a vector larger than
41108
+ // RootVT, then combine the insert_subvector as a shuffle, the
41109
+ // extract_subvector will be folded in a later recursion.
41110
+ SDValue BaseVec = Op.getOperand(0);
41111
+ SDValue SubVec = Op.getOperand(1);
41112
+ int InsertIdx = Op.getConstantOperandVal(2);
41113
+ unsigned NumBaseElts = VT.getVectorNumElements();
41114
+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
41115
+ OpInputs.assign({BaseVec, SubVec});
41116
+ OpMask.resize(NumBaseElts);
41117
+ std::iota(OpMask.begin(), OpMask.end(), 0);
41118
+ std::iota(OpMask.begin() + InsertIdx,
41119
+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
41120
+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
41239
41121
} else {
41240
41122
return SDValue();
41241
41123
}
@@ -41582,25 +41464,9 @@ static SDValue combineX86ShufflesRecursively(
41582
41464
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
41583
41465
IsMaskedShuffle, DAG, DL, Subtarget))
41584
41466
return Shuffle;
41585
-
41586
- // If all the operands come from the same larger vector, fallthrough and try
41587
- // to use combineX86ShuffleChainWithExtract.
41588
- SDValue LHS = peekThroughBitcasts(Ops.front());
41589
- SDValue RHS = peekThroughBitcasts(Ops.back());
41590
- if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41591
- (RootSizeInBits / Mask.size()) != 64 ||
41592
- LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41593
- RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41594
- LHS.getOperand(0) != RHS.getOperand(0))
41595
- return SDValue();
41596
41467
}
41597
41468
41598
- // If that failed and any input is extracted then try to combine as a
41599
- // shuffle with the larger type.
41600
- return combineX86ShuffleChainWithExtract(
41601
- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41602
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41603
- DAG, DL, Subtarget);
41469
+ return SDValue();
41604
41470
}
41605
41471
41606
41472
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -44283,6 +44149,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
44283
44149
case X86ISD::UNPCKL:
44284
44150
case X86ISD::UNPCKH:
44285
44151
case X86ISD::BLENDI:
44152
+ case X86ISD::SHUFP:
44286
44153
// Integer ops.
44287
44154
case X86ISD::PACKSS:
44288
44155
case X86ISD::PACKUS:
0 commit comments