Skip to content

Commit ca4b1f8

Browse files
committed
[X86] computeKnownBitsForTargetNode - add handling for PMADDWD/PMADDUBSW nodes
These were reverted in fa0e9ac while we triaged an infinite loop regression
1 parent a52be0c commit ca4b1f8

File tree

2 files changed

+94
-29
lines changed

2 files changed

+94
-29
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37139,6 +37139,52 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
3713937139
Known = Known.zext(64);
3714037140
}
3714137141

37142+
static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
37143+
KnownBits &Known,
37144+
const APInt &DemandedElts,
37145+
const SelectionDAG &DAG,
37146+
unsigned Depth) {
37147+
unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37148+
37149+
// Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
37150+
APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37151+
APInt DemandedLoElts =
37152+
DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
37153+
APInt DemandedHiElts =
37154+
DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
37155+
KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
37156+
KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
37157+
KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
37158+
KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
37159+
KnownBits Lo = KnownBits::mul(LHSLo.sext(32), RHSLo.sext(32));
37160+
KnownBits Hi = KnownBits::mul(LHSHi.sext(32), RHSHi.sext(32));
37161+
Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
37162+
/*NUW=*/false, Lo, Hi);
37163+
}
37164+
37165+
static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
37166+
KnownBits &Known,
37167+
const APInt &DemandedElts,
37168+
const SelectionDAG &DAG,
37169+
unsigned Depth) {
37170+
unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37171+
37172+
// Multiply unsigned/signed i8 elements to create i16 values and add_sat Lo/Hi
37173+
// pairs.
37174+
APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37175+
APInt DemandedLoElts =
37176+
DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
37177+
APInt DemandedHiElts =
37178+
DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
37179+
KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
37180+
KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
37181+
KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
37182+
KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
37183+
KnownBits Lo = KnownBits::mul(LHSLo.zext(16), RHSLo.sext(16));
37184+
KnownBits Hi = KnownBits::mul(LHSHi.zext(16), RHSHi.sext(16));
37185+
Known = KnownBits::sadd_sat(Lo, Hi);
37186+
}
37187+
3714237188
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3714337189
KnownBits &Known,
3714437190
const APInt &DemandedElts,
@@ -37314,6 +37360,26 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3731437360
}
3731537361
break;
3731637362
}
37363+
case X86ISD::VPMADDWD: {
37364+
SDValue LHS = Op.getOperand(0);
37365+
SDValue RHS = Op.getOperand(1);
37366+
assert(VT.getVectorElementType() == MVT::i32 &&
37367+
LHS.getValueType() == RHS.getValueType() &&
37368+
LHS.getValueType().getVectorElementType() == MVT::i16 &&
37369+
"Unexpected PMADDWD types");
37370+
computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
37371+
break;
37372+
}
37373+
case X86ISD::VPMADDUBSW: {
37374+
SDValue LHS = Op.getOperand(0);
37375+
SDValue RHS = Op.getOperand(1);
37376+
assert(VT.getVectorElementType() == MVT::i16 &&
37377+
LHS.getValueType() == RHS.getValueType() &&
37378+
LHS.getValueType().getVectorElementType() == MVT::i8 &&
37379+
"Unexpected PMADDUBSW types");
37380+
computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37381+
break;
37382+
}
3731737383
case X86ISD::PMULUDQ: {
3731837384
KnownBits Known2;
3731937385
Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37450,6 +37516,30 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3745037516
}
3745137517
case ISD::INTRINSIC_WO_CHAIN: {
3745237518
switch (Op->getConstantOperandVal(0)) {
37519+
case Intrinsic::x86_sse2_pmadd_wd:
37520+
case Intrinsic::x86_avx2_pmadd_wd:
37521+
case Intrinsic::x86_avx512_pmaddw_d_512: {
37522+
SDValue LHS = Op.getOperand(1);
37523+
SDValue RHS = Op.getOperand(2);
37524+
assert(VT.getScalarType() == MVT::i32 &&
37525+
LHS.getValueType() == RHS.getValueType() &&
37526+
LHS.getValueType().getScalarType() == MVT::i16 &&
37527+
"Unexpected PMADDWD types");
37528+
computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
37529+
break;
37530+
}
37531+
case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
37532+
case Intrinsic::x86_avx2_pmadd_ub_sw:
37533+
case Intrinsic::x86_avx512_pmaddubs_w_512: {
37534+
SDValue LHS = Op.getOperand(1);
37535+
SDValue RHS = Op.getOperand(2);
37536+
assert(VT.getScalarType() == MVT::i16 &&
37537+
LHS.getValueType() == RHS.getValueType() &&
37538+
LHS.getValueType().getScalarType() == MVT::i8 &&
37539+
"Unexpected PMADDUBSW types");
37540+
computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37541+
break;
37542+
}
3745337543
case Intrinsic::x86_sse2_psad_bw:
3745437544
case Intrinsic::x86_avx2_psad_bw:
3745537545
case Intrinsic::x86_avx512_psad_bw_512: {

llvm/test/CodeGen/X86/combine-pmadd.ll

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -229,35 +229,10 @@ define i32 @combine_pmaddubsw_constant_sat() {
229229

230230
; Constant folding PMADDWD was causing an infinite loop in the PCMPGT commuting between 2 constant values.
231231
define i1 @pmaddwd_pcmpgt_infinite_loop() {
232-
; SSE-LABEL: pmaddwd_pcmpgt_infinite_loop:
233-
; SSE: # %bb.0:
234-
; SSE-NEXT: movdqa {{.*#+}} xmm0 = [2147483647,2147483647,2147483647,2147483647]
235-
; SSE-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
236-
; SSE-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
237-
; SSE-NEXT: movmskps %xmm0, %eax
238-
; SSE-NEXT: testl %eax, %eax
239-
; SSE-NEXT: sete %al
240-
; SSE-NEXT: retq
241-
;
242-
; AVX1-LABEL: pmaddwd_pcmpgt_infinite_loop:
243-
; AVX1: # %bb.0:
244-
; AVX1-NEXT: vpcmpeqd %xmm0, %xmm0, %xmm0
245-
; AVX1-NEXT: vbroadcastss {{.*#+}} xmm1 = [2147483647,2147483647,2147483647,2147483647]
246-
; AVX1-NEXT: vpaddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
247-
; AVX1-NEXT: vpcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
248-
; AVX1-NEXT: vtestps %xmm1, %xmm0
249-
; AVX1-NEXT: sete %al
250-
; AVX1-NEXT: retq
251-
;
252-
; AVX2-LABEL: pmaddwd_pcmpgt_infinite_loop:
253-
; AVX2: # %bb.0:
254-
; AVX2-NEXT: vpcmpeqd %xmm0, %xmm0, %xmm0
255-
; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm1 = [2147483647,2147483647,2147483647,2147483647]
256-
; AVX2-NEXT: vpaddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
257-
; AVX2-NEXT: vpcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
258-
; AVX2-NEXT: vtestps %xmm1, %xmm0
259-
; AVX2-NEXT: sete %al
260-
; AVX2-NEXT: retq
232+
; CHECK-LABEL: pmaddwd_pcmpgt_infinite_loop:
233+
; CHECK: # %bb.0:
234+
; CHECK-NEXT: movb $1, %al
235+
; CHECK-NEXT: retq
261236
%1 = tail call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>, <8 x i16> <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>)
262237
%2 = icmp eq <4 x i32> %1, <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
263238
%3 = select <4 x i1> %2, <4 x i32> <i32 2147483647, i32 2147483647, i32 2147483647, i32 2147483647>, <4 x i32> zeroinitializer

0 commit comments

Comments
 (0)