Skip to content

[WebAssembly] Autovec support for dot #123207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 3, 2025

Conversation

sparker-arm
Copy link
Contributor

Enable the use of partial.reduce.add that we can lower to dot or a tree of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support both v8i16 and v16i8 inputs.

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-backend-webassembly

Author: Sam Parker (sparker-arm)

Changes

Enable the use of partial.reduce.add that we can lower to dot or a tree of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support both v8i16 and v16i8 inputs.


Patch is 28.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123207.diff

7 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISD.def (+1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+124)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h (+1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td (+4)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp (+47)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h (+6-1)
  • (added) llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll (+402)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 1cf0d13df1ff6b..378ef2c8f250e6 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -26,6 +26,7 @@ HANDLE_NODETYPE(Wrapper)
 HANDLE_NODETYPE(WrapperREL)
 HANDLE_NODETYPE(BR_IF)
 HANDLE_NODETYPE(BR_TABLE)
+HANDLE_NODETYPE(DOT)
 HANDLE_NODETYPE(SHUFFLE)
 HANDLE_NODETYPE(SWIZZLE)
 HANDLE_NODETYPE(VEC_SHL)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 084aed6eed46d3..4da9e65853b9a4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -29,6 +29,7 @@
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsWebAssembly.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -177,6 +178,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
 
   // SIMD-specific configuration
   if (Subtarget->hasSIMD128()) {
+
+    // Combine partial.reduce.add before legalization gets confused.
+    setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+
     // Combine vector mask reductions into alltrue/anytrue
     setTargetDAGCombine(ISD::SETCC);
 
@@ -406,6 +411,35 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
   return TargetLowering::getPointerMemTy(DL, AS);
 }
 
+bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
+    const IntrinsicInst *I) const {
+  if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
+    return true;
+
+  EVT VT = EVT::getEVT(I->getType());
+  auto Op1 = I->getOperand(1);
+
+  if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
+    if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
+      return true;
+
+    if (isa<Instruction>(InputInst->getOperand(0)) &&
+        isa<Instruction>(InputInst->getOperand(1))) {
+      // dot only supports signed inputs but also support lowering unsigned.
+      if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
+          cast<Instruction>(InputInst->getOperand(1))->getOpcode())
+        return true;
+
+      EVT Op1VT = EVT::getEVT(Op1->getType());
+      if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
+          ((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
+           (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
+        return false;
+    }
+  }
+  return true;
+}
+
 TargetLowering::AtomicExpansionKind
 WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
   // We have wasm instructions for these
@@ -2029,6 +2063,94 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
                       MachinePointerInfo(SV));
 }
 
+// Try to lower partial.reduce.add to a dot or fallback to a sequence with
+// extmul and adds.
+SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
+  if (N->getConstantOperandVal(0) !=
+      Intrinsic::experimental_vector_partial_reduce_add)
+    return SDValue();
+
+  assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
+  SDLoc DL(N);
+  SDValue Mul = N->getOperand(2);
+  assert(Mul->getOpcode() == ISD::MUL && "expected mul input");
+
+  SDValue ExtendLHS = Mul->getOperand(0);
+  SDValue ExtendRHS = Mul->getOperand(1);
+  assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
+          ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
+         "expected widening mul");
+  assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
+         "expected mul to use the same extend for both operands");
+
+  SDValue ExtendInLHS = ExtendLHS->getOperand(0);
+  SDValue ExtendInRHS = ExtendRHS->getOperand(0);
+  bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+
+  if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+    if (IsSigned) {
+      // i32x4.dot_i16x8_s
+      SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
+                                ExtendInLHS, ExtendInRHS);
+      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
+    }
+
+    unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
+    unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
+
+    // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
+    SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
+    SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
+    SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
+    SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
+
+    SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
+    SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
+    SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
+    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+  } else {
+    assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+           "expected v16i8 input types");
+    // Lower to a wider tree, using twice the operations compared to above.
+    if (IsSigned) {
+      // Use two dots
+      unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
+      unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
+      SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
+      SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
+      SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
+      SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
+      SDValue DotLHS =
+          DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
+      SDValue DotRHS =
+          DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
+      SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+    }
+
+    unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
+    unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
+    SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
+    SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
+    SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
+    SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
+
+    SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+    SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+
+    SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow);
+    SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh);
+    SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow);
+    SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh);
+
+    SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow);
+    SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh);
+    SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+  }
+}
+
 SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
                                                   SelectionDAG &DAG) const {
   MachineFunction &MF = DAG.getMachineFunction();
@@ -3125,5 +3247,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
     return performVectorTruncZeroCombine(N, DCI);
   case ISD::TRUNCATE:
     return performTruncateCombine(N, DCI);
+  case ISD::INTRINSIC_WO_CHAIN:
+    return performLowerPartialReduction(N, DCI.DAG);
   }
 }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 454432728ca871..6cc5ef51561c35 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -45,6 +45,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   /// right decision when generating code for different targets.
   const WebAssemblySubtarget *Subtarget;
 
+  bool shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
   AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
   bool shouldScalarizeBinop(SDValue VecOp) const override;
   FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 2c0543842a82bb..14acc623ce24d1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1147,11 +1147,15 @@ def : Pat<(wasm_shr_u
 }
 
 // Widening dot product: i32x4.dot_i16x8_s
+def dot_t : SDTypeProfile<1, 2, [SDTCisVT<0, v4i32>, SDTCisVT<1, v8i16>, SDTCisVT<2, v8i16>]>;
+def wasm_dot : SDNode<"WebAssemblyISD::DOT", dot_t>;
 let isCommutable = 1 in
 defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
                   [(set V128:$dst, (int_wasm_dot V128:$lhs, V128:$rhs))],
                   "i32x4.dot_i16x8_s\t$dst, $lhs, $rhs", "i32x4.dot_i16x8_s",
                   186>;
+def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
+          (DOT $lhs, $rhs)>;
 
 // Extending multiplication: extmul_{low,high}_P, extmul_high
 def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 3d678e53841664..103fdf8587255c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -92,6 +92,53 @@ WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
   return Cost;
 }
 
+InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
+    unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+    ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+    TTI::PartialReductionExtendKind OpBExtend,
+    std::optional<unsigned> BinOp) const {
+  InstructionCost Invalid = InstructionCost::getInvalid();
+  if (!VF.isFixed() || !ST->hasSIMD128())
+    return Invalid;
+
+  InstructionCost Cost(TTI::TCC_Basic);
+
+  // Possible options:
+  // - i16x8.extadd_pairwise_i8x16_sx
+  // - i32x4.extadd_pairwise_i16x8_sx
+  // - i32x4.dot_i16x8_s
+  // Only try to support dot, for now.
+
+  if (Opcode != Instruction::Add)
+    return Invalid;
+
+  if (!BinOp || *BinOp != Instruction::Mul)
+    return Invalid;
+
+  if (InputTypeA != InputTypeB)
+    return Invalid;
+
+  if (OpAExtend != OpBExtend)
+    return Invalid;
+
+  EVT InputEVT = EVT::getEVT(InputTypeA);
+  EVT AccumEVT = EVT::getEVT(AccumType);
+
+  // TODO: Add i64 accumulator.
+  if (AccumEVT != MVT::i32)
+    return Invalid;
+
+  // Signed inputs can lower to dot
+  if (InputEVT == MVT::i16 && VF.getFixedValue() == 8)
+    return OpAExtend == TTI::PR_SignExtend ? Cost : Cost * 2;
+
+  // Double the size of the lowered sequence.
+  if (InputEVT == MVT::i8 && VF.getFixedValue() == 16)
+    return OpAExtend == TTI::PR_SignExtend ? Cost * 2 : Cost * 4;
+
+  return Invalid;
+}
+
 TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
     const IntrinsicInst *II) const {
 
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 9691120b2e531d..ef4cdc76966a3b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -68,7 +68,12 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index, Value *Op0, Value *Op1);
-
+  InstructionCost
+  getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+                          Type *AccumType, ElementCount VF,
+                          TTI::PartialReductionExtendKind OpAExtend,
+                          TTI::PartialReductionExtendKind OpBExtend,
+                          std::optional<unsigned> BinOp = std::nullopt) const;
   TTI::ReductionShuffle
   getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;
 
diff --git a/llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll b/llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll
new file mode 100644
index 00000000000000..65487cef4d4ee9
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll
@@ -0,0 +1,402 @@
+; RUN: opt -mattr=+simd128 -passes=loop-vectorize %s | llc -mtriple=wasm32 -mattr=+simd128 -verify-machineinstrs -o - | FileCheck %s
+; RUN: opt -mattr=+simd128 -passes=loop-vectorize -vectorizer-maximize-bandwidth %s | llc -mtriple=wasm32 -mattr=+simd128 -verify-machineinstrs -o - | FileCheck %s --check-prefix=MAX-BANDWIDTH
+
+target triple = "wasm32"
+
+define hidden i32 @i32_mac_s8(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i32_mac_s8:
+; CHECK:    v128.load32_zero 0:p2align=0
+; CHECK:    i16x8.extend_low_i8x16_s
+; CHECK:    v128.load32_zero 0:p2align=0
+; CHECK:    i16x8.extend_low_i8x16_s
+; CHECK:    i32x4.extmul_low_i16x8_s
+; CHECK:    i32x4.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i16x8.extend_low_i8x16_s
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i16x8.extend_low_i8x16_s
+; MAX-BANDWIDTH: i32x4.dot_i16x8_s
+; MAX-BANDWIDTH: i16x8.extend_high_i8x16_s
+; MAX-BANDWIDTH: i16x8.extend_high_i8x16_s
+; MAX-BANDWIDTH: i32x4.dot_i16x8_s
+; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
+
+entry:
+  %cmp7.not = icmp eq i32 %N, 0
+  br i1 %cmp7.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  ret i32 %res.0.lcssa
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i8, ptr %a, i32 %i.09
+  %0 = load i8, ptr %arrayidx, align 1
+  %conv = sext i8 %0 to i32
+  %arrayidx1 = getelementptr inbounds i8, ptr %b, i32 %i.09
+  %1 = load i8, ptr %arrayidx1, align 1
+  %conv2 = sext i8 %1 to i32
+  %mul = mul nsw i32 %conv2, %conv
+  %add = add nsw i32 %mul, %res.08
+  %inc = add nuw i32 %i.09, 1
+  %exitcond.not = icmp eq i32 %inc, %N
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i32 @i32_mac_s16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i32_mac_s16:
+; CHECK:    i32x4.load16x4_s 0:p2align=1
+; CHECK:    i32x4.load16x4_s 0:p2align=1
+; CHECK:    i32x4.mul
+; CHECK:    i32x4.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i32x4.dot_i16x8_s
+
+entry:
+  %cmp7.not = icmp eq i32 %N, 0
+  br i1 %cmp7.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
+  ret i32 %res.0.lcssa
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.09
+  %0 = load i16, ptr %arrayidx, align 2
+  %conv = sext i16 %0 to i32
+  %arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.09
+  %1 = load i16, ptr %arrayidx1, align 2
+  %conv2 = sext i16 %1 to i32
+  %mul = mul nsw i32 %conv2, %conv
+  %add = add nsw i32 %mul, %res.08
+  %inc = add nuw i32 %i.09, 1
+  %exitcond.not = icmp eq i32 %inc, %N
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i64 @i64_mac_s16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i64_mac_s16:
+; CHECK:    v128.load32_zero 0:p2align=1
+; CHECK:    i32x4.extend_low_i16x8_s
+; CHECK:    v128.load32_zero 0:p2align=1
+; CHECK:    i32x4.extend_low_i16x8_s
+; CHECK:    i64x2.extmul_low_i32x4_s
+; CHECK:    i64x2.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i8x16.shuffle	12, 13, 14, 15, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i8x16.shuffle	12, 13, 14, 15, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i8x16.shuffle	8, 9, 10, 11, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i8x16.shuffle	8, 9, 10, 11, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i8x16.shuffle	4, 5, 6, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i8x16.shuffle	4, 5, 6, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+entry:
+  %cmp7.not = icmp eq i32 %N, 0
+  br i1 %cmp7.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %res.0.lcssa
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %res.08 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.09
+  %0 = load i16, ptr %arrayidx, align 2
+  %conv = sext i16 %0 to i64
+  %arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.09
+  %1 = load i16, ptr %arrayidx1, align 2
+  %conv2 = sext i16 %1 to i64
+  %mul = mul nsw i64 %conv2, %conv
+  %add = add nsw i64 %mul, %res.08
+  %inc = add nuw i32 %i.09, 1
+  %exitcond.not = icmp eq i32 %inc, %N
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i64 @i64_mac_s32(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i64_mac_s32:
+; CHECK:    v128.load64_zero 0:p2align=2
+; CHECK:    v128.load64_zero 0:p2align=2
+; CHECK:    i32x4.mul
+; CHECK:    i64x2.extend_low_i32x4_s
+; CHECK:    i64x2.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i32x4.mul
+; MAX-BANDWIDTH: i64x2.extend_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i8x16.shuffle	8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3
+; MAX-BANDWIDTH: i64x2.extend_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+
+entry:
+  %cmp6.not = icmp eq i32 %N, 0
+  br i1 %cmp6.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  %res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+  ret i64 %res.0.lcssa
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.08 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+  %res.07 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %a, i32 %i.08
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx1 = getelementptr inbounds i32, ptr %b, i32 %i.08
+  %1 = load i32, ptr %arrayidx1, align 4
+  %mul = mul i32 %1, %0
+  %conv = sext i32 %mul to i64
+  %add = add i64 %res.07, %conv
+  %inc = add nuw i32 %i.08, 1
+  %exitcond.not = icmp eq i32 %inc, %N
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i32 @i32_mac_u8(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i32_mac_u8:
+; CHECK:    v128.load32_zero 0:p2align=0
+; CHECK:    i16x8.extend_low_i8x...
[truncated]

Copy link

github-actions bot commented Jan 16, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@sparker-arm
Copy link
Contributor Author

Ping?

@dschuff
Copy link
Member

dschuff commented Jan 23, 2025

Thanks for the PR, the code looks pretty reasonable. I'm curious about the partial.reduce.add intrinsic since I'm not really familiar with it. Is this something that the autovectorizer will generate by default now (or maybe now that you've implemented getPartialReductionCost?)

@sparker-arm
Copy link
Contributor Author

or maybe now that you've implemented getPartialReductionCost?

Correct, but with the caveat that it won't trigger at the moment due to the vectorization factor selected. For our dot, we need a factor of 8 or 16 but, AFAICT, the vectorizer will not go above 4 as we are using an i32 for the arithmetic. The tests that shows dot being generated are doing so because I've used the -vectorizer-maximize-bandwidth option. This hook can be implemented in the backend too, but I'm not convinced that's the best approach. I think some extra work will be needed in the vectorizer to explore a wider vector factor, based upon the input widths, for these partial reductions.

@sparker-arm
Copy link
Contributor Author

And thanks for taking a look!

Enable the use of partial.reduce.add that we can lower to dot or a
tree of (add (extmul_low_u, extmul_high_u)) for the unsigned
case. We support both v8i16 and v16i8 inputs.
@sparker-arm
Copy link
Contributor Author

Ping?

@sparker-arm sparker-arm merged commit df2de13 into llvm:main Feb 3, 2025
8 checks passed
@sparker-arm sparker-arm deleted the wasm-dot-autovec branch February 3, 2025 08:58
@llvm-ci
Copy link
Collaborator

llvm-ci commented Feb 3, 2025

LLVM Buildbot has detected a new failure on builder openmp-offload-libc-amdgpu-runtime running on omp-vega20-1 while building llvm at step 7 "Add check check-offload".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/73/builds/12914

Here is the relevant piece of the build log for the reference
Step 7 (Add check check-offload) failure: test (failure)
******************** TEST 'libomptarget :: amdgcn-amd-amdhsa :: mapping/declare_mapper_nested_default_mappers.cpp' FAILED ********************
Exit Code: 2

Command Output (stdout):
--
# RUN: at line 1
/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./bin/clang++ -fopenmp    -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/offload/test -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./lib -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src  -nogpulib -Wl,-rpath,/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -Wl,-rpath,/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -Wl,-rpath,/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./lib  -fopenmp-targets=amdgcn-amd-amdhsa /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/offload/test/mapping/declare_mapper_nested_default_mappers.cpp -o /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/mapping/Output/declare_mapper_nested_default_mappers.cpp.tmp -Xoffload-linker -lc -Xoffload-linker -lm /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./lib/libomptarget.devicertl.a && /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/mapping/Output/declare_mapper_nested_default_mappers.cpp.tmp | /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./bin/FileCheck /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/offload/test/mapping/declare_mapper_nested_default_mappers.cpp
# executed command: /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./bin/clang++ -fopenmp -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/offload/test -I /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./lib -L /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -nogpulib -Wl,-rpath,/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -Wl,-rpath,/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -Wl,-rpath,/home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./lib -fopenmp-targets=amdgcn-amd-amdhsa /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/offload/test/mapping/declare_mapper_nested_default_mappers.cpp -o /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/mapping/Output/declare_mapper_nested_default_mappers.cpp.tmp -Xoffload-linker -lc -Xoffload-linker -lm /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./lib/libomptarget.devicertl.a
# executed command: /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/mapping/Output/declare_mapper_nested_default_mappers.cpp.tmp
# note: command had no output on stdout or stderr
# error: command failed with exit status: -11
# executed command: /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./bin/FileCheck /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/offload/test/mapping/declare_mapper_nested_default_mappers.cpp
# .---command stderr------------
# | FileCheck error: '<stdin>' is empty.
# | FileCheck command line:  /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.build/./bin/FileCheck /home/ompworker/bbot/openmp-offload-libc-amdgpu-runtime/llvm.src/offload/test/mapping/declare_mapper_nested_default_mappers.cpp
# `-----------------------------
# error: command failed with exit status: 2

--

********************


Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Enable the use of partial.reduce.add that we can lower to dot or a tree
of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support
both v8i16 and v16i8 inputs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants