@@ -39737,13 +39737,6 @@ static bool matchBinaryPermuteShuffle(
39737
39737
return false;
39738
39738
}
39739
39739
39740
- static SDValue combineX86ShuffleChainWithExtract(
39741
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39742
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39743
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39744
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39745
- const X86Subtarget &Subtarget);
39746
-
39747
39740
/// Combine an arbitrary chain of shuffles into a single instruction if
39748
39741
/// possible.
39749
39742
///
@@ -40288,14 +40281,6 @@ static SDValue combineX86ShuffleChain(
40288
40281
return DAG.getBitcast(RootVT, Res);
40289
40282
}
40290
40283
40291
- // If that failed and either input is extracted then try to combine as a
40292
- // shuffle with the larger type.
40293
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40294
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40295
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40296
- IsMaskedShuffle, DAG, DL, Subtarget))
40297
- return WideShuffle;
40298
-
40299
40284
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
40300
40285
// (non-VLX will pad to 512-bit shuffles).
40301
40286
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40461,14 +40446,6 @@ static SDValue combineX86ShuffleChain(
40461
40446
return DAG.getBitcast(RootVT, Res);
40462
40447
}
40463
40448
40464
- // If that failed and either input is extracted then try to combine as a
40465
- // shuffle with the larger type.
40466
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40467
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40468
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40469
- DAG, DL, Subtarget))
40470
- return WideShuffle;
40471
-
40472
40449
// If we have a dual input shuffle then lower to VPERMV3,
40473
40450
// (non-VLX will pad to 512-bit shuffles)
40474
40451
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40494,148 +40471,6 @@ static SDValue combineX86ShuffleChain(
40494
40471
return SDValue();
40495
40472
}
40496
40473
40497
- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40498
- // instruction if possible.
40499
- //
40500
- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40501
- // type size to attempt to combine:
40502
- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40503
- // -->
40504
- // extract_subvector(shuffle(x,y,m2),0)
40505
- static SDValue combineX86ShuffleChainWithExtract(
40506
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40507
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40508
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40509
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40510
- const X86Subtarget &Subtarget) {
40511
- unsigned NumMaskElts = BaseMask.size();
40512
- unsigned NumInputs = Inputs.size();
40513
- if (NumInputs == 0)
40514
- return SDValue();
40515
-
40516
- unsigned RootSizeInBits = RootVT.getSizeInBits();
40517
- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40518
- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40519
-
40520
- // Peek through subvectors to find widest legal vector.
40521
- // TODO: Handle ISD::TRUNCATE
40522
- unsigned WideSizeInBits = RootSizeInBits;
40523
- for (SDValue Input : Inputs) {
40524
- Input = peekThroughBitcasts(Input);
40525
- while (1) {
40526
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40527
- Input = peekThroughBitcasts(Input.getOperand(0));
40528
- continue;
40529
- }
40530
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40531
- Input.getOperand(0).isUndef()) {
40532
- Input = peekThroughBitcasts(Input.getOperand(1));
40533
- continue;
40534
- }
40535
- break;
40536
- }
40537
- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40538
- WideSizeInBits < Input.getValueSizeInBits())
40539
- WideSizeInBits = Input.getValueSizeInBits();
40540
- }
40541
-
40542
- // Bail if we fail to find a source larger than the existing root.
40543
- if (WideSizeInBits <= RootSizeInBits ||
40544
- (WideSizeInBits % RootSizeInBits) != 0)
40545
- return SDValue();
40546
-
40547
- // Create new mask for larger type.
40548
- SmallVector<int, 64> WideMask;
40549
- growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40550
-
40551
- // Attempt to peek through inputs and adjust mask when we extract from an
40552
- // upper subvector.
40553
- int AdjustedMasks = 0;
40554
- SmallVector<SDValue, 4> WideInputs(Inputs);
40555
- for (unsigned I = 0; I != NumInputs; ++I) {
40556
- SDValue &Input = WideInputs[I];
40557
- Input = peekThroughBitcasts(Input);
40558
- while (1) {
40559
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40560
- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40561
- uint64_t Idx = Input.getConstantOperandVal(1);
40562
- if (Idx != 0) {
40563
- ++AdjustedMasks;
40564
- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40565
- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40566
-
40567
- int lo = I * WideMask.size();
40568
- int hi = (I + 1) * WideMask.size();
40569
- for (int &M : WideMask)
40570
- if (lo <= M && M < hi)
40571
- M += Idx;
40572
- }
40573
- Input = peekThroughBitcasts(Input.getOperand(0));
40574
- continue;
40575
- }
40576
- // TODO: Handle insertions into upper subvectors.
40577
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40578
- Input.getOperand(0).isUndef() &&
40579
- isNullConstant(Input.getOperand(2))) {
40580
- Input = peekThroughBitcasts(Input.getOperand(1));
40581
- continue;
40582
- }
40583
- break;
40584
- }
40585
- }
40586
-
40587
- // Remove unused/repeated shuffle source ops.
40588
- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40589
- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40590
-
40591
- // Bail if we're always extracting from the lowest subvectors,
40592
- // combineX86ShuffleChain should match this for the current width, or the
40593
- // shuffle still references too many inputs.
40594
- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40595
- return SDValue();
40596
-
40597
- // Minor canonicalization of the accumulated shuffle mask to make it easier
40598
- // to match below. All this does is detect masks with sequential pairs of
40599
- // elements, and shrink them to the half-width mask. It does this in a loop
40600
- // so it will reduce the size of the mask to the minimal width mask which
40601
- // performs an equivalent shuffle.
40602
- while (WideMask.size() > 1) {
40603
- SmallVector<int, 64> WidenedMask;
40604
- if (!canWidenShuffleElements(WideMask, WidenedMask))
40605
- break;
40606
- WideMask = std::move(WidenedMask);
40607
- }
40608
-
40609
- // Canonicalization of binary shuffle masks to improve pattern matching by
40610
- // commuting the inputs.
40611
- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40612
- ShuffleVectorSDNode::commuteMask(WideMask);
40613
- std::swap(WideInputs[0], WideInputs[1]);
40614
- }
40615
-
40616
- // Increase depth for every upper subvector we've peeked through.
40617
- Depth += AdjustedMasks;
40618
-
40619
- // Attempt to combine wider chain.
40620
- // TODO: Can we use a better Root?
40621
- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40622
- WideInputs.back().getValueSizeInBits()
40623
- ? WideInputs.front()
40624
- : WideInputs.back();
40625
- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40626
- "WideRootSize mismatch");
40627
-
40628
- if (SDValue WideShuffle = combineX86ShuffleChain(
40629
- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40630
- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40631
- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40632
- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40633
- return DAG.getBitcast(RootVT, WideShuffle);
40634
- }
40635
-
40636
- return SDValue();
40637
- }
40638
-
40639
40474
// Canonicalize the combined shuffle mask chain with horizontal ops.
40640
40475
// NOTE: This may update the Ops and Mask.
40641
40476
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -41048,6 +40883,54 @@ static SDValue combineX86ShufflesRecursively(
41048
40883
OpMask.assign(NumElts, SM_SentinelUndef);
41049
40884
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
41050
40885
OpZero = OpUndef = APInt::getZero(NumElts);
40886
+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40887
+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
40888
+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
40889
+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
40890
+ // Extracting from vector larger than RootVT - scale the mask and attempt to
40891
+ // fold the shuffle with the larger root type, then extract the lower
40892
+ // elements.
40893
+ unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
40894
+ unsigned Scale = NewRootSizeInBits / RootSizeInBits;
40895
+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
40896
+ Scale * RootVT.getVectorNumElements());
40897
+ SmallVector<int, 64> NewRootMask;
40898
+ growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
40899
+ // If we're using the lowest subvector, just replace it directly in the src
40900
+ // ops/nodes.
40901
+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
40902
+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
40903
+ if (isNullConstant(Op.getOperand(1))) {
40904
+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
40905
+ NewSrcNodes.push_back(Op.getNode());
40906
+ }
40907
+ // Don't increase the combine depth - we're effectively working on the same
40908
+ // nodes, just with a wider type.
40909
+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
40910
+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
40911
+ Depth, MaxDepth, AllowVariableCrossLaneMask,
40912
+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
40913
+ return DAG.getBitcast(
40914
+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
40915
+ return SDValue();
40916
+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
40917
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40918
+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
40919
+ RootSizeInBits) {
40920
+ // If we're inserting an subvector extracted from a vector larger than
40921
+ // RootVT, then combine the insert_subvector as a shuffle, the
40922
+ // extract_subvector will be folded in a later recursion.
40923
+ SDValue BaseVec = Op.getOperand(0);
40924
+ SDValue SubVec = Op.getOperand(1);
40925
+ int InsertIdx = Op.getConstantOperandVal(2);
40926
+ unsigned NumBaseElts = VT.getVectorNumElements();
40927
+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
40928
+ OpInputs.assign({BaseVec, SubVec});
40929
+ OpMask.resize(NumBaseElts);
40930
+ std::iota(OpMask.begin(), OpMask.end(), 0);
40931
+ std::iota(OpMask.begin() + InsertIdx,
40932
+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
40933
+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
41051
40934
} else {
41052
40935
return SDValue();
41053
40936
}
@@ -41394,25 +41277,9 @@ static SDValue combineX86ShufflesRecursively(
41394
41277
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
41395
41278
IsMaskedShuffle, DAG, DL, Subtarget))
41396
41279
return Shuffle;
41397
-
41398
- // If all the operands come from the same larger vector, fallthrough and try
41399
- // to use combineX86ShuffleChainWithExtract.
41400
- SDValue LHS = peekThroughBitcasts(Ops.front());
41401
- SDValue RHS = peekThroughBitcasts(Ops.back());
41402
- if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41403
- (RootSizeInBits / Mask.size()) != 64 ||
41404
- LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41405
- RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41406
- LHS.getOperand(0) != RHS.getOperand(0))
41407
- return SDValue();
41408
41280
}
41409
41281
41410
- // If that failed and any input is extracted then try to combine as a
41411
- // shuffle with the larger type.
41412
- return combineX86ShuffleChainWithExtract(
41413
- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41414
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41415
- DAG, DL, Subtarget);
41282
+ return SDValue();
41416
41283
}
41417
41284
41418
41285
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -44025,6 +43892,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
44025
43892
case X86ISD::UNPCKL:
44026
43893
case X86ISD::UNPCKH:
44027
43894
case X86ISD::BLENDI:
43895
+ case X86ISD::SHUFP:
44028
43896
// Integer ops.
44029
43897
case X86ISD::PACKSS:
44030
43898
case X86ISD::PACKUS:
0 commit comments