@@ -39902,13 +39902,6 @@ static bool matchBinaryPermuteShuffle(
39902
39902
return false;
39903
39903
}
39904
39904
39905
- static SDValue combineX86ShuffleChainWithExtract(
39906
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39907
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39908
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39909
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39910
- const X86Subtarget &Subtarget);
39911
-
39912
39905
/// Combine an arbitrary chain of shuffles into a single instruction if
39913
39906
/// possible.
39914
39907
///
@@ -40453,14 +40446,6 @@ static SDValue combineX86ShuffleChain(
40453
40446
return DAG.getBitcast(RootVT, Res);
40454
40447
}
40455
40448
40456
- // If that failed and either input is extracted then try to combine as a
40457
- // shuffle with the larger type.
40458
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40459
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40460
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40461
- IsMaskedShuffle, DAG, DL, Subtarget))
40462
- return WideShuffle;
40463
-
40464
40449
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
40465
40450
// (non-VLX will pad to 512-bit shuffles).
40466
40451
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40626,14 +40611,6 @@ static SDValue combineX86ShuffleChain(
40626
40611
return DAG.getBitcast(RootVT, Res);
40627
40612
}
40628
40613
40629
- // If that failed and either input is extracted then try to combine as a
40630
- // shuffle with the larger type.
40631
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40632
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40633
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40634
- DAG, DL, Subtarget))
40635
- return WideShuffle;
40636
-
40637
40614
// If we have a dual input shuffle then lower to VPERMV3,
40638
40615
// (non-VLX will pad to 512-bit shuffles)
40639
40616
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40659,149 +40636,6 @@ static SDValue combineX86ShuffleChain(
40659
40636
return SDValue();
40660
40637
}
40661
40638
40662
- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40663
- // instruction if possible.
40664
- //
40665
- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40666
- // type size to attempt to combine:
40667
- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40668
- // -->
40669
- // extract_subvector(shuffle(x,y,m2),0)
40670
- static SDValue combineX86ShuffleChainWithExtract(
40671
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40672
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40673
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40674
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40675
- const X86Subtarget &Subtarget) {
40676
- unsigned NumMaskElts = BaseMask.size();
40677
- unsigned NumInputs = Inputs.size();
40678
- if (NumInputs == 0)
40679
- return SDValue();
40680
-
40681
- unsigned RootSizeInBits = RootVT.getSizeInBits();
40682
- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40683
- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40684
-
40685
- // Peek through subvectors to find widest legal vector.
40686
- // TODO: Handle ISD::TRUNCATE
40687
- unsigned WideSizeInBits = RootSizeInBits;
40688
- for (SDValue Input : Inputs) {
40689
- Input = peekThroughBitcasts(Input);
40690
- while (1) {
40691
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40692
- Input = peekThroughBitcasts(Input.getOperand(0));
40693
- continue;
40694
- }
40695
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40696
- Input.getOperand(0).isUndef() &&
40697
- isNullConstant(Input.getOperand(2))) {
40698
- Input = peekThroughBitcasts(Input.getOperand(1));
40699
- continue;
40700
- }
40701
- break;
40702
- }
40703
- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40704
- WideSizeInBits < Input.getValueSizeInBits())
40705
- WideSizeInBits = Input.getValueSizeInBits();
40706
- }
40707
-
40708
- // Bail if we fail to find a source larger than the existing root.
40709
- if (WideSizeInBits <= RootSizeInBits ||
40710
- (WideSizeInBits % RootSizeInBits) != 0)
40711
- return SDValue();
40712
-
40713
- // Create new mask for larger type.
40714
- SmallVector<int, 64> WideMask;
40715
- growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40716
-
40717
- // Attempt to peek through inputs and adjust mask when we extract from an
40718
- // upper subvector.
40719
- int AdjustedMasks = 0;
40720
- SmallVector<SDValue, 4> WideInputs(Inputs);
40721
- for (unsigned I = 0; I != NumInputs; ++I) {
40722
- SDValue &Input = WideInputs[I];
40723
- Input = peekThroughBitcasts(Input);
40724
- while (1) {
40725
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40726
- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40727
- uint64_t Idx = Input.getConstantOperandVal(1);
40728
- if (Idx != 0) {
40729
- ++AdjustedMasks;
40730
- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40731
- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40732
-
40733
- int lo = I * WideMask.size();
40734
- int hi = (I + 1) * WideMask.size();
40735
- for (int &M : WideMask)
40736
- if (lo <= M && M < hi)
40737
- M += Idx;
40738
- }
40739
- Input = peekThroughBitcasts(Input.getOperand(0));
40740
- continue;
40741
- }
40742
- // TODO: Handle insertions into upper subvectors.
40743
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40744
- Input.getOperand(0).isUndef() &&
40745
- isNullConstant(Input.getOperand(2))) {
40746
- Input = peekThroughBitcasts(Input.getOperand(1));
40747
- continue;
40748
- }
40749
- break;
40750
- }
40751
- }
40752
-
40753
- // Remove unused/repeated shuffle source ops.
40754
- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40755
- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40756
-
40757
- // Bail if we're always extracting from the lowest subvectors,
40758
- // combineX86ShuffleChain should match this for the current width, or the
40759
- // shuffle still references too many inputs.
40760
- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40761
- return SDValue();
40762
-
40763
- // Minor canonicalization of the accumulated shuffle mask to make it easier
40764
- // to match below. All this does is detect masks with sequential pairs of
40765
- // elements, and shrink them to the half-width mask. It does this in a loop
40766
- // so it will reduce the size of the mask to the minimal width mask which
40767
- // performs an equivalent shuffle.
40768
- while (WideMask.size() > 1) {
40769
- SmallVector<int, 64> WidenedMask;
40770
- if (!canWidenShuffleElements(WideMask, WidenedMask))
40771
- break;
40772
- WideMask = std::move(WidenedMask);
40773
- }
40774
-
40775
- // Canonicalization of binary shuffle masks to improve pattern matching by
40776
- // commuting the inputs.
40777
- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40778
- ShuffleVectorSDNode::commuteMask(WideMask);
40779
- std::swap(WideInputs[0], WideInputs[1]);
40780
- }
40781
-
40782
- // Increase depth for every upper subvector we've peeked through.
40783
- Depth += AdjustedMasks;
40784
-
40785
- // Attempt to combine wider chain.
40786
- // TODO: Can we use a better Root?
40787
- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40788
- WideInputs.back().getValueSizeInBits()
40789
- ? WideInputs.front()
40790
- : WideInputs.back();
40791
- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40792
- "WideRootSize mismatch");
40793
-
40794
- if (SDValue WideShuffle = combineX86ShuffleChain(
40795
- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40796
- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40797
- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40798
- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40799
- return DAG.getBitcast(RootVT, WideShuffle);
40800
- }
40801
-
40802
- return SDValue();
40803
- }
40804
-
40805
40639
// Canonicalize the combined shuffle mask chain with horizontal ops.
40806
40640
// NOTE: This may update the Ops and Mask.
40807
40641
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -41214,6 +41048,54 @@ static SDValue combineX86ShufflesRecursively(
41214
41048
OpMask.assign(NumElts, SM_SentinelUndef);
41215
41049
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
41216
41050
OpZero = OpUndef = APInt::getZero(NumElts);
41051
+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41052
+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
41053
+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
41054
+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
41055
+ // Extracting from vector larger than RootVT - scale the mask and attempt to
41056
+ // fold the shuffle with the larger root type, then extract the lower
41057
+ // elements.
41058
+ unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
41059
+ unsigned Scale = NewRootSizeInBits / RootSizeInBits;
41060
+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
41061
+ Scale * RootVT.getVectorNumElements());
41062
+ SmallVector<int, 64> NewRootMask;
41063
+ growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
41064
+ // If we're using the lowest subvector, just replace it directly in the src
41065
+ // ops/nodes.
41066
+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
41067
+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
41068
+ if (isNullConstant(Op.getOperand(1))) {
41069
+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
41070
+ NewSrcNodes.push_back(Op.getNode());
41071
+ }
41072
+ // Don't increase the combine depth - we're effectively working on the same
41073
+ // nodes, just with a wider type.
41074
+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
41075
+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
41076
+ Depth, MaxDepth, AllowVariableCrossLaneMask,
41077
+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
41078
+ return DAG.getBitcast(
41079
+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
41080
+ return SDValue();
41081
+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
41082
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41083
+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
41084
+ RootSizeInBits) {
41085
+ // If we're inserting an subvector extracted from a vector larger than
41086
+ // RootVT, then combine the insert_subvector as a shuffle, the
41087
+ // extract_subvector will be folded in a later recursion.
41088
+ SDValue BaseVec = Op.getOperand(0);
41089
+ SDValue SubVec = Op.getOperand(1);
41090
+ int InsertIdx = Op.getConstantOperandVal(2);
41091
+ unsigned NumBaseElts = VT.getVectorNumElements();
41092
+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
41093
+ OpInputs.assign({BaseVec, SubVec});
41094
+ OpMask.resize(NumBaseElts);
41095
+ std::iota(OpMask.begin(), OpMask.end(), 0);
41096
+ std::iota(OpMask.begin() + InsertIdx,
41097
+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
41098
+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
41217
41099
} else {
41218
41100
return SDValue();
41219
41101
}
@@ -41560,25 +41442,9 @@ static SDValue combineX86ShufflesRecursively(
41560
41442
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
41561
41443
IsMaskedShuffle, DAG, DL, Subtarget))
41562
41444
return Shuffle;
41563
-
41564
- // If all the operands come from the same larger vector, fallthrough and try
41565
- // to use combineX86ShuffleChainWithExtract.
41566
- SDValue LHS = peekThroughBitcasts(Ops.front());
41567
- SDValue RHS = peekThroughBitcasts(Ops.back());
41568
- if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41569
- (RootSizeInBits / Mask.size()) != 64 ||
41570
- LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41571
- RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41572
- LHS.getOperand(0) != RHS.getOperand(0))
41573
- return SDValue();
41574
41445
}
41575
41446
41576
- // If that failed and any input is extracted then try to combine as a
41577
- // shuffle with the larger type.
41578
- return combineX86ShuffleChainWithExtract(
41579
- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41580
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41581
- DAG, DL, Subtarget);
41447
+ return SDValue();
41582
41448
}
41583
41449
41584
41450
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -44212,6 +44078,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
44212
44078
case X86ISD::UNPCKL:
44213
44079
case X86ISD::UNPCKH:
44214
44080
case X86ISD::BLENDI:
44081
+ case X86ISD::SHUFP:
44215
44082
// Integer ops.
44216
44083
case X86ISD::PACKSS:
44217
44084
case X86ISD::PACKUS:
0 commit comments