@@ -39659,13 +39659,6 @@ static bool matchBinaryPermuteShuffle(
39659
39659
return false;
39660
39660
}
39661
39661
39662
- static SDValue combineX86ShuffleChainWithExtract(
39663
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39664
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39665
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39666
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39667
- const X86Subtarget &Subtarget);
39668
-
39669
39662
/// Combine an arbitrary chain of shuffles into a single instruction if
39670
39663
/// possible.
39671
39664
///
@@ -40210,14 +40203,6 @@ static SDValue combineX86ShuffleChain(
40210
40203
return DAG.getBitcast(RootVT, Res);
40211
40204
}
40212
40205
40213
- // If that failed and either input is extracted then try to combine as a
40214
- // shuffle with the larger type.
40215
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40216
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40217
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40218
- IsMaskedShuffle, DAG, DL, Subtarget))
40219
- return WideShuffle;
40220
-
40221
40206
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
40222
40207
// (non-VLX will pad to 512-bit shuffles).
40223
40208
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40383,14 +40368,6 @@ static SDValue combineX86ShuffleChain(
40383
40368
return DAG.getBitcast(RootVT, Res);
40384
40369
}
40385
40370
40386
- // If that failed and either input is extracted then try to combine as a
40387
- // shuffle with the larger type.
40388
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40389
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40390
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40391
- DAG, DL, Subtarget))
40392
- return WideShuffle;
40393
-
40394
40371
// If we have a dual input shuffle then lower to VPERMV3,
40395
40372
// (non-VLX will pad to 512-bit shuffles)
40396
40373
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40416,148 +40393,6 @@ static SDValue combineX86ShuffleChain(
40416
40393
return SDValue();
40417
40394
}
40418
40395
40419
- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40420
- // instruction if possible.
40421
- //
40422
- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40423
- // type size to attempt to combine:
40424
- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40425
- // -->
40426
- // extract_subvector(shuffle(x,y,m2),0)
40427
- static SDValue combineX86ShuffleChainWithExtract(
40428
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40429
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40430
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40431
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40432
- const X86Subtarget &Subtarget) {
40433
- unsigned NumMaskElts = BaseMask.size();
40434
- unsigned NumInputs = Inputs.size();
40435
- if (NumInputs == 0)
40436
- return SDValue();
40437
-
40438
- unsigned RootSizeInBits = RootVT.getSizeInBits();
40439
- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40440
- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40441
-
40442
- // Peek through subvectors to find widest legal vector.
40443
- // TODO: Handle ISD::TRUNCATE
40444
- unsigned WideSizeInBits = RootSizeInBits;
40445
- for (SDValue Input : Inputs) {
40446
- Input = peekThroughBitcasts(Input);
40447
- while (1) {
40448
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40449
- Input = peekThroughBitcasts(Input.getOperand(0));
40450
- continue;
40451
- }
40452
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40453
- Input.getOperand(0).isUndef()) {
40454
- Input = peekThroughBitcasts(Input.getOperand(1));
40455
- continue;
40456
- }
40457
- break;
40458
- }
40459
- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40460
- WideSizeInBits < Input.getValueSizeInBits())
40461
- WideSizeInBits = Input.getValueSizeInBits();
40462
- }
40463
-
40464
- // Bail if we fail to find a source larger than the existing root.
40465
- if (WideSizeInBits <= RootSizeInBits ||
40466
- (WideSizeInBits % RootSizeInBits) != 0)
40467
- return SDValue();
40468
-
40469
- // Create new mask for larger type.
40470
- SmallVector<int, 64> WideMask;
40471
- growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40472
-
40473
- // Attempt to peek through inputs and adjust mask when we extract from an
40474
- // upper subvector.
40475
- int AdjustedMasks = 0;
40476
- SmallVector<SDValue, 4> WideInputs(Inputs);
40477
- for (unsigned I = 0; I != NumInputs; ++I) {
40478
- SDValue &Input = WideInputs[I];
40479
- Input = peekThroughBitcasts(Input);
40480
- while (1) {
40481
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40482
- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40483
- uint64_t Idx = Input.getConstantOperandVal(1);
40484
- if (Idx != 0) {
40485
- ++AdjustedMasks;
40486
- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40487
- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40488
-
40489
- int lo = I * WideMask.size();
40490
- int hi = (I + 1) * WideMask.size();
40491
- for (int &M : WideMask)
40492
- if (lo <= M && M < hi)
40493
- M += Idx;
40494
- }
40495
- Input = peekThroughBitcasts(Input.getOperand(0));
40496
- continue;
40497
- }
40498
- // TODO: Handle insertions into upper subvectors.
40499
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40500
- Input.getOperand(0).isUndef() &&
40501
- isNullConstant(Input.getOperand(2))) {
40502
- Input = peekThroughBitcasts(Input.getOperand(1));
40503
- continue;
40504
- }
40505
- break;
40506
- }
40507
- }
40508
-
40509
- // Remove unused/repeated shuffle source ops.
40510
- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40511
- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40512
-
40513
- // Bail if we're always extracting from the lowest subvectors,
40514
- // combineX86ShuffleChain should match this for the current width, or the
40515
- // shuffle still references too many inputs.
40516
- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40517
- return SDValue();
40518
-
40519
- // Minor canonicalization of the accumulated shuffle mask to make it easier
40520
- // to match below. All this does is detect masks with sequential pairs of
40521
- // elements, and shrink them to the half-width mask. It does this in a loop
40522
- // so it will reduce the size of the mask to the minimal width mask which
40523
- // performs an equivalent shuffle.
40524
- while (WideMask.size() > 1) {
40525
- SmallVector<int, 64> WidenedMask;
40526
- if (!canWidenShuffleElements(WideMask, WidenedMask))
40527
- break;
40528
- WideMask = std::move(WidenedMask);
40529
- }
40530
-
40531
- // Canonicalization of binary shuffle masks to improve pattern matching by
40532
- // commuting the inputs.
40533
- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40534
- ShuffleVectorSDNode::commuteMask(WideMask);
40535
- std::swap(WideInputs[0], WideInputs[1]);
40536
- }
40537
-
40538
- // Increase depth for every upper subvector we've peeked through.
40539
- Depth += AdjustedMasks;
40540
-
40541
- // Attempt to combine wider chain.
40542
- // TODO: Can we use a better Root?
40543
- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40544
- WideInputs.back().getValueSizeInBits()
40545
- ? WideInputs.front()
40546
- : WideInputs.back();
40547
- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40548
- "WideRootSize mismatch");
40549
-
40550
- if (SDValue WideShuffle = combineX86ShuffleChain(
40551
- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40552
- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40553
- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40554
- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40555
- return DAG.getBitcast(RootVT, WideShuffle);
40556
- }
40557
-
40558
- return SDValue();
40559
- }
40560
-
40561
40396
// Canonicalize the combined shuffle mask chain with horizontal ops.
40562
40397
// NOTE: This may update the Ops and Mask.
40563
40398
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -40970,6 +40805,54 @@ static SDValue combineX86ShufflesRecursively(
40970
40805
OpMask.assign(NumElts, SM_SentinelUndef);
40971
40806
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
40972
40807
OpZero = OpUndef = APInt::getZero(NumElts);
40808
+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40809
+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
40810
+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
40811
+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
40812
+ // Extracting from vector larger than RootVT - scale the mask and attempt to
40813
+ // fold the shuffle with the larger root type, then extract the lower
40814
+ // elements.
40815
+ unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
40816
+ unsigned Scale = NewRootSizeInBits / RootSizeInBits;
40817
+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
40818
+ Scale * RootVT.getVectorNumElements());
40819
+ SmallVector<int, 64> NewRootMask;
40820
+ growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
40821
+ // If we're using the lowest subvector, just replace it directly in the src
40822
+ // ops/nodes.
40823
+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
40824
+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
40825
+ if (isNullConstant(Op.getOperand(1))) {
40826
+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
40827
+ NewSrcNodes.push_back(Op.getNode());
40828
+ }
40829
+ // Don't increase the combine depth - we're effectively working on the same
40830
+ // nodes, just with a wider type.
40831
+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
40832
+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
40833
+ Depth, MaxDepth, AllowVariableCrossLaneMask,
40834
+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
40835
+ return DAG.getBitcast(
40836
+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
40837
+ return SDValue();
40838
+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
40839
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40840
+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
40841
+ RootSizeInBits) {
40842
+ // If we're inserting an subvector extracted from a vector larger than
40843
+ // RootVT, then combine the insert_subvector as a shuffle, the
40844
+ // extract_subvector will be folded in a later recursion.
40845
+ SDValue BaseVec = Op.getOperand(0);
40846
+ SDValue SubVec = Op.getOperand(1);
40847
+ int InsertIdx = Op.getConstantOperandVal(2);
40848
+ unsigned NumBaseElts = VT.getVectorNumElements();
40849
+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
40850
+ OpInputs.assign({BaseVec, SubVec});
40851
+ OpMask.resize(NumBaseElts);
40852
+ std::iota(OpMask.begin(), OpMask.end(), 0);
40853
+ std::iota(OpMask.begin() + InsertIdx,
40854
+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
40855
+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
40973
40856
} else {
40974
40857
return SDValue();
40975
40858
}
@@ -41316,25 +41199,9 @@ static SDValue combineX86ShufflesRecursively(
41316
41199
AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
41317
41200
IsMaskedShuffle, DAG, DL, Subtarget))
41318
41201
return Shuffle;
41319
-
41320
- // If all the operands come from the same larger vector, fallthrough and try
41321
- // to use combineX86ShuffleChainWithExtract.
41322
- SDValue LHS = peekThroughBitcasts(Ops.front());
41323
- SDValue RHS = peekThroughBitcasts(Ops.back());
41324
- if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41325
- (RootSizeInBits / Mask.size()) != 64 ||
41326
- LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41327
- RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41328
- LHS.getOperand(0) != RHS.getOperand(0))
41329
- return SDValue();
41330
41202
}
41331
41203
41332
- // If that failed and any input is extracted then try to combine as a
41333
- // shuffle with the larger type.
41334
- return combineX86ShuffleChainWithExtract(
41335
- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41336
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41337
- DAG, DL, Subtarget);
41204
+ return SDValue();
41338
41205
}
41339
41206
41340
41207
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -43947,6 +43814,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
43947
43814
case X86ISD::UNPCKL:
43948
43815
case X86ISD::UNPCKH:
43949
43816
case X86ISD::BLENDI:
43817
+ case X86ISD::SHUFP:
43950
43818
// Integer ops.
43951
43819
case X86ISD::PACKSS:
43952
43820
case X86ISD::PACKUS:
0 commit comments