@@ -39872,13 +39872,6 @@ static bool matchBinaryPermuteShuffle(
39872
39872
return false;
39873
39873
}
39874
39874
39875
- static SDValue combineX86ShuffleChainWithExtract(
39876
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39877
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39878
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39879
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39880
- const X86Subtarget &Subtarget);
39881
-
39882
39875
/// Combine an arbitrary chain of shuffles into a single instruction if
39883
39876
/// possible.
39884
39877
///
@@ -40423,14 +40416,6 @@ static SDValue combineX86ShuffleChain(
40423
40416
return DAG.getBitcast(RootVT, Res);
40424
40417
}
40425
40418
40426
- // If that failed and either input is extracted then try to combine as a
40427
- // shuffle with the larger type.
40428
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40429
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40430
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40431
- IsMaskedShuffle, DAG, DL, Subtarget))
40432
- return WideShuffle;
40433
-
40434
40419
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
40435
40420
// (non-VLX will pad to 512-bit shuffles).
40436
40421
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40596,14 +40581,6 @@ static SDValue combineX86ShuffleChain(
40596
40581
return DAG.getBitcast(RootVT, Res);
40597
40582
}
40598
40583
40599
- // If that failed and either input is extracted then try to combine as a
40600
- // shuffle with the larger type.
40601
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40602
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40603
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40604
- DAG, DL, Subtarget))
40605
- return WideShuffle;
40606
-
40607
40584
// If we have a dual input shuffle then lower to VPERMV3,
40608
40585
// (non-VLX will pad to 512-bit shuffles)
40609
40586
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40629,149 +40606,6 @@ static SDValue combineX86ShuffleChain(
40629
40606
return SDValue();
40630
40607
}
40631
40608
40632
- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40633
- // instruction if possible.
40634
- //
40635
- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40636
- // type size to attempt to combine:
40637
- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40638
- // -->
40639
- // extract_subvector(shuffle(x,y,m2),0)
40640
- static SDValue combineX86ShuffleChainWithExtract(
40641
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40642
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40643
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40644
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40645
- const X86Subtarget &Subtarget) {
40646
- unsigned NumMaskElts = BaseMask.size();
40647
- unsigned NumInputs = Inputs.size();
40648
- if (NumInputs == 0)
40649
- return SDValue();
40650
-
40651
- unsigned RootSizeInBits = RootVT.getSizeInBits();
40652
- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40653
- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40654
-
40655
- // Peek through subvectors to find widest legal vector.
40656
- // TODO: Handle ISD::TRUNCATE
40657
- unsigned WideSizeInBits = RootSizeInBits;
40658
- for (SDValue Input : Inputs) {
40659
- Input = peekThroughBitcasts(Input);
40660
- while (1) {
40661
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40662
- Input = peekThroughBitcasts(Input.getOperand(0));
40663
- continue;
40664
- }
40665
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40666
- Input.getOperand(0).isUndef() &&
40667
- isNullConstant(Input.getOperand(2))) {
40668
- Input = peekThroughBitcasts(Input.getOperand(1));
40669
- continue;
40670
- }
40671
- break;
40672
- }
40673
- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40674
- WideSizeInBits < Input.getValueSizeInBits())
40675
- WideSizeInBits = Input.getValueSizeInBits();
40676
- }
40677
-
40678
- // Bail if we fail to find a source larger than the existing root.
40679
- if (WideSizeInBits <= RootSizeInBits ||
40680
- (WideSizeInBits % RootSizeInBits) != 0)
40681
- return SDValue();
40682
-
40683
- // Create new mask for larger type.
40684
- SmallVector<int, 64> WideMask;
40685
- growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40686
-
40687
- // Attempt to peek through inputs and adjust mask when we extract from an
40688
- // upper subvector.
40689
- int AdjustedMasks = 0;
40690
- SmallVector<SDValue, 4> WideInputs(Inputs);
40691
- for (unsigned I = 0; I != NumInputs; ++I) {
40692
- SDValue &Input = WideInputs[I];
40693
- Input = peekThroughBitcasts(Input);
40694
- while (1) {
40695
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40696
- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40697
- uint64_t Idx = Input.getConstantOperandVal(1);
40698
- if (Idx != 0) {
40699
- ++AdjustedMasks;
40700
- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40701
- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40702
-
40703
- int lo = I * WideMask.size();
40704
- int hi = (I + 1) * WideMask.size();
40705
- for (int &M : WideMask)
40706
- if (lo <= M && M < hi)
40707
- M += Idx;
40708
- }
40709
- Input = peekThroughBitcasts(Input.getOperand(0));
40710
- continue;
40711
- }
40712
- // TODO: Handle insertions into upper subvectors.
40713
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40714
- Input.getOperand(0).isUndef() &&
40715
- isNullConstant(Input.getOperand(2))) {
40716
- Input = peekThroughBitcasts(Input.getOperand(1));
40717
- continue;
40718
- }
40719
- break;
40720
- }
40721
- }
40722
-
40723
- // Remove unused/repeated shuffle source ops.
40724
- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40725
- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40726
-
40727
- // Bail if we're always extracting from the lowest subvectors,
40728
- // combineX86ShuffleChain should match this for the current width, or the
40729
- // shuffle still references too many inputs.
40730
- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40731
- return SDValue();
40732
-
40733
- // Minor canonicalization of the accumulated shuffle mask to make it easier
40734
- // to match below. All this does is detect masks with sequential pairs of
40735
- // elements, and shrink them to the half-width mask. It does this in a loop
40736
- // so it will reduce the size of the mask to the minimal width mask which
40737
- // performs an equivalent shuffle.
40738
- while (WideMask.size() > 1) {
40739
- SmallVector<int, 64> WidenedMask;
40740
- if (!canWidenShuffleElements(WideMask, WidenedMask))
40741
- break;
40742
- WideMask = std::move(WidenedMask);
40743
- }
40744
-
40745
- // Canonicalization of binary shuffle masks to improve pattern matching by
40746
- // commuting the inputs.
40747
- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40748
- ShuffleVectorSDNode::commuteMask(WideMask);
40749
- std::swap(WideInputs[0], WideInputs[1]);
40750
- }
40751
-
40752
- // Increase depth for every upper subvector we've peeked through.
40753
- Depth += AdjustedMasks;
40754
-
40755
- // Attempt to combine wider chain.
40756
- // TODO: Can we use a better Root?
40757
- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40758
- WideInputs.back().getValueSizeInBits()
40759
- ? WideInputs.front()
40760
- : WideInputs.back();
40761
- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40762
- "WideRootSize mismatch");
40763
-
40764
- if (SDValue WideShuffle = combineX86ShuffleChain(
40765
- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40766
- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40767
- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40768
- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40769
- return DAG.getBitcast(RootVT, WideShuffle);
40770
- }
40771
-
40772
- return SDValue();
40773
- }
40774
-
40775
40609
// Canonicalize the combined shuffle mask chain with horizontal ops.
40776
40610
// NOTE: This may update the Ops and Mask.
40777
40611
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -41184,6 +41018,54 @@ static SDValue combineX86ShufflesRecursively(
41184
41018
OpMask.assign(NumElts, SM_SentinelUndef);
41185
41019
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
41186
41020
OpZero = OpUndef = APInt::getZero(NumElts);
41021
+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41022
+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
41023
+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
41024
+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
41025
+ // Extracting from vector larger than RootVT - scale the mask and attempt to
41026
+ // fold the shuffle with the larger root type, then extract the lower
41027
+ // elements.
41028
+ unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
41029
+ unsigned Scale = NewRootSizeInBits / RootSizeInBits;
41030
+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
41031
+ Scale * RootVT.getVectorNumElements());
41032
+ SmallVector<int, 64> NewRootMask;
41033
+ growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
41034
+ // If we're using the lowest subvector, just replace it directly in the src
41035
+ // ops/nodes.
41036
+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
41037
+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
41038
+ if (isNullConstant(Op.getOperand(1))) {
41039
+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
41040
+ NewSrcNodes.push_back(Op.getNode());
41041
+ }
41042
+ // Don't increase the combine depth - we're effectively working on the same
41043
+ // nodes, just with a wider type.
41044
+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
41045
+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
41046
+ Depth, MaxDepth, AllowVariableCrossLaneMask,
41047
+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
41048
+ return DAG.getBitcast(
41049
+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
41050
+ return SDValue();
41051
+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
41052
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41053
+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
41054
+ RootSizeInBits) {
41055
+ // If we're inserting an subvector extracted from a vector larger than
41056
+ // RootVT, then combine the insert_subvector as a shuffle, the
41057
+ // extract_subvector will be folded in a later recursion.
41058
+ SDValue BaseVec = Op.getOperand(0);
41059
+ SDValue SubVec = Op.getOperand(1);
41060
+ int InsertIdx = Op.getConstantOperandVal(2);
41061
+ unsigned NumBaseElts = VT.getVectorNumElements();
41062
+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
41063
+ OpInputs.assign({BaseVec, SubVec});
41064
+ OpMask.resize(NumBaseElts);
41065
+ std::iota(OpMask.begin(), OpMask.end(), 0);
41066
+ std::iota(OpMask.begin() + InsertIdx,
41067
+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
41068
+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
41187
41069
} else {
41188
41070
return SDValue();
41189
41071
}
@@ -41530,25 +41412,9 @@ static SDValue combineX86ShufflesRecursively(
41530
41412
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
41531
41413
IsMaskedShuffle, DAG, DL, Subtarget))
41532
41414
return Shuffle;
41533
-
41534
- // If all the operands come from the same larger vector, fallthrough and try
41535
- // to use combineX86ShuffleChainWithExtract.
41536
- SDValue LHS = peekThroughBitcasts(Ops.front());
41537
- SDValue RHS = peekThroughBitcasts(Ops.back());
41538
- if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41539
- (RootSizeInBits / Mask.size()) != 64 ||
41540
- LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41541
- RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41542
- LHS.getOperand(0) != RHS.getOperand(0))
41543
- return SDValue();
41544
41415
}
41545
41416
41546
- // If that failed and any input is extracted then try to combine as a
41547
- // shuffle with the larger type.
41548
- return combineX86ShuffleChainWithExtract(
41549
- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41550
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41551
- DAG, DL, Subtarget);
41417
+ return SDValue();
41552
41418
}
41553
41419
41554
41420
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -44230,6 +44096,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
44230
44096
case X86ISD::UNPCKL:
44231
44097
case X86ISD::UNPCKH:
44232
44098
case X86ISD::BLENDI:
44099
+ case X86ISD::SHUFP:
44233
44100
// Integer ops.
44234
44101
case X86ISD::PACKSS:
44235
44102
case X86ISD::PACKUS:
0 commit comments