@@ -37139,6 +37139,52 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
37139
37139
Known = Known.zext(64);
37140
37140
}
37141
37141
37142
+ static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
37143
+ KnownBits &Known,
37144
+ const APInt &DemandedElts,
37145
+ const SelectionDAG &DAG,
37146
+ unsigned Depth) {
37147
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37148
+
37149
+ // Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
37150
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37151
+ APInt DemandedLoElts =
37152
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
37153
+ APInt DemandedHiElts =
37154
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
37155
+ KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
37156
+ KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
37157
+ KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
37158
+ KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
37159
+ KnownBits Lo = KnownBits::mul(LHSLo.sext(32), RHSLo.sext(32));
37160
+ KnownBits Hi = KnownBits::mul(LHSHi.sext(32), RHSHi.sext(32));
37161
+ Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
37162
+ /*NUW=*/false, Lo, Hi);
37163
+ }
37164
+
37165
+ static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
37166
+ KnownBits &Known,
37167
+ const APInt &DemandedElts,
37168
+ const SelectionDAG &DAG,
37169
+ unsigned Depth) {
37170
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37171
+
37172
+ // Multiply unsigned/signed i8 elements to create i16 values and add_sat Lo/Hi
37173
+ // pairs.
37174
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37175
+ APInt DemandedLoElts =
37176
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
37177
+ APInt DemandedHiElts =
37178
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
37179
+ KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
37180
+ KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
37181
+ KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
37182
+ KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
37183
+ KnownBits Lo = KnownBits::mul(LHSLo.zext(16), RHSLo.sext(16));
37184
+ KnownBits Hi = KnownBits::mul(LHSHi.zext(16), RHSHi.sext(16));
37185
+ Known = KnownBits::sadd_sat(Lo, Hi);
37186
+ }
37187
+
37142
37188
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
37143
37189
KnownBits &Known,
37144
37190
const APInt &DemandedElts,
@@ -37314,6 +37360,26 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
37314
37360
}
37315
37361
break;
37316
37362
}
37363
+ case X86ISD::VPMADDWD: {
37364
+ SDValue LHS = Op.getOperand(0);
37365
+ SDValue RHS = Op.getOperand(1);
37366
+ assert(VT.getVectorElementType() == MVT::i32 &&
37367
+ LHS.getValueType() == RHS.getValueType() &&
37368
+ LHS.getValueType().getVectorElementType() == MVT::i16 &&
37369
+ "Unexpected PMADDWD types");
37370
+ computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
37371
+ break;
37372
+ }
37373
+ case X86ISD::VPMADDUBSW: {
37374
+ SDValue LHS = Op.getOperand(0);
37375
+ SDValue RHS = Op.getOperand(1);
37376
+ assert(VT.getVectorElementType() == MVT::i16 &&
37377
+ LHS.getValueType() == RHS.getValueType() &&
37378
+ LHS.getValueType().getVectorElementType() == MVT::i8 &&
37379
+ "Unexpected PMADDUBSW types");
37380
+ computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37381
+ break;
37382
+ }
37317
37383
case X86ISD::PMULUDQ: {
37318
37384
KnownBits Known2;
37319
37385
Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37450,6 +37516,30 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
37450
37516
}
37451
37517
case ISD::INTRINSIC_WO_CHAIN: {
37452
37518
switch (Op->getConstantOperandVal(0)) {
37519
+ case Intrinsic::x86_sse2_pmadd_wd:
37520
+ case Intrinsic::x86_avx2_pmadd_wd:
37521
+ case Intrinsic::x86_avx512_pmaddw_d_512: {
37522
+ SDValue LHS = Op.getOperand(1);
37523
+ SDValue RHS = Op.getOperand(2);
37524
+ assert(VT.getScalarType() == MVT::i32 &&
37525
+ LHS.getValueType() == RHS.getValueType() &&
37526
+ LHS.getValueType().getScalarType() == MVT::i16 &&
37527
+ "Unexpected PMADDWD types");
37528
+ computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
37529
+ break;
37530
+ }
37531
+ case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
37532
+ case Intrinsic::x86_avx2_pmadd_ub_sw:
37533
+ case Intrinsic::x86_avx512_pmaddubs_w_512: {
37534
+ SDValue LHS = Op.getOperand(1);
37535
+ SDValue RHS = Op.getOperand(2);
37536
+ assert(VT.getScalarType() == MVT::i16 &&
37537
+ LHS.getValueType() == RHS.getValueType() &&
37538
+ LHS.getValueType().getScalarType() == MVT::i8 &&
37539
+ "Unexpected PMADDUBSW types");
37540
+ computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37541
+ break;
37542
+ }
37453
37543
case Intrinsic::x86_sse2_psad_bw:
37454
37544
case Intrinsic::x86_avx2_psad_bw:
37455
37545
case Intrinsic::x86_avx512_psad_bw_512: {
0 commit comments