Skip to content

Commit 103119a

Browse files
authored
[WebAssembly] Lower wide SIMD i8 muls (#130785)
Currently, 'wide' i32 simd multiplication, with extended i8 elements, will perform the multiplication with i32 So, for IR like the following: ``` %wide.a = sext <8 x i8> %a to <8 x i32> %wide.b = sext <8 x i8> %a to <8 x i32> %mul = mul <8 x i32> %wide.a, %wide.b ret <8 x i32> %mul ``` We would generate the following sequence: ``` i16x8.extend_low_i8x16_s $push6=, $1 local.tee $push5=, $3=, $pop6 i32x4.extmul_low_i16x8_s $push0=, $pop5, $3 v128.store 0($0), $pop0 i8x16.shuffle $push1=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 i16x8.extend_low_i8x16_s $push4=, $pop1 local.tee $push3=, $1=, $pop4 i32x4.extmul_low_i16x8_s $push2=, $pop3, $1 v128.store 16($0), $pop2 return ``` But now we perform the multiplication with i16, resulting in: ``` i16x8.extmul_low_i8x16_s $push3=, $1, $1 local.tee $push2=, $1=, $pop3 i32x4.extend_high_i16x8_s $push0=, $pop2 v128.store 16($0), $pop0 i32x4.extend_low_i16x8_s $push1=, $1 v128.store 0($0), $pop1 return ```
1 parent 2089b08 commit 103119a

File tree

2 files changed

+290
-2
lines changed

2 files changed

+290
-2
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
183183
// Combine partial.reduce.add before legalization gets confused.
184184
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
185185

186+
// Combine wide-vector muls, with extend inputs, to extmul_half.
187+
setTargetDAGCombine(ISD::MUL);
188+
186189
// Combine vector mask reductions into alltrue/anytrue
187190
setTargetDAGCombine(ISD::SETCC);
188191

@@ -1461,8 +1464,7 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
14611464

14621465
bool WebAssemblyTargetLowering::CanLowerReturn(
14631466
CallingConv::ID /*CallConv*/, MachineFunction & /*MF*/, bool /*IsVarArg*/,
1464-
const SmallVectorImpl<ISD::OutputArg> &Outs,
1465-
LLVMContext & /*Context*/,
1467+
const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext & /*Context*/,
14661468
const Type *RetTy) const {
14671469
// WebAssembly can only handle returning tuples with multivalue enabled
14681470
return WebAssembly::canLowerReturn(Outs.size(), Subtarget);
@@ -3254,6 +3256,93 @@ static SDValue performSETCCCombine(SDNode *N,
32543256
return SDValue();
32553257
}
32563258

3259+
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
3260+
assert(N->getOpcode() == ISD::MUL);
3261+
EVT VT = N->getValueType(0);
3262+
if (VT != MVT::v8i32 && VT != MVT::v16i32)
3263+
return SDValue();
3264+
3265+
// Mul with extending inputs.
3266+
SDValue LHS = N->getOperand(0);
3267+
SDValue RHS = N->getOperand(1);
3268+
if (LHS.getOpcode() != RHS.getOpcode())
3269+
return SDValue();
3270+
3271+
if (LHS.getOpcode() != ISD::SIGN_EXTEND &&
3272+
LHS.getOpcode() != ISD::ZERO_EXTEND)
3273+
return SDValue();
3274+
3275+
if (LHS->getOperand(0).getValueType() != RHS->getOperand(0).getValueType())
3276+
return SDValue();
3277+
3278+
EVT FromVT = LHS->getOperand(0).getValueType();
3279+
EVT EltTy = FromVT.getVectorElementType();
3280+
if (EltTy != MVT::i8)
3281+
return SDValue();
3282+
3283+
// For an input DAG that looks like this
3284+
// %a = input_type
3285+
// %b = input_type
3286+
// %lhs = extend %a to output_type
3287+
// %rhs = extend %b to output_type
3288+
// %mul = mul %lhs, %rhs
3289+
3290+
// input_type | output_type | instructions
3291+
// v16i8 | v16i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
3292+
// | | %high = i16x8.extmul_high_i8x16_, %a, %b
3293+
// | | %low_low = i32x4.ext_low_i16x8_ %low
3294+
// | | %low_high = i32x4.ext_high_i16x8_ %low
3295+
// | | %high_low = i32x4.ext_low_i16x8_ %high
3296+
// | | %high_high = i32x4.ext_high_i16x8_ %high
3297+
// | | %res = concat_vector(...)
3298+
// v8i8 | v8i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
3299+
// | | %low_low = i32x4.ext_low_i16x8_ %low
3300+
// | | %low_high = i32x4.ext_high_i16x8_ %low
3301+
// | | %res = concat_vector(%low_low, %low_high)
3302+
3303+
SDLoc DL(N);
3304+
unsigned NumElts = VT.getVectorNumElements();
3305+
SDValue ExtendInLHS = LHS->getOperand(0);
3306+
SDValue ExtendInRHS = RHS->getOperand(0);
3307+
bool IsSigned = LHS->getOpcode() == ISD::SIGN_EXTEND;
3308+
unsigned ExtendLowOpc =
3309+
IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
3310+
unsigned ExtendHighOpc =
3311+
IsSigned ? WebAssemblyISD::EXTEND_HIGH_S : WebAssemblyISD::EXTEND_HIGH_U;
3312+
3313+
auto GetExtendLow = [&DAG, &DL, &ExtendLowOpc](EVT VT, SDValue Op) {
3314+
return DAG.getNode(ExtendLowOpc, DL, VT, Op);
3315+
};
3316+
auto GetExtendHigh = [&DAG, &DL, &ExtendHighOpc](EVT VT, SDValue Op) {
3317+
return DAG.getNode(ExtendHighOpc, DL, VT, Op);
3318+
};
3319+
3320+
if (NumElts == 16) {
3321+
SDValue LowLHS = GetExtendLow(MVT::v8i16, ExtendInLHS);
3322+
SDValue LowRHS = GetExtendLow(MVT::v8i16, ExtendInRHS);
3323+
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
3324+
SDValue HighLHS = GetExtendHigh(MVT::v8i16, ExtendInLHS);
3325+
SDValue HighRHS = GetExtendHigh(MVT::v8i16, ExtendInRHS);
3326+
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
3327+
SDValue SubVectors[] = {
3328+
GetExtendLow(MVT::v4i32, MulLow),
3329+
GetExtendHigh(MVT::v4i32, MulLow),
3330+
GetExtendLow(MVT::v4i32, MulHigh),
3331+
GetExtendHigh(MVT::v4i32, MulHigh),
3332+
};
3333+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SubVectors);
3334+
} else {
3335+
assert(NumElts == 8);
3336+
SDValue LowLHS = DAG.getNode(LHS->getOpcode(), DL, MVT::v8i16, ExtendInLHS);
3337+
SDValue LowRHS = DAG.getNode(RHS->getOpcode(), DL, MVT::v8i16, ExtendInRHS);
3338+
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
3339+
SDValue Lo = GetExtendLow(MVT::v4i32, MulLow);
3340+
SDValue Hi = GetExtendHigh(MVT::v4i32, MulLow);
3341+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
3342+
}
3343+
return SDValue();
3344+
}
3345+
32573346
SDValue
32583347
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
32593348
DAGCombinerInfo &DCI) const {
@@ -3281,5 +3370,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
32813370
return performTruncateCombine(N, DCI);
32823371
case ISD::INTRINSIC_WO_CHAIN:
32833372
return performLowerPartialReduction(N, DCI.DAG);
3373+
case ISD::MUL:
3374+
return performMulCombine(N, DCI.DAG);
32843375
}
32853376
}
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=wasm32 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
3+
4+
define <8 x i32> @sext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
5+
; CHECK-LABEL: sext_mul_v8i8:
6+
; CHECK: .functype sext_mul_v8i8 (i32, v128, v128) -> ()
7+
; CHECK-NEXT: # %bb.0:
8+
; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push3=, $1, $1
9+
; CHECK-NEXT: local.tee $push2=, $1=, $pop3
10+
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop2
11+
; CHECK-NEXT: v128.store 16($0), $pop0
12+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $1
13+
; CHECK-NEXT: v128.store 0($0), $pop1
14+
; CHECK-NEXT: return
15+
%wide.a = sext <8 x i8> %a to <8 x i32>
16+
%wide.b = sext <8 x i8> %a to <8 x i32>
17+
%mul = mul <8 x i32> %wide.a, %wide.b
18+
ret <8 x i32> %mul
19+
}
20+
21+
define <16 x i32> @sext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
22+
; CHECK-LABEL: sext_mul_v16i8:
23+
; CHECK: .functype sext_mul_v16i8 (i32, v128, v128) -> ()
24+
; CHECK-NEXT: # %bb.0:
25+
; CHECK-NEXT: i16x8.extmul_high_i8x16_s $push7=, $1, $1
26+
; CHECK-NEXT: local.tee $push6=, $3=, $pop7
27+
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop6
28+
; CHECK-NEXT: v128.store 48($0), $pop0
29+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $3
30+
; CHECK-NEXT: v128.store 32($0), $pop1
31+
; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push5=, $1, $1
32+
; CHECK-NEXT: local.tee $push4=, $1=, $pop5
33+
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push2=, $pop4
34+
; CHECK-NEXT: v128.store 16($0), $pop2
35+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
36+
; CHECK-NEXT: v128.store 0($0), $pop3
37+
; CHECK-NEXT: return
38+
%wide.a = sext <16 x i8> %a to <16 x i32>
39+
%wide.b = sext <16 x i8> %a to <16 x i32>
40+
%mul = mul <16 x i32> %wide.a, %wide.b
41+
ret <16 x i32> %mul
42+
}
43+
44+
define <8 x i32> @sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
45+
; CHECK-LABEL: sext_mul_v8i16:
46+
; CHECK: .functype sext_mul_v8i16 (i32, v128, v128) -> ()
47+
; CHECK-NEXT: # %bb.0:
48+
; CHECK-NEXT: i32x4.extmul_high_i16x8_s $push0=, $1, $1
49+
; CHECK-NEXT: v128.store 16($0), $pop0
50+
; CHECK-NEXT: i32x4.extmul_low_i16x8_s $push1=, $1, $1
51+
; CHECK-NEXT: v128.store 0($0), $pop1
52+
; CHECK-NEXT: return
53+
%wide.a = sext <8 x i16> %a to <8 x i32>
54+
%wide.b = sext <8 x i16> %a to <8 x i32>
55+
%mul = mul <8 x i32> %wide.a, %wide.b
56+
ret <8 x i32> %mul
57+
}
58+
59+
define <8 x i32> @zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
60+
; CHECK-LABEL: zext_mul_v8i8:
61+
; CHECK: .functype zext_mul_v8i8 (i32, v128, v128) -> ()
62+
; CHECK-NEXT: # %bb.0:
63+
; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push3=, $1, $1
64+
; CHECK-NEXT: local.tee $push2=, $1=, $pop3
65+
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop2
66+
; CHECK-NEXT: v128.store 16($0), $pop0
67+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $1
68+
; CHECK-NEXT: v128.store 0($0), $pop1
69+
; CHECK-NEXT: return
70+
%wide.a = zext <8 x i8> %a to <8 x i32>
71+
%wide.b = zext <8 x i8> %a to <8 x i32>
72+
%mul = mul <8 x i32> %wide.a, %wide.b
73+
ret <8 x i32> %mul
74+
}
75+
76+
define <16 x i32> @zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
77+
; CHECK-LABEL: zext_mul_v16i8:
78+
; CHECK: .functype zext_mul_v16i8 (i32, v128, v128) -> ()
79+
; CHECK-NEXT: # %bb.0:
80+
; CHECK-NEXT: i16x8.extmul_high_i8x16_u $push7=, $1, $1
81+
; CHECK-NEXT: local.tee $push6=, $3=, $pop7
82+
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop6
83+
; CHECK-NEXT: v128.store 48($0), $pop0
84+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $3
85+
; CHECK-NEXT: v128.store 32($0), $pop1
86+
; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push5=, $1, $1
87+
; CHECK-NEXT: local.tee $push4=, $1=, $pop5
88+
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push2=, $pop4
89+
; CHECK-NEXT: v128.store 16($0), $pop2
90+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push3=, $1
91+
; CHECK-NEXT: v128.store 0($0), $pop3
92+
; CHECK-NEXT: return
93+
%wide.a = zext <16 x i8> %a to <16 x i32>
94+
%wide.b = zext <16 x i8> %a to <16 x i32>
95+
%mul = mul <16 x i32> %wide.a, %wide.b
96+
ret <16 x i32> %mul
97+
}
98+
99+
define <8 x i32> @zext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
100+
; CHECK-LABEL: zext_mul_v8i16:
101+
; CHECK: .functype zext_mul_v8i16 (i32, v128, v128) -> ()
102+
; CHECK-NEXT: # %bb.0:
103+
; CHECK-NEXT: i32x4.extmul_high_i16x8_u $push0=, $1, $1
104+
; CHECK-NEXT: v128.store 16($0), $pop0
105+
; CHECK-NEXT: i32x4.extmul_low_i16x8_u $push1=, $1, $1
106+
; CHECK-NEXT: v128.store 0($0), $pop1
107+
; CHECK-NEXT: return
108+
%wide.a = zext <8 x i16> %a to <8 x i32>
109+
%wide.b = zext <8 x i16> %a to <8 x i32>
110+
%mul = mul <8 x i32> %wide.a, %wide.b
111+
ret <8 x i32> %mul
112+
}
113+
114+
define <8 x i32> @sext_zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
115+
; CHECK-LABEL: sext_zext_mul_v8i8:
116+
; CHECK: .functype sext_zext_mul_v8i8 (i32, v128, v128) -> ()
117+
; CHECK-NEXT: # %bb.0:
118+
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
119+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
120+
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
121+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
122+
; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
123+
; CHECK-NEXT: v128.store 0($0), $pop4
124+
; CHECK-NEXT: i8x16.shuffle $push11=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
125+
; CHECK-NEXT: local.tee $push10=, $1=, $pop11
126+
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop10
127+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
128+
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $1
129+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
130+
; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
131+
; CHECK-NEXT: v128.store 16($0), $pop9
132+
; CHECK-NEXT: return
133+
%wide.a = sext <8 x i8> %a to <8 x i32>
134+
%wide.b = zext <8 x i8> %a to <8 x i32>
135+
%mul = mul <8 x i32> %wide.a, %wide.b
136+
ret <8 x i32> %mul
137+
}
138+
139+
define <16 x i32> @sext_zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
140+
; CHECK-LABEL: sext_zext_mul_v16i8:
141+
; CHECK: .functype sext_zext_mul_v16i8 (i32, v128, v128) -> ()
142+
; CHECK-NEXT: # %bb.0:
143+
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
144+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
145+
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
146+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
147+
; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
148+
; CHECK-NEXT: v128.store 0($0), $pop4
149+
; CHECK-NEXT: i8x16.shuffle $push25=, $1, $1, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
150+
; CHECK-NEXT: local.tee $push24=, $3=, $pop25
151+
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop24
152+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
153+
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $3
154+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
155+
; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
156+
; CHECK-NEXT: v128.store 48($0), $pop9
157+
; CHECK-NEXT: i8x16.shuffle $push23=, $1, $1, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0
158+
; CHECK-NEXT: local.tee $push22=, $3=, $pop23
159+
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push12=, $pop22
160+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push13=, $pop12
161+
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push10=, $3
162+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push11=, $pop10
163+
; CHECK-NEXT: i32x4.mul $push14=, $pop13, $pop11
164+
; CHECK-NEXT: v128.store 32($0), $pop14
165+
; CHECK-NEXT: i8x16.shuffle $push21=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
166+
; CHECK-NEXT: local.tee $push20=, $1=, $pop21
167+
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push17=, $pop20
168+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push18=, $pop17
169+
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push15=, $1
170+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push16=, $pop15
171+
; CHECK-NEXT: i32x4.mul $push19=, $pop18, $pop16
172+
; CHECK-NEXT: v128.store 16($0), $pop19
173+
; CHECK-NEXT: return
174+
%wide.a = sext <16 x i8> %a to <16 x i32>
175+
%wide.b = zext <16 x i8> %a to <16 x i32>
176+
%mul = mul <16 x i32> %wide.a, %wide.b
177+
ret <16 x i32> %mul
178+
}
179+
180+
define <8 x i32> @zext_sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
181+
; CHECK-LABEL: zext_sext_mul_v8i16:
182+
; CHECK: .functype zext_sext_mul_v8i16 (i32, v128, v128) -> ()
183+
; CHECK-NEXT: # %bb.0:
184+
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push1=, $1
185+
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $1
186+
; CHECK-NEXT: i32x4.mul $push2=, $pop1, $pop0
187+
; CHECK-NEXT: v128.store 16($0), $pop2
188+
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push4=, $1
189+
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
190+
; CHECK-NEXT: i32x4.mul $push5=, $pop4, $pop3
191+
; CHECK-NEXT: v128.store 0($0), $pop5
192+
; CHECK-NEXT: return
193+
%wide.a = zext <8 x i16> %a to <8 x i32>
194+
%wide.b = sext <8 x i16> %a to <8 x i32>
195+
%mul = mul <8 x i32> %wide.a, %wide.b
196+
ret <8 x i32> %mul
197+
}

0 commit comments

Comments
 (0)