Skip to content

Commit c7e0f1e

Browse files
authored
[X86] Allow input vector extracted from larger vector when combining to VPMADDUBSW (#89584)
Failed on main trunk: https://godbolt.org/z/edWMz8chE
1 parent 35b292e commit c7e0f1e

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51841,6 +51841,17 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
5184151841
return SDValue();
5184251842
}
5184351843

51844+
auto ExtractVec = [&DAG, &DL, NumElems](SDValue &Ext) {
51845+
EVT ExtVT = Ext.getValueType();
51846+
if (ExtVT.getVectorNumElements() != NumElems * 2) {
51847+
MVT NVT = MVT::getVectorVT(MVT::i8, NumElems * 2);
51848+
Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, Ext,
51849+
DAG.getIntPtrConstant(0, DL));
51850+
}
51851+
};
51852+
ExtractVec(ZExtIn);
51853+
ExtractVec(SExtIn);
51854+
5184451855
auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
5184551856
ArrayRef<SDValue> Ops) {
5184651857
// Shrink by adding truncate nodes and let DAGCombine fold with the

llvm/test/CodeGen/X86/pmaddubsw.ll

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,41 @@ define <8 x i16> @pmaddubsw_bad_indices(ptr %Aptr, ptr %Bptr) {
469469
%trunc = trunc <8 x i32> %min to <8 x i16>
470470
ret <8 x i16> %trunc
471471
}
472+
473+
define <8 x i16> @pmaddubsw_large_vector(ptr %p1, ptr %p2) {
474+
; SSE-LABEL: pmaddubsw_large_vector:
475+
; SSE: # %bb.0:
476+
; SSE-NEXT: movdqa (%rdi), %xmm0
477+
; SSE-NEXT: pmaddubsw (%rsi), %xmm0
478+
; SSE-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
479+
; SSE-NEXT: retq
480+
;
481+
; AVX-LABEL: pmaddubsw_large_vector:
482+
; AVX: # %bb.0:
483+
; AVX-NEXT: vmovdqa (%rdi), %xmm0
484+
; AVX-NEXT: vpmaddubsw (%rsi), %xmm0, %xmm0
485+
; AVX-NEXT: vpxor %xmm1, %xmm1, %xmm1
486+
; AVX-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1],xmm0[2],xmm1[3,4],xmm0[5],xmm1[6],xmm0[7]
487+
; AVX-NEXT: retq
488+
%1 = load <64 x i8>, ptr %p1, align 64
489+
%2 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
490+
%3 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
491+
%4 = load <32 x i8>, ptr %p2, align 64
492+
%5 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
493+
%6 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
494+
%7 = sext <8 x i8> %5 to <8 x i32>
495+
%8 = zext <8 x i8> %2 to <8 x i32>
496+
%9 = mul nsw <8 x i32> %7, %8
497+
%10 = sext <8 x i8> %6 to <8 x i32>
498+
%11 = zext <8 x i8> %3 to <8 x i32>
499+
%12 = mul nsw <8 x i32> %10, %11
500+
%13 = add nsw <8 x i32> %9, %12
501+
%14 = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %13, <8 x i32> <i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767>)
502+
%15 = tail call <8 x i32> @llvm.smax.v8i32(<8 x i32> %14, <8 x i32> <i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768>)
503+
%16 = trunc <8 x i32> %15 to <8 x i16>
504+
%17 = shufflevector <8 x i16> zeroinitializer, <8 x i16> %16, <8 x i32> <i32 0, i32 1, i32 10, i32 3, i32 4, i32 13, i32 6, i32 15>
505+
ret <8 x i16> %17
506+
}
507+
508+
declare <8 x i32> @llvm.smin.v8i32(<8 x i32>, <8 x i32>)
509+
declare <8 x i32> @llvm.smax.v8i32(<8 x i32>, <8 x i32>)

0 commit comments

Comments
 (0)