-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[WebAssembly] Lower wide SIMD i8 muls #130785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Currently, 'wide' i32 simd multiplication, with extended i8 elements, will perform the multiplication with i32 So, for IR like the following: %wide.a = sext <8 x i8> %a to <8 x i32> %wide.b = sext <8 x i8> %a to <8 x i32> %mul = mul <8 x i32> %wide.a, %wide.b ret <8 x i32> %mul We would generate the following sequence: i16x8.extend_low_i8x16_s $push6=, $1 local.tee $push5=, $3=, $pop6 i32x4.extmul_low_i16x8_s $push0=, $pop5, $3 v128.store 0($0), $pop0 i8x16.shuffle $push1=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 i16x8.extend_low_i8x16_s $push4=, $pop1 local.tee $push3=, $1=, $pop4 i32x4.extmul_low_i16x8_s $push2=, $pop3, $1 v128.store 16($0), $pop2 return But now we perform the multiplication with i16, resulting in: i16x8.extmul_low_i8x16_s $push3=, $1, $1 local.tee $push2=, $1=, $pop3 i32x4.extend_high_i16x8_s $push0=, $pop2 v128.store 16($0), $pop0 i32x4.extend_low_i16x8_s $push1=, $1 v128.store 0($0), $pop1 return
@llvm/pr-subscribers-backend-webassembly Author: Sam Parker (sparker-arm) ChangesCurrently, 'wide' i32 simd multiplication, with extended i8 elements, will perform the multiplication with i32 So, for IR like the following:
We would generate the following sequence:
But now we perform the multiplication with i16, resulting in:
Full diff: https://github.com/llvm/llvm-project/pull/130785.diff 2 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index b24a45c2d8898..9ae46e709d823 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -183,6 +183,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+ // Combine wide-vector muls, with extend inputs, to extmul_half.
+ setTargetDAGCombine(ISD::MUL);
+
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -1461,8 +1464,7 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
bool WebAssemblyTargetLowering::CanLowerReturn(
CallingConv::ID /*CallConv*/, MachineFunction & /*MF*/, bool /*IsVarArg*/,
- const SmallVectorImpl<ISD::OutputArg> &Outs,
- LLVMContext & /*Context*/,
+ const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext & /*Context*/,
const Type *RetTy) const {
// WebAssembly can only handle returning tuples with multivalue enabled
return WebAssembly::canLowerReturn(Outs.size(), Subtarget);
@@ -3254,6 +3256,93 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::MUL);
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v8i32 && VT != MVT::v16i32)
+ return SDValue();
+
+ // Mul with extending inputs.
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ if (LHS.getOpcode() != RHS.getOpcode())
+ return SDValue();
+
+ if (LHS.getOpcode() != ISD::SIGN_EXTEND &&
+ LHS.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+
+ if (LHS->getOperand(0).getValueType() != RHS->getOperand(0).getValueType())
+ return SDValue();
+
+ EVT FromVT = LHS->getOperand(0).getValueType();
+ EVT EltTy = FromVT.getVectorElementType();
+ if (EltTy != MVT::i8)
+ return SDValue();
+
+ // For an input DAG that looks like this
+ // %a = input_type
+ // %b = input_type
+ // %lhs = extend %a to output_type
+ // %rhs = extend %b to output_type
+ // %mul = mul %lhs, %rhs
+
+ // input_type | output_type | instructions
+ // v16i8 | v16i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
+ // | | %high = i16x8.extmul_high_i8x16_, %a, %b
+ // | | %low_low = i32x4.ext_low_i16x8_ %low
+ // | | %low_high = i32x4.ext_high_i16x8_ %low
+ // | | %high_low = i32x4.ext_low_i16x8_ %high
+ // | | %high_high = i32x4.ext_high_i16x8_ %high
+ // | | %res = concat_vector(...)
+ // v8i8 | v8i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
+ // | | %low_low = i32x4.ext_low_i16x8_ %low
+ // | | %low_high = i32x4.ext_high_i16x8_ %low
+ // | | %res = concat_vector(%low_low, %low_high)
+
+ SDLoc DL(N);
+ unsigned NumElts = VT.getVectorNumElements();
+ SDValue ExtendInLHS = LHS->getOperand(0);
+ SDValue ExtendInRHS = RHS->getOperand(0);
+ bool IsSigned = LHS->getOpcode() == ISD::SIGN_EXTEND;
+ unsigned ExtendLowOpc =
+ IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
+ unsigned ExtendHighOpc =
+ IsSigned ? WebAssemblyISD::EXTEND_HIGH_S : WebAssemblyISD::EXTEND_HIGH_U;
+
+ auto GetExtendLow = [&DAG, &DL, &ExtendLowOpc](EVT VT, SDValue Op) {
+ return DAG.getNode(ExtendLowOpc, DL, VT, Op);
+ };
+ auto GetExtendHigh = [&DAG, &DL, &ExtendHighOpc](EVT VT, SDValue Op) {
+ return DAG.getNode(ExtendHighOpc, DL, VT, Op);
+ };
+
+ if (NumElts == 16) {
+ SDValue LowLHS = GetExtendLow(MVT::v8i16, ExtendInLHS);
+ SDValue LowRHS = GetExtendLow(MVT::v8i16, ExtendInRHS);
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+ SDValue HighLHS = GetExtendHigh(MVT::v8i16, ExtendInLHS);
+ SDValue HighRHS = GetExtendHigh(MVT::v8i16, ExtendInRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+ SDValue SubVectors[] = {
+ GetExtendLow(MVT::v4i32, MulLow),
+ GetExtendHigh(MVT::v4i32, MulLow),
+ GetExtendLow(MVT::v4i32, MulHigh),
+ GetExtendHigh(MVT::v4i32, MulHigh),
+ };
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SubVectors);
+ } else {
+ assert(NumElts == 8);
+ SDValue LowLHS = DAG.getNode(LHS->getOpcode(), DL, MVT::v8i16, ExtendInLHS);
+ SDValue LowRHS = DAG.getNode(RHS->getOpcode(), DL, MVT::v8i16, ExtendInRHS);
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+ SDValue Lo = GetExtendLow(MVT::v4i32, MulLow);
+ SDValue Hi = GetExtendHigh(MVT::v4i32, MulLow);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
+ }
+ return SDValue();
+}
+
SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@@ -3281,5 +3370,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN:
return performLowerPartialReduction(N, DCI.DAG);
+ case ISD::MUL:
+ return performMulCombine(N, DCI.DAG);
}
}
diff --git a/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll b/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll
new file mode 100644
index 0000000000000..94aa197bfd564
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll
@@ -0,0 +1,197 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=wasm32 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
+
+define <8 x i32> @sext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sext_mul_v8i8:
+; CHECK: .functype sext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push3=, $1, $1
+; CHECK-NEXT: local.tee $push2=, $1=, $pop3
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop2
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = sext <8 x i8> %a to <8 x i32>
+ %wide.b = sext <8 x i8> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <16 x i32> @sext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sext_mul_v16i8:
+; CHECK: .functype sext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_high_i8x16_s $push7=, $1, $1
+; CHECK-NEXT: local.tee $push6=, $3=, $pop7
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop6
+; CHECK-NEXT: v128.store 48($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $3
+; CHECK-NEXT: v128.store 32($0), $pop1
+; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push5=, $1, $1
+; CHECK-NEXT: local.tee $push4=, $1=, $pop5
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push2=, $pop4
+; CHECK-NEXT: v128.store 16($0), $pop2
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
+; CHECK-NEXT: v128.store 0($0), $pop3
+; CHECK-NEXT: return
+ %wide.a = sext <16 x i8> %a to <16 x i32>
+ %wide.b = sext <16 x i8> %a to <16 x i32>
+ %mul = mul <16 x i32> %wide.a, %wide.b
+ ret <16 x i32> %mul
+}
+
+define <8 x i32> @sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: sext_mul_v8i16:
+; CHECK: .functype sext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.extmul_high_i16x8_s $push0=, $1, $1
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extmul_low_i16x8_s $push1=, $1, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = sext <8 x i16> %a to <8 x i32>
+ %wide.b = sext <8 x i16> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <8 x i32> @zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: zext_mul_v8i8:
+; CHECK: .functype zext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push3=, $1, $1
+; CHECK-NEXT: local.tee $push2=, $1=, $pop3
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop2
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = zext <8 x i8> %a to <8 x i32>
+ %wide.b = zext <8 x i8> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <16 x i32> @zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: zext_mul_v16i8:
+; CHECK: .functype zext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_high_i8x16_u $push7=, $1, $1
+; CHECK-NEXT: local.tee $push6=, $3=, $pop7
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop6
+; CHECK-NEXT: v128.store 48($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $3
+; CHECK-NEXT: v128.store 32($0), $pop1
+; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push5=, $1, $1
+; CHECK-NEXT: local.tee $push4=, $1=, $pop5
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push2=, $pop4
+; CHECK-NEXT: v128.store 16($0), $pop2
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push3=, $1
+; CHECK-NEXT: v128.store 0($0), $pop3
+; CHECK-NEXT: return
+ %wide.a = zext <16 x i8> %a to <16 x i32>
+ %wide.b = zext <16 x i8> %a to <16 x i32>
+ %mul = mul <16 x i32> %wide.a, %wide.b
+ ret <16 x i32> %mul
+}
+
+define <8 x i32> @zext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: zext_mul_v8i16:
+; CHECK: .functype zext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.extmul_high_i16x8_u $push0=, $1, $1
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extmul_low_i16x8_u $push1=, $1, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = zext <8 x i16> %a to <8 x i32>
+ %wide.b = zext <8 x i16> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <8 x i32> @sext_zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sext_zext_mul_v8i8:
+; CHECK: .functype sext_zext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
+; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
+; CHECK-NEXT: v128.store 0($0), $pop4
+; CHECK-NEXT: i8x16.shuffle $push11=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push10=, $1=, $pop11
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop10
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
+; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
+; CHECK-NEXT: v128.store 16($0), $pop9
+; CHECK-NEXT: return
+ %wide.a = sext <8 x i8> %a to <8 x i32>
+ %wide.b = zext <8 x i8> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <16 x i32> @sext_zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sext_zext_mul_v16i8:
+; CHECK: .functype sext_zext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
+; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
+; CHECK-NEXT: v128.store 0($0), $pop4
+; CHECK-NEXT: i8x16.shuffle $push25=, $1, $1, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push24=, $3=, $pop25
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop24
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $3
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
+; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
+; CHECK-NEXT: v128.store 48($0), $pop9
+; CHECK-NEXT: i8x16.shuffle $push23=, $1, $1, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push22=, $3=, $pop23
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push12=, $pop22
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push13=, $pop12
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push10=, $3
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push11=, $pop10
+; CHECK-NEXT: i32x4.mul $push14=, $pop13, $pop11
+; CHECK-NEXT: v128.store 32($0), $pop14
+; CHECK-NEXT: i8x16.shuffle $push21=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push20=, $1=, $pop21
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push17=, $pop20
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push18=, $pop17
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push15=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push16=, $pop15
+; CHECK-NEXT: i32x4.mul $push19=, $pop18, $pop16
+; CHECK-NEXT: v128.store 16($0), $pop19
+; CHECK-NEXT: return
+ %wide.a = sext <16 x i8> %a to <16 x i32>
+ %wide.b = zext <16 x i8> %a to <16 x i32>
+ %mul = mul <16 x i32> %wide.a, %wide.b
+ ret <16 x i32> %mul
+}
+
+define <8 x i32> @zext_sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: zext_sext_mul_v8i16:
+; CHECK: .functype zext_sext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push1=, $1
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $1
+; CHECK-NEXT: i32x4.mul $push2=, $pop1, $pop0
+; CHECK-NEXT: v128.store 16($0), $pop2
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push4=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
+; CHECK-NEXT: i32x4.mul $push5=, $pop4, $pop3
+; CHECK-NEXT: v128.store 0($0), $pop5
+; CHECK-NEXT: return
+ %wide.a = zext <8 x i16> %a to <8 x i32>
+ %wide.b = sext <8 x i16> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
|
Ping? |
dschuff
approved these changes
Mar 20, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Currently, 'wide' i32 simd multiplication, with extended i8 elements, will perform the multiplication with i32 So, for IR like the following:
We would generate the following sequence:
But now we perform the multiplication with i16, resulting in: