Skip to content

[WebAssembly] Lower wide SIMD i8 muls #130785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 93 additions & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);

// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);

// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);

Expand Down Expand Up @@ -1461,8 +1464,7 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,

bool WebAssemblyTargetLowering::CanLowerReturn(
CallingConv::ID /*CallConv*/, MachineFunction & /*MF*/, bool /*IsVarArg*/,
const SmallVectorImpl<ISD::OutputArg> &Outs,
LLVMContext & /*Context*/,
const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext & /*Context*/,
const Type *RetTy) const {
// WebAssembly can only handle returning tuples with multivalue enabled
return WebAssembly::canLowerReturn(Outs.size(), Subtarget);
Expand Down Expand Up @@ -3254,6 +3256,93 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}

static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
if (VT != MVT::v8i32 && VT != MVT::v16i32)
return SDValue();

// Mul with extending inputs.
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
if (LHS.getOpcode() != RHS.getOpcode())
return SDValue();

if (LHS.getOpcode() != ISD::SIGN_EXTEND &&
LHS.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();

if (LHS->getOperand(0).getValueType() != RHS->getOperand(0).getValueType())
return SDValue();

EVT FromVT = LHS->getOperand(0).getValueType();
EVT EltTy = FromVT.getVectorElementType();
if (EltTy != MVT::i8)
return SDValue();

// For an input DAG that looks like this
// %a = input_type
// %b = input_type
// %lhs = extend %a to output_type
// %rhs = extend %b to output_type
// %mul = mul %lhs, %rhs

// input_type | output_type | instructions
// v16i8 | v16i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
// | | %high = i16x8.extmul_high_i8x16_, %a, %b
// | | %low_low = i32x4.ext_low_i16x8_ %low
// | | %low_high = i32x4.ext_high_i16x8_ %low
// | | %high_low = i32x4.ext_low_i16x8_ %high
// | | %high_high = i32x4.ext_high_i16x8_ %high
// | | %res = concat_vector(...)
// v8i8 | v8i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
// | | %low_low = i32x4.ext_low_i16x8_ %low
// | | %low_high = i32x4.ext_high_i16x8_ %low
// | | %res = concat_vector(%low_low, %low_high)

SDLoc DL(N);
unsigned NumElts = VT.getVectorNumElements();
SDValue ExtendInLHS = LHS->getOperand(0);
SDValue ExtendInRHS = RHS->getOperand(0);
bool IsSigned = LHS->getOpcode() == ISD::SIGN_EXTEND;
unsigned ExtendLowOpc =
IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
unsigned ExtendHighOpc =
IsSigned ? WebAssemblyISD::EXTEND_HIGH_S : WebAssemblyISD::EXTEND_HIGH_U;

auto GetExtendLow = [&DAG, &DL, &ExtendLowOpc](EVT VT, SDValue Op) {
return DAG.getNode(ExtendLowOpc, DL, VT, Op);
};
auto GetExtendHigh = [&DAG, &DL, &ExtendHighOpc](EVT VT, SDValue Op) {
return DAG.getNode(ExtendHighOpc, DL, VT, Op);
};

if (NumElts == 16) {
SDValue LowLHS = GetExtendLow(MVT::v8i16, ExtendInLHS);
SDValue LowRHS = GetExtendLow(MVT::v8i16, ExtendInRHS);
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue HighLHS = GetExtendHigh(MVT::v8i16, ExtendInLHS);
SDValue HighRHS = GetExtendHigh(MVT::v8i16, ExtendInRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
SDValue SubVectors[] = {
GetExtendLow(MVT::v4i32, MulLow),
GetExtendHigh(MVT::v4i32, MulLow),
GetExtendLow(MVT::v4i32, MulHigh),
GetExtendHigh(MVT::v4i32, MulHigh),
};
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SubVectors);
} else {
assert(NumElts == 8);
SDValue LowLHS = DAG.getNode(LHS->getOpcode(), DL, MVT::v8i16, ExtendInLHS);
SDValue LowRHS = DAG.getNode(RHS->getOpcode(), DL, MVT::v8i16, ExtendInRHS);
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue Lo = GetExtendLow(MVT::v4i32, MulLow);
SDValue Hi = GetExtendHigh(MVT::v4i32, MulLow);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
}
return SDValue();
}

SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
Expand Down Expand Up @@ -3281,5 +3370,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN:
return performLowerPartialReduction(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
}
}
197 changes: 197 additions & 0 deletions llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=wasm32 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s

define <8 x i32> @sext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
; CHECK-LABEL: sext_mul_v8i8:
; CHECK: .functype sext_mul_v8i8 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push3=, $1, $1
; CHECK-NEXT: local.tee $push2=, $1=, $pop3
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop2
; CHECK-NEXT: v128.store 16($0), $pop0
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $1
; CHECK-NEXT: v128.store 0($0), $pop1
; CHECK-NEXT: return
%wide.a = sext <8 x i8> %a to <8 x i32>
%wide.b = sext <8 x i8> %a to <8 x i32>
%mul = mul <8 x i32> %wide.a, %wide.b
ret <8 x i32> %mul
}

define <16 x i32> @sext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
; CHECK-LABEL: sext_mul_v16i8:
; CHECK: .functype sext_mul_v16i8 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i16x8.extmul_high_i8x16_s $push7=, $1, $1
; CHECK-NEXT: local.tee $push6=, $3=, $pop7
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop6
; CHECK-NEXT: v128.store 48($0), $pop0
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $3
; CHECK-NEXT: v128.store 32($0), $pop1
; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push5=, $1, $1
; CHECK-NEXT: local.tee $push4=, $1=, $pop5
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push2=, $pop4
; CHECK-NEXT: v128.store 16($0), $pop2
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
; CHECK-NEXT: v128.store 0($0), $pop3
; CHECK-NEXT: return
%wide.a = sext <16 x i8> %a to <16 x i32>
%wide.b = sext <16 x i8> %a to <16 x i32>
%mul = mul <16 x i32> %wide.a, %wide.b
ret <16 x i32> %mul
}

define <8 x i32> @sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: sext_mul_v8i16:
; CHECK: .functype sext_mul_v8i16 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i32x4.extmul_high_i16x8_s $push0=, $1, $1
; CHECK-NEXT: v128.store 16($0), $pop0
; CHECK-NEXT: i32x4.extmul_low_i16x8_s $push1=, $1, $1
; CHECK-NEXT: v128.store 0($0), $pop1
; CHECK-NEXT: return
%wide.a = sext <8 x i16> %a to <8 x i32>
%wide.b = sext <8 x i16> %a to <8 x i32>
%mul = mul <8 x i32> %wide.a, %wide.b
ret <8 x i32> %mul
}

define <8 x i32> @zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
; CHECK-LABEL: zext_mul_v8i8:
; CHECK: .functype zext_mul_v8i8 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push3=, $1, $1
; CHECK-NEXT: local.tee $push2=, $1=, $pop3
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop2
; CHECK-NEXT: v128.store 16($0), $pop0
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $1
; CHECK-NEXT: v128.store 0($0), $pop1
; CHECK-NEXT: return
%wide.a = zext <8 x i8> %a to <8 x i32>
%wide.b = zext <8 x i8> %a to <8 x i32>
%mul = mul <8 x i32> %wide.a, %wide.b
ret <8 x i32> %mul
}

define <16 x i32> @zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
; CHECK-LABEL: zext_mul_v16i8:
; CHECK: .functype zext_mul_v16i8 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i16x8.extmul_high_i8x16_u $push7=, $1, $1
; CHECK-NEXT: local.tee $push6=, $3=, $pop7
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop6
; CHECK-NEXT: v128.store 48($0), $pop0
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $3
; CHECK-NEXT: v128.store 32($0), $pop1
; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push5=, $1, $1
; CHECK-NEXT: local.tee $push4=, $1=, $pop5
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push2=, $pop4
; CHECK-NEXT: v128.store 16($0), $pop2
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push3=, $1
; CHECK-NEXT: v128.store 0($0), $pop3
; CHECK-NEXT: return
%wide.a = zext <16 x i8> %a to <16 x i32>
%wide.b = zext <16 x i8> %a to <16 x i32>
%mul = mul <16 x i32> %wide.a, %wide.b
ret <16 x i32> %mul
}

define <8 x i32> @zext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: zext_mul_v8i16:
; CHECK: .functype zext_mul_v8i16 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i32x4.extmul_high_i16x8_u $push0=, $1, $1
; CHECK-NEXT: v128.store 16($0), $pop0
; CHECK-NEXT: i32x4.extmul_low_i16x8_u $push1=, $1, $1
; CHECK-NEXT: v128.store 0($0), $pop1
; CHECK-NEXT: return
%wide.a = zext <8 x i16> %a to <8 x i32>
%wide.b = zext <8 x i16> %a to <8 x i32>
%mul = mul <8 x i32> %wide.a, %wide.b
ret <8 x i32> %mul
}

define <8 x i32> @sext_zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
; CHECK-LABEL: sext_zext_mul_v8i8:
; CHECK: .functype sext_zext_mul_v8i8 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
; CHECK-NEXT: v128.store 0($0), $pop4
; CHECK-NEXT: i8x16.shuffle $push11=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
; CHECK-NEXT: local.tee $push10=, $1=, $pop11
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop10
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $1
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
; CHECK-NEXT: v128.store 16($0), $pop9
; CHECK-NEXT: return
%wide.a = sext <8 x i8> %a to <8 x i32>
%wide.b = zext <8 x i8> %a to <8 x i32>
%mul = mul <8 x i32> %wide.a, %wide.b
ret <8 x i32> %mul
}

define <16 x i32> @sext_zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
; CHECK-LABEL: sext_zext_mul_v16i8:
; CHECK: .functype sext_zext_mul_v16i8 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
; CHECK-NEXT: v128.store 0($0), $pop4
; CHECK-NEXT: i8x16.shuffle $push25=, $1, $1, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
; CHECK-NEXT: local.tee $push24=, $3=, $pop25
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop24
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $3
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
; CHECK-NEXT: v128.store 48($0), $pop9
; CHECK-NEXT: i8x16.shuffle $push23=, $1, $1, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0
; CHECK-NEXT: local.tee $push22=, $3=, $pop23
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push12=, $pop22
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push13=, $pop12
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push10=, $3
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push11=, $pop10
; CHECK-NEXT: i32x4.mul $push14=, $pop13, $pop11
; CHECK-NEXT: v128.store 32($0), $pop14
; CHECK-NEXT: i8x16.shuffle $push21=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
; CHECK-NEXT: local.tee $push20=, $1=, $pop21
; CHECK-NEXT: i16x8.extend_low_i8x16_s $push17=, $pop20
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push18=, $pop17
; CHECK-NEXT: i16x8.extend_low_i8x16_u $push15=, $1
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push16=, $pop15
; CHECK-NEXT: i32x4.mul $push19=, $pop18, $pop16
; CHECK-NEXT: v128.store 16($0), $pop19
; CHECK-NEXT: return
%wide.a = sext <16 x i8> %a to <16 x i32>
%wide.b = zext <16 x i8> %a to <16 x i32>
%mul = mul <16 x i32> %wide.a, %wide.b
ret <16 x i32> %mul
}

define <8 x i32> @zext_sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: zext_sext_mul_v8i16:
; CHECK: .functype zext_sext_mul_v8i16 (i32, v128, v128) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i32x4.extend_high_i16x8_u $push1=, $1
; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $1
; CHECK-NEXT: i32x4.mul $push2=, $pop1, $pop0
; CHECK-NEXT: v128.store 16($0), $pop2
; CHECK-NEXT: i32x4.extend_low_i16x8_u $push4=, $1
; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
; CHECK-NEXT: i32x4.mul $push5=, $pop4, $pop3
; CHECK-NEXT: v128.store 0($0), $pop5
; CHECK-NEXT: return
%wide.a = zext <8 x i16> %a to <8 x i32>
%wide.b = sext <8 x i16> %a to <8 x i32>
%mul = mul <8 x i32> %wide.a, %wide.b
ret <8 x i32> %mul
}