@@ -41368,6 +41368,22 @@ static SmallVector<int, 4> getPSHUFShuffleMask(SDValue N) {
41368
41368
}
41369
41369
}
41370
41370
41371
+ /// Get the expanded blend mask from a BLENDI node.
41372
+ /// For v16i16 nodes, this will splat the repeated i8 mask.
41373
+ static APInt getBLENDIBlendMask(SDValue V) {
41374
+ assert(V.getOpcode() == X86ISD::BLENDI && "Unknown blend shuffle");
41375
+ unsigned NumElts = V.getSimpleValueType().getVectorNumElements();
41376
+ APInt Mask = V.getConstantOperandAPInt(2);
41377
+ if (Mask.getBitWidth() > NumElts)
41378
+ Mask = Mask.trunc(NumElts);
41379
+ if (NumElts == 16) {
41380
+ assert(Mask.getBitWidth() == 8 && "Unexpected v16i16 blend mask width");
41381
+ Mask = APInt::getSplat(16, Mask);
41382
+ }
41383
+ assert(Mask.getBitWidth() == NumElts && "Unexpected blend mask width");
41384
+ return Mask;
41385
+ }
41386
+
41371
41387
/// Search for a combinable shuffle across a chain ending in pshufd.
41372
41388
///
41373
41389
/// We walk up the chain and look for a combinable shuffle, skipping over
@@ -42266,7 +42282,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
42266
42282
unsigned SrcBits = SrcVT.getScalarSizeInBits();
42267
42283
if ((EltBits % SrcBits) == 0 && SrcBits >= 32) {
42268
42284
unsigned NewSize = SrcVT.getVectorNumElements();
42269
- APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(NumElts );
42285
+ APInt BlendMask = getBLENDIBlendMask(N );
42270
42286
APInt NewBlendMask = APIntOps::ScaleBitMask(BlendMask, NewSize);
42271
42287
return DAG.getBitcast(
42272
42288
VT, DAG.getNode(X86ISD::BLENDI, DL, SrcVT, N0.getOperand(0),
@@ -58488,16 +58504,11 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
58488
58504
break;
58489
58505
case X86ISD::BLENDI:
58490
58506
if (NumOps == 2 && VT.is512BitVector() && Subtarget.useBWIRegs()) {
58491
- uint64_t Mask0 = Ops[0].getConstantOperandVal(2);
58492
- uint64_t Mask1 = Ops[1].getConstantOperandVal(2);
58493
- // MVT::v16i16 has repeated blend mask.
58494
- if (Op0.getSimpleValueType() == MVT::v16i16) {
58495
- Mask0 = (Mask0 << 8) | Mask0;
58496
- Mask1 = (Mask1 << 8) | Mask1;
58497
- }
58498
- uint64_t Mask = (Mask1 << (VT.getVectorNumElements() / 2)) | Mask0;
58499
- MVT MaskSVT = MVT::getIntegerVT(VT.getVectorNumElements());
58500
- MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
58507
+ unsigned NumElts = VT.getVectorNumElements();
58508
+ APInt Mask = getBLENDIBlendMask(Ops[0]).zext(NumElts);
58509
+ Mask.insertBits(getBLENDIBlendMask(Ops[1]), NumElts / 2);
58510
+ MVT MaskSVT = MVT::getIntegerVT(NumElts);
58511
+ MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
58501
58512
SDValue Sel =
58502
58513
DAG.getBitcast(MaskVT, DAG.getConstant(Mask, DL, MaskSVT));
58503
58514
return DAG.getSelect(DL, VT, Sel, ConcatSubOperand(VT, Ops, 1),
0 commit comments