-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[WebAssembly] Autovec support for dot #123207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-webassembly Author: Sam Parker (sparker-arm) ChangesEnable the use of partial.reduce.add that we can lower to dot or a tree of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support both v8i16 and v16i8 inputs. Patch is 28.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123207.diff 7 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 1cf0d13df1ff6b..378ef2c8f250e6 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -26,6 +26,7 @@ HANDLE_NODETYPE(Wrapper)
HANDLE_NODETYPE(WrapperREL)
HANDLE_NODETYPE(BR_IF)
HANDLE_NODETYPE(BR_TABLE)
+HANDLE_NODETYPE(DOT)
HANDLE_NODETYPE(SHUFFLE)
HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 084aed6eed46d3..4da9e65853b9a4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -29,6 +29,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/Support/ErrorHandling.h"
@@ -177,6 +178,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
+
+ // Combine partial.reduce.add before legalization gets confused.
+ setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -406,6 +411,35 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
return TargetLowering::getPointerMemTy(DL, AS);
}
+bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
+ const IntrinsicInst *I) const {
+ if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
+ return true;
+
+ EVT VT = EVT::getEVT(I->getType());
+ auto Op1 = I->getOperand(1);
+
+ if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
+ if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
+ return true;
+
+ if (isa<Instruction>(InputInst->getOperand(0)) &&
+ isa<Instruction>(InputInst->getOperand(1))) {
+ // dot only supports signed inputs but also support lowering unsigned.
+ if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
+ cast<Instruction>(InputInst->getOperand(1))->getOpcode())
+ return true;
+
+ EVT Op1VT = EVT::getEVT(Op1->getType());
+ if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
+ ((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
+ (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
+ return false;
+ }
+ }
+ return true;
+}
+
TargetLowering::AtomicExpansionKind
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// We have wasm instructions for these
@@ -2029,6 +2063,94 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
MachinePointerInfo(SV));
}
+// Try to lower partial.reduce.add to a dot or fallback to a sequence with
+// extmul and adds.
+SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
+ if (N->getConstantOperandVal(0) !=
+ Intrinsic::experimental_vector_partial_reduce_add)
+ return SDValue();
+
+ assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
+ SDLoc DL(N);
+ SDValue Mul = N->getOperand(2);
+ assert(Mul->getOpcode() == ISD::MUL && "expected mul input");
+
+ SDValue ExtendLHS = Mul->getOperand(0);
+ SDValue ExtendRHS = Mul->getOperand(1);
+ assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
+ ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
+ "expected widening mul");
+ assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
+ "expected mul to use the same extend for both operands");
+
+ SDValue ExtendInLHS = ExtendLHS->getOperand(0);
+ SDValue ExtendInRHS = ExtendRHS->getOperand(0);
+ bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+
+ if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+ if (IsSigned) {
+ // i32x4.dot_i16x8_s
+ SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
+ ExtendInLHS, ExtendInRHS);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
+ }
+
+ unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
+ unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
+
+ // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
+ SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
+ SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
+ SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
+ SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
+
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ } else {
+ assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+ "expected v16i8 input types");
+ // Lower to a wider tree, using twice the operations compared to above.
+ if (IsSigned) {
+ // Use two dots
+ unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
+ unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
+ SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
+ SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
+ SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
+ SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
+ SDValue DotLHS =
+ DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
+ SDValue DotRHS =
+ DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ }
+
+ unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
+ unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
+ SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
+ SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
+ SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
+ SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
+
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+
+ SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow);
+ SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh);
+ SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow);
+ SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh);
+
+ SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow);
+ SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ }
+}
+
SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
@@ -3125,5 +3247,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
+ case ISD::INTRINSIC_WO_CHAIN:
+ return performLowerPartialReduction(N, DCI.DAG);
}
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 454432728ca871..6cc5ef51561c35 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -45,6 +45,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
/// right decision when generating code for different targets.
const WebAssemblySubtarget *Subtarget;
+ bool shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
bool shouldScalarizeBinop(SDValue VecOp) const override;
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 2c0543842a82bb..14acc623ce24d1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1147,11 +1147,15 @@ def : Pat<(wasm_shr_u
}
// Widening dot product: i32x4.dot_i16x8_s
+def dot_t : SDTypeProfile<1, 2, [SDTCisVT<0, v4i32>, SDTCisVT<1, v8i16>, SDTCisVT<2, v8i16>]>;
+def wasm_dot : SDNode<"WebAssemblyISD::DOT", dot_t>;
let isCommutable = 1 in
defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
[(set V128:$dst, (int_wasm_dot V128:$lhs, V128:$rhs))],
"i32x4.dot_i16x8_s\t$dst, $lhs, $rhs", "i32x4.dot_i16x8_s",
186>;
+def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
+ (DOT $lhs, $rhs)>;
// Extending multiplication: extmul_{low,high}_P, extmul_high
def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 3d678e53841664..103fdf8587255c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -92,6 +92,53 @@ WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
return Cost;
}
+InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend,
+ std::optional<unsigned> BinOp) const {
+ InstructionCost Invalid = InstructionCost::getInvalid();
+ if (!VF.isFixed() || !ST->hasSIMD128())
+ return Invalid;
+
+ InstructionCost Cost(TTI::TCC_Basic);
+
+ // Possible options:
+ // - i16x8.extadd_pairwise_i8x16_sx
+ // - i32x4.extadd_pairwise_i16x8_sx
+ // - i32x4.dot_i16x8_s
+ // Only try to support dot, for now.
+
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ if (!BinOp || *BinOp != Instruction::Mul)
+ return Invalid;
+
+ if (InputTypeA != InputTypeB)
+ return Invalid;
+
+ if (OpAExtend != OpBExtend)
+ return Invalid;
+
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ EVT AccumEVT = EVT::getEVT(AccumType);
+
+ // TODO: Add i64 accumulator.
+ if (AccumEVT != MVT::i32)
+ return Invalid;
+
+ // Signed inputs can lower to dot
+ if (InputEVT == MVT::i16 && VF.getFixedValue() == 8)
+ return OpAExtend == TTI::PR_SignExtend ? Cost : Cost * 2;
+
+ // Double the size of the lowered sequence.
+ if (InputEVT == MVT::i8 && VF.getFixedValue() == 16)
+ return OpAExtend == TTI::PR_SignExtend ? Cost * 2 : Cost * 4;
+
+ return Invalid;
+}
+
TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
const IntrinsicInst *II) const {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
index 9691120b2e531d..ef4cdc76966a3b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
@@ -68,7 +68,12 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
unsigned Index, Value *Op0, Value *Op1);
-
+ InstructionCost
+ getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+ Type *AccumType, ElementCount VF,
+ TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend,
+ std::optional<unsigned> BinOp = std::nullopt) const;
TTI::ReductionShuffle
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;
diff --git a/llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll b/llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll
new file mode 100644
index 00000000000000..65487cef4d4ee9
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll
@@ -0,0 +1,402 @@
+; RUN: opt -mattr=+simd128 -passes=loop-vectorize %s | llc -mtriple=wasm32 -mattr=+simd128 -verify-machineinstrs -o - | FileCheck %s
+; RUN: opt -mattr=+simd128 -passes=loop-vectorize -vectorizer-maximize-bandwidth %s | llc -mtriple=wasm32 -mattr=+simd128 -verify-machineinstrs -o - | FileCheck %s --check-prefix=MAX-BANDWIDTH
+
+target triple = "wasm32"
+
+define hidden i32 @i32_mac_s8(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i32_mac_s8:
+; CHECK: v128.load32_zero 0:p2align=0
+; CHECK: i16x8.extend_low_i8x16_s
+; CHECK: v128.load32_zero 0:p2align=0
+; CHECK: i16x8.extend_low_i8x16_s
+; CHECK: i32x4.extmul_low_i16x8_s
+; CHECK: i32x4.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i16x8.extend_low_i8x16_s
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i16x8.extend_low_i8x16_s
+; MAX-BANDWIDTH: i32x4.dot_i16x8_s
+; MAX-BANDWIDTH: i16x8.extend_high_i8x16_s
+; MAX-BANDWIDTH: i16x8.extend_high_i8x16_s
+; MAX-BANDWIDTH: i32x4.dot_i16x8_s
+; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
+
+entry:
+ %cmp7.not = icmp eq i32 %N, 0
+ br i1 %cmp7.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup: ; preds = %for.body, %entry
+ %res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
+ ret i32 %res.0.lcssa
+
+for.body: ; preds = %entry, %for.body
+ %i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+ %res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i8, ptr %a, i32 %i.09
+ %0 = load i8, ptr %arrayidx, align 1
+ %conv = sext i8 %0 to i32
+ %arrayidx1 = getelementptr inbounds i8, ptr %b, i32 %i.09
+ %1 = load i8, ptr %arrayidx1, align 1
+ %conv2 = sext i8 %1 to i32
+ %mul = mul nsw i32 %conv2, %conv
+ %add = add nsw i32 %mul, %res.08
+ %inc = add nuw i32 %i.09, 1
+ %exitcond.not = icmp eq i32 %inc, %N
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i32 @i32_mac_s16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i32_mac_s16:
+; CHECK: i32x4.load16x4_s 0:p2align=1
+; CHECK: i32x4.load16x4_s 0:p2align=1
+; CHECK: i32x4.mul
+; CHECK: i32x4.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i32x4.dot_i16x8_s
+
+entry:
+ %cmp7.not = icmp eq i32 %N, 0
+ br i1 %cmp7.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup: ; preds = %for.body, %entry
+ %res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
+ ret i32 %res.0.lcssa
+
+for.body: ; preds = %entry, %for.body
+ %i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+ %res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.09
+ %0 = load i16, ptr %arrayidx, align 2
+ %conv = sext i16 %0 to i32
+ %arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.09
+ %1 = load i16, ptr %arrayidx1, align 2
+ %conv2 = sext i16 %1 to i32
+ %mul = mul nsw i32 %conv2, %conv
+ %add = add nsw i32 %mul, %res.08
+ %inc = add nuw i32 %i.09, 1
+ %exitcond.not = icmp eq i32 %inc, %N
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i64 @i64_mac_s16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i64_mac_s16:
+; CHECK: v128.load32_zero 0:p2align=1
+; CHECK: i32x4.extend_low_i16x8_s
+; CHECK: v128.load32_zero 0:p2align=1
+; CHECK: i32x4.extend_low_i16x8_s
+; CHECK: i64x2.extmul_low_i32x4_s
+; CHECK: i64x2.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i8x16.shuffle 12, 13, 14, 15, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i8x16.shuffle 12, 13, 14, 15, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i8x16.shuffle 8, 9, 10, 11, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i8x16.shuffle 8, 9, 10, 11, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i8x16.shuffle 4, 5, 6, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i8x16.shuffle 4, 5, 6, 7, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i32x4.extend_low_i16x8_s
+; MAX-BANDWIDTH: i64x2.extmul_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+entry:
+ %cmp7.not = icmp eq i32 %N, 0
+ br i1 %cmp7.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup: ; preds = %for.body, %entry
+ %res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+ ret i64 %res.0.lcssa
+
+for.body: ; preds = %entry, %for.body
+ %i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+ %res.08 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.09
+ %0 = load i16, ptr %arrayidx, align 2
+ %conv = sext i16 %0 to i64
+ %arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.09
+ %1 = load i16, ptr %arrayidx1, align 2
+ %conv2 = sext i16 %1 to i64
+ %mul = mul nsw i64 %conv2, %conv
+ %add = add nsw i64 %mul, %res.08
+ %inc = add nuw i32 %i.09, 1
+ %exitcond.not = icmp eq i32 %inc, %N
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i64 @i64_mac_s32(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i64_mac_s32:
+; CHECK: v128.load64_zero 0:p2align=2
+; CHECK: v128.load64_zero 0:p2align=2
+; CHECK: i32x4.mul
+; CHECK: i64x2.extend_low_i32x4_s
+; CHECK: i64x2.add
+
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i32x4.mul
+; MAX-BANDWIDTH: i64x2.extend_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+; MAX-BANDWIDTH: i8x16.shuffle 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3
+; MAX-BANDWIDTH: i64x2.extend_low_i32x4_s
+; MAX-BANDWIDTH: i64x2.add
+
+entry:
+ %cmp6.not = icmp eq i32 %N, 0
+ br i1 %cmp6.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup: ; preds = %for.body, %entry
+ %res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
+ ret i64 %res.0.lcssa
+
+for.body: ; preds = %entry, %for.body
+ %i.08 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+ %res.07 = phi i64 [ %add, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i32, ptr %a, i32 %i.08
+ %0 = load i32, ptr %arrayidx, align 4
+ %arrayidx1 = getelementptr inbounds i32, ptr %b, i32 %i.08
+ %1 = load i32, ptr %arrayidx1, align 4
+ %mul = mul i32 %1, %0
+ %conv = sext i32 %mul to i64
+ %add = add i64 %res.07, %conv
+ %inc = add nuw i32 %i.08, 1
+ %exitcond.not = icmp eq i32 %inc, %N
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+define hidden i32 @i32_mac_u8(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
+; CHECK-LABEL: i32_mac_u8:
+; CHECK: v128.load32_zero 0:p2align=0
+; CHECK: i16x8.extend_low_i8x...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Ping? |
Thanks for the PR, the code looks pretty reasonable. I'm curious about the partial.reduce.add intrinsic since I'm not really familiar with it. Is this something that the autovectorizer will generate by default now (or maybe now that you've implemented |
Correct, but with the caveat that it won't trigger at the moment due to the vectorization factor selected. For our dot, we need a factor of 8 or 16 but, AFAICT, the vectorizer will not go above 4 as we are using an i32 for the arithmetic. The tests that shows dot being generated are doing so because I've used the |
And thanks for taking a look! |
Enable the use of partial.reduce.add that we can lower to dot or a tree of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support both v8i16 and v16i8 inputs.
dc22605
to
ec82a32
Compare
Ping? |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/73/builds/12914 Here is the relevant piece of the build log for the reference
|
Enable the use of partial.reduce.add that we can lower to dot or a tree of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support both v8i16 and v16i8 inputs.
Enable the use of partial.reduce.add that we can lower to dot or a tree of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support both v8i16 and v16i8 inputs.