Skip to content

Commit 9de14e2

Browse files
committed
[InstCombine][X86] Add zero arg handling for PMADDWD/PMADDUBSW intrinsics
PMADDWD/PMADDUBSW - multiply by zero folds Initial setup to handle future PMADDWD/PMADDUBSW simplification / constant folding
1 parent e1751a1 commit 9de14e2

File tree

3 files changed

+50
-28
lines changed

3 files changed

+50
-28
lines changed

llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,24 @@ static Value *simplifyX86pack(IntrinsicInst &II,
502502
return Builder.CreateTrunc(Shuffle, ResTy);
503503
}
504504

505+
static Value *simplifyX86pmadd(IntrinsicInst &II,
506+
InstCombiner::BuilderTy &Builder) {
507+
Value *Arg0 = II.getArgOperand(0);
508+
Value *Arg1 = II.getArgOperand(1);
509+
auto *ResTy = cast<FixedVectorType>(II.getType());
510+
[[maybe_unused]] auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
511+
512+
assert(ArgTy->getNumElements() == (2 * ResTy->getNumElements()) &&
513+
ResTy->getScalarSizeInBits() == (2 * ArgTy->getScalarSizeInBits()) &&
514+
"Unexpected PMADD types");
515+
516+
// Multiply by zero.
517+
if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1))
518+
return ConstantAggregateZero::get(ResTy);
519+
520+
return nullptr;
521+
}
522+
505523
static Value *simplifyX86movmsk(const IntrinsicInst &II,
506524
InstCombiner::BuilderTy &Builder) {
507525
Value *Arg = II.getArgOperand(0);
@@ -2478,6 +2496,22 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
24782496
}
24792497
break;
24802498

2499+
case Intrinsic::x86_sse2_pmadd_wd:
2500+
case Intrinsic::x86_avx2_pmadd_wd:
2501+
case Intrinsic::x86_avx512_pmaddw_d_512:
2502+
if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
2503+
return IC.replaceInstUsesWith(II, V);
2504+
}
2505+
break;
2506+
2507+
case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
2508+
case Intrinsic::x86_avx2_pmadd_ub_sw:
2509+
case Intrinsic::x86_avx512_pmaddubs_w_512:
2510+
if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
2511+
return IC.replaceInstUsesWith(II, V);
2512+
}
2513+
break;
2514+
24812515
case Intrinsic::x86_pclmulqdq:
24822516
case Intrinsic::x86_pclmulqdq_256:
24832517
case Intrinsic::x86_pclmulqdq_512: {

llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,53 +38,47 @@ define <32 x i16> @undef_pmaddubsw_512() {
3838

3939
define <8 x i16> @zero_pmaddubsw_128(<16 x i8> %a0) {
4040
; CHECK-LABEL: @zero_pmaddubsw_128(
41-
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> [[A0:%.*]], <16 x i8> zeroinitializer)
42-
; CHECK-NEXT: ret <8 x i16> [[TMP1]]
41+
; CHECK-NEXT: ret <8 x i16> zeroinitializer
4342
;
4443
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a0, <16 x i8> zeroinitializer)
4544
ret <8 x i16> %1
4645
}
4746

4847
define <8 x i16> @zero_pmaddubsw_128_commute(<16 x i8> %a0) {
4948
; CHECK-LABEL: @zero_pmaddubsw_128_commute(
50-
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> zeroinitializer, <16 x i8> [[A0:%.*]])
51-
; CHECK-NEXT: ret <8 x i16> [[TMP1]]
49+
; CHECK-NEXT: ret <8 x i16> zeroinitializer
5250
;
5351
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> zeroinitializer, <16 x i8> %a0)
5452
ret <8 x i16> %1
5553
}
5654

5755
define <16 x i16> @zero_pmaddubsw_256(<32 x i8>%a0) {
5856
; CHECK-LABEL: @zero_pmaddubsw_256(
59-
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0:%.*]], <32 x i8> zeroinitializer)
60-
; CHECK-NEXT: ret <16 x i16> [[TMP1]]
57+
; CHECK-NEXT: ret <16 x i16> zeroinitializer
6158
;
6259
%1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a0, <32 x i8> zeroinitializer)
6360
ret <16 x i16> %1
6461
}
6562

6663
define <16 x i16> @zero_pmaddubsw_256_commute(<32 x i8> %a0) {
6764
; CHECK-LABEL: @zero_pmaddubsw_256_commute(
68-
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> zeroinitializer, <32 x i8> [[A0:%.*]])
69-
; CHECK-NEXT: ret <16 x i16> [[TMP1]]
65+
; CHECK-NEXT: ret <16 x i16> zeroinitializer
7066
;
7167
%1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> zeroinitializer, <32 x i8> %a0)
7268
ret <16 x i16> %1
7369
}
7470

7571
define <32 x i16> @zero_pmaddubsw_512(<64 x i8> %a0) {
7672
; CHECK-LABEL: @zero_pmaddubsw_512(
77-
; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> [[A0:%.*]], <64 x i8> zeroinitializer)
78-
; CHECK-NEXT: ret <32 x i16> [[TMP1]]
73+
; CHECK-NEXT: ret <32 x i16> zeroinitializer
7974
;
8075
%1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> %a0, <64 x i8> zeroinitializer)
8176
ret <32 x i16> %1
8277
}
8378

84-
define <32 x i16> @zero_pmaddubsw_512_commuite(<64 x i8> %a0) {
85-
; CHECK-LABEL: @zero_pmaddubsw_512_commuite(
86-
; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> zeroinitializer, <64 x i8> [[A0:%.*]])
87-
; CHECK-NEXT: ret <32 x i16> [[TMP1]]
79+
define <32 x i16> @zero_pmaddubsw_512_commute(<64 x i8> %a0) {
80+
; CHECK-LABEL: @zero_pmaddubsw_512_commute(
81+
; CHECK-NEXT: ret <32 x i16> zeroinitializer
8882
;
8983
%1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> zeroinitializer, <64 x i8> %a0)
9084
ret <32 x i16> %1

llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,53 +38,47 @@ define <16 x i32> @undef_pmaddwd_512() {
3838

3939
define <4 x i32> @zero_pmaddwd_128(<8 x i16> %a0) {
4040
; CHECK-LABEL: @zero_pmaddwd_128(
41-
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> [[A0:%.*]], <8 x i16> zeroinitializer)
42-
; CHECK-NEXT: ret <4 x i32> [[TMP1]]
41+
; CHECK-NEXT: ret <4 x i32> zeroinitializer
4342
;
4443
%1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a0, <8 x i16> zeroinitializer)
4544
ret <4 x i32> %1
4645
}
4746

4847
define <4 x i32> @zero_pmaddwd_128_commute(<8 x i16> %a0) {
4948
; CHECK-LABEL: @zero_pmaddwd_128_commute(
50-
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> zeroinitializer, <8 x i16> [[A0:%.*]])
51-
; CHECK-NEXT: ret <4 x i32> [[TMP1]]
49+
; CHECK-NEXT: ret <4 x i32> zeroinitializer
5250
;
5351
%1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> zeroinitializer, <8 x i16> %a0)
5452
ret <4 x i32> %1
5553
}
5654

5755
define <8 x i32> @zero_pmaddwd_256(<16 x i16> %a0) {
5856
; CHECK-LABEL: @zero_pmaddwd_256(
59-
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> [[A0:%.*]], <16 x i16> zeroinitializer)
60-
; CHECK-NEXT: ret <8 x i32> [[TMP1]]
57+
; CHECK-NEXT: ret <8 x i32> zeroinitializer
6158
;
6259
%1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a0, <16 x i16> zeroinitializer)
6360
ret <8 x i32> %1
6461
}
6562

6663
define <8 x i32> @zero_pmaddwd_256_commute(<16 x i16> %a0) {
6764
; CHECK-LABEL: @zero_pmaddwd_256_commute(
68-
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> zeroinitializer, <16 x i16> [[A0:%.*]])
69-
; CHECK-NEXT: ret <8 x i32> [[TMP1]]
65+
; CHECK-NEXT: ret <8 x i32> zeroinitializer
7066
;
7167
%1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> zeroinitializer, <16 x i16> %a0)
7268
ret <8 x i32> %1
7369
}
7470

7571
define <16 x i32> @zero_pmaddwd_512(<32 x i16> %a0) {
7672
; CHECK-LABEL: @zero_pmaddwd_512(
77-
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> [[A0:%.*]], <32 x i16> zeroinitializer)
78-
; CHECK-NEXT: ret <16 x i32> [[TMP1]]
73+
; CHECK-NEXT: ret <16 x i32> zeroinitializer
7974
;
8075
%1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a0, <32 x i16> zeroinitializer)
8176
ret <16 x i32> %1
8277
}
8378

84-
define <16 x i32> @zero_pmaddwd_512_commuite(<32 x i16> %a0) {
85-
; CHECK-LABEL: @zero_pmaddwd_512_commuite(
86-
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> zeroinitializer, <32 x i16> [[A0:%.*]])
87-
; CHECK-NEXT: ret <16 x i32> [[TMP1]]
79+
define <16 x i32> @zero_pmaddwd_512_commute(<32 x i16> %a0) {
80+
; CHECK-LABEL: @zero_pmaddwd_512_commute(
81+
; CHECK-NEXT: ret <16 x i32> zeroinitializer
8882
;
8983
%1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> zeroinitializer, <32 x i16> %a0)
9084
ret <16 x i32> %1

0 commit comments

Comments
 (0)