@@ -39653,13 +39653,6 @@ static bool matchBinaryPermuteShuffle(
39653
39653
return false;
39654
39654
}
39655
39655
39656
- static SDValue combineX86ShuffleChainWithExtract(
39657
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39658
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39659
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39660
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39661
- const X86Subtarget &Subtarget);
39662
-
39663
39656
/// Combine an arbitrary chain of shuffles into a single instruction if
39664
39657
/// possible.
39665
39658
///
@@ -40203,14 +40196,6 @@ static SDValue combineX86ShuffleChain(
40203
40196
return DAG.getBitcast(RootVT, Res);
40204
40197
}
40205
40198
40206
- // If that failed and either input is extracted then try to combine as a
40207
- // shuffle with the larger type.
40208
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40209
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40210
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40211
- IsMaskedShuffle, DAG, DL, Subtarget))
40212
- return WideShuffle;
40213
-
40214
40199
// If we have a dual input lane-crossing shuffle then lower to VPERMV3,
40215
40200
// (non-VLX will pad to 512-bit shuffles).
40216
40201
if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40376,14 +40361,6 @@ static SDValue combineX86ShuffleChain(
40376
40361
return DAG.getBitcast(RootVT, Res);
40377
40362
}
40378
40363
40379
- // If that failed and either input is extracted then try to combine as a
40380
- // shuffle with the larger type.
40381
- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40382
- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40383
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40384
- DAG, DL, Subtarget))
40385
- return WideShuffle;
40386
-
40387
40364
// If we have a dual input shuffle then lower to VPERMV3,
40388
40365
// (non-VLX will pad to 512-bit shuffles)
40389
40366
if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40409,154 +40386,6 @@ static SDValue combineX86ShuffleChain(
40409
40386
return SDValue();
40410
40387
}
40411
40388
40412
- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40413
- // instruction if possible.
40414
- //
40415
- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40416
- // type size to attempt to combine:
40417
- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40418
- // -->
40419
- // extract_subvector(shuffle(x,y,m2),0)
40420
- static SDValue combineX86ShuffleChainWithExtract(
40421
- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40422
- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40423
- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40424
- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40425
- const X86Subtarget &Subtarget) {
40426
- unsigned NumMaskElts = BaseMask.size();
40427
- unsigned NumInputs = Inputs.size();
40428
- if (NumInputs == 0)
40429
- return SDValue();
40430
-
40431
- unsigned RootSizeInBits = RootVT.getSizeInBits();
40432
- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40433
- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40434
-
40435
- // Peek through subvectors to find widest legal vector.
40436
- // TODO: Handle ISD::TRUNCATE
40437
- unsigned WideSizeInBits = RootSizeInBits;
40438
- for (SDValue Input : Inputs) {
40439
- Input = peekThroughBitcasts(Input);
40440
- while (1) {
40441
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40442
- Input = peekThroughBitcasts(Input.getOperand(0));
40443
- continue;
40444
- }
40445
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40446
- Input.getOperand(0).isUndef()) {
40447
- Input = peekThroughBitcasts(Input.getOperand(1));
40448
- continue;
40449
- }
40450
- break;
40451
- }
40452
- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40453
- WideSizeInBits < Input.getValueSizeInBits())
40454
- WideSizeInBits = Input.getValueSizeInBits();
40455
- }
40456
-
40457
- // Bail if we fail to find a source larger than the existing root.
40458
- unsigned Scale = WideSizeInBits / RootSizeInBits;
40459
- if (WideSizeInBits <= RootSizeInBits ||
40460
- (WideSizeInBits % RootSizeInBits) != 0)
40461
- return SDValue();
40462
-
40463
- // Create new mask for larger type.
40464
- SmallVector<int, 64> WideMask(BaseMask);
40465
- for (int &M : WideMask) {
40466
- if (M < 0)
40467
- continue;
40468
- M = (M % NumMaskElts) + ((M / NumMaskElts) * Scale * NumMaskElts);
40469
- }
40470
- WideMask.append((Scale - 1) * NumMaskElts, SM_SentinelUndef);
40471
-
40472
- // Attempt to peek through inputs and adjust mask when we extract from an
40473
- // upper subvector.
40474
- int AdjustedMasks = 0;
40475
- SmallVector<SDValue, 4> WideInputs(Inputs);
40476
- for (unsigned I = 0; I != NumInputs; ++I) {
40477
- SDValue &Input = WideInputs[I];
40478
- Input = peekThroughBitcasts(Input);
40479
- while (1) {
40480
- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40481
- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40482
- uint64_t Idx = Input.getConstantOperandVal(1);
40483
- if (Idx != 0) {
40484
- ++AdjustedMasks;
40485
- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40486
- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40487
-
40488
- int lo = I * WideMask.size();
40489
- int hi = (I + 1) * WideMask.size();
40490
- for (int &M : WideMask)
40491
- if (lo <= M && M < hi)
40492
- M += Idx;
40493
- }
40494
- Input = peekThroughBitcasts(Input.getOperand(0));
40495
- continue;
40496
- }
40497
- // TODO: Handle insertions into upper subvectors.
40498
- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40499
- Input.getOperand(0).isUndef() &&
40500
- isNullConstant(Input.getOperand(2))) {
40501
- Input = peekThroughBitcasts(Input.getOperand(1));
40502
- continue;
40503
- }
40504
- break;
40505
- }
40506
- }
40507
-
40508
- // Remove unused/repeated shuffle source ops.
40509
- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40510
- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40511
-
40512
- // Bail if we're always extracting from the lowest subvectors,
40513
- // combineX86ShuffleChain should match this for the current width, or the
40514
- // shuffle still references too many inputs.
40515
- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40516
- return SDValue();
40517
-
40518
- // Minor canonicalization of the accumulated shuffle mask to make it easier
40519
- // to match below. All this does is detect masks with sequential pairs of
40520
- // elements, and shrink them to the half-width mask. It does this in a loop
40521
- // so it will reduce the size of the mask to the minimal width mask which
40522
- // performs an equivalent shuffle.
40523
- while (WideMask.size() > 1) {
40524
- SmallVector<int, 64> WidenedMask;
40525
- if (!canWidenShuffleElements(WideMask, WidenedMask))
40526
- break;
40527
- WideMask = std::move(WidenedMask);
40528
- }
40529
-
40530
- // Canonicalization of binary shuffle masks to improve pattern matching by
40531
- // commuting the inputs.
40532
- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40533
- ShuffleVectorSDNode::commuteMask(WideMask);
40534
- std::swap(WideInputs[0], WideInputs[1]);
40535
- }
40536
-
40537
- // Increase depth for every upper subvector we've peeked through.
40538
- Depth += AdjustedMasks;
40539
-
40540
- // Attempt to combine wider chain.
40541
- // TODO: Can we use a better Root?
40542
- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40543
- WideInputs.back().getValueSizeInBits()
40544
- ? WideInputs.front()
40545
- : WideInputs.back();
40546
- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40547
- "WideRootSize mismatch");
40548
-
40549
- if (SDValue WideShuffle = combineX86ShuffleChain(
40550
- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40551
- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40552
- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40553
- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40554
- return DAG.getBitcast(RootVT, WideShuffle);
40555
- }
40556
-
40557
- return SDValue();
40558
- }
40559
-
40560
40389
// Canonicalize the combined shuffle mask chain with horizontal ops.
40561
40390
// NOTE: This may update the Ops and Mask.
40562
40391
static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -40969,6 +40798,57 @@ static SDValue combineX86ShufflesRecursively(
40969
40798
OpMask.assign(NumElts, SM_SentinelUndef);
40970
40799
std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
40971
40800
OpZero = OpUndef = APInt::getZero(NumElts);
40801
+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40802
+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
40803
+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
40804
+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
40805
+ // Extracting from vector larger than RootVT - scale the mask and attempt to
40806
+ // fold the shuffle with the larger root type, then extract the lower
40807
+ // elements.
40808
+ unsigned Scale = Op.getOperand(0).getValueSizeInBits() / RootSizeInBits;
40809
+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
40810
+ Scale * RootVT.getVectorNumElements());
40811
+ SmallVector<int, 64> NewRootMask(RootMask);
40812
+ NewRootMask.append((Scale - 1) * RootMask.size(), SM_SentinelUndef);
40813
+ for (int &M : NewRootMask)
40814
+ if (0 <= M)
40815
+ M = (M % RootMask.size()) +
40816
+ ((M / RootMask.size()) * NewRootMask.size());
40817
+ // If we're using the lowest subvector, just replace it directly in the src
40818
+ // ops/nodes.
40819
+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
40820
+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
40821
+ if (isNullConstant(Op.getOperand(1))) {
40822
+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
40823
+ NewSrcNodes.push_back(Op.getNode());
40824
+ }
40825
+ // Don't increase the combine depth - we're effectively working on the same
40826
+ // nodes, just with a wider type.
40827
+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
40828
+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
40829
+ Depth, MaxDepth, AllowVariableCrossLaneMask,
40830
+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
40831
+ return DAG.getBitcast(
40832
+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
40833
+ return SDValue();
40834
+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
40835
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40836
+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
40837
+ RootSizeInBits) {
40838
+ // If we're inserting an subvector extracted from a vector larger than
40839
+ // RootVT, then combine the insert_subvector as a shuffle, the
40840
+ // extract_subvector will be folded in a later recursion.
40841
+ SDValue BaseVec = Op.getOperand(0);
40842
+ SDValue SubVec = Op.getOperand(1);
40843
+ int InsertIdx = Op.getConstantOperandVal(2);
40844
+ unsigned NumBaseElts = VT.getVectorNumElements();
40845
+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
40846
+ OpInputs.assign({BaseVec, SubVec});
40847
+ OpMask.assign(NumBaseElts, SM_SentinelUndef);
40848
+ std::iota(OpMask.begin(), OpMask.end(), 0);
40849
+ std::iota(OpMask.begin() + InsertIdx,
40850
+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
40851
+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
40972
40852
} else {
40973
40853
return SDValue();
40974
40854
}
@@ -41324,12 +41204,7 @@ static SDValue combineX86ShufflesRecursively(
41324
41204
return SDValue();
41325
41205
}
41326
41206
41327
- // If that failed and any input is extracted then try to combine as a
41328
- // shuffle with the larger type.
41329
- return combineX86ShuffleChainWithExtract(
41330
- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41331
- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41332
- DAG, DL, Subtarget);
41207
+ return SDValue();
41333
41208
}
41334
41209
41335
41210
/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -43866,6 +43741,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
43866
43741
case X86ISD::UNPCKL:
43867
43742
case X86ISD::UNPCKH:
43868
43743
case X86ISD::BLENDI:
43744
+ case X86ISD::SHUFP:
43869
43745
// Integer ops.
43870
43746
case X86ISD::PACKSS:
43871
43747
case X86ISD::PACKUS:
0 commit comments