Skip to content

[RISCV] Lower PARTIAL_REDUCE_[S/U]MLA via zvqdotq #140950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 22, 2025

Conversation

preames
Copy link
Collaborator

@preames preames commented May 21, 2025

The semantics of the PARTIAL_REDUCE_SMLA with i32 result element, and i8 sources corresponds to vqdot. Analogously PARTIAL_REDUCE_UMLA corresponds to vqdotu. There is currently no vqdotsu equivalent.

This patch is a starting place. We can extend this quite a bit more, and I plan to take a look at the fixed vector lowering, the TTI hook to drive loop vectorizer, and to try to integrate the reduction based lowering I'd added for zvqdotq into this flow.

The semantics of the PARTIAL_REDUCE_SMLA with i32 result element, and
i8 sources corresponds to vqdot.  Analogously PARTIAL_REDUCE_UMLA
corresponds to vqdotu.  There is currently no vqdotsu equivalent.

This patch is a starting place.  We can extend this quite a bit more,
and I plan to take a look at the fixed vector lowering, the TTI hook
to drive loop vectorizer, and to try to integrate the reduction based
lowering I'd added for zvqdotq into this flow.
@llvmbot
Copy link
Member

llvmbot commented May 21, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

The semantics of the PARTIAL_REDUCE_SMLA with i32 result element, and i8 sources corresponds to vqdot. Analogously PARTIAL_REDUCE_UMLA corresponds to vqdotu. There is currently no vqdotsu equivalent.

This patch is a starting place. We can extend this quite a bit more, and I plan to take a look at the fixed vector lowering, the TTI hook to drive loop vectorizer, and to try to integrate the reduction based lowering I'd added for zvqdotq into this flow.


Patch is 25.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140950.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+32)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+1)
  • (modified) llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll (+322-209)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d69e04a9912a2..59f43761b4105 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1571,6 +1571,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal);
   }
 
+  if (Subtarget.hasStdExtZvqdotq()) {
+    setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
+    setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
+    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
+    setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
+    setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
+  }
+
   // Function alignments.
   const Align FunctionAlignment(Subtarget.hasStdExtCOrZca() ? 2 : 4);
   setMinFunctionAlignment(FunctionAlignment);
@@ -8229,6 +8237,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerINIT_TRAMPOLINE(Op, DAG);
   case ISD::ADJUST_TRAMPOLINE:
     return lowerADJUST_TRAMPOLINE(Op, DAG);
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    return lowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
 
@@ -8364,6 +8375,27 @@ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
   return Op.getOperand(0);
 }
 
+SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
+                                                     SelectionDAG &DAG) const {
+  // Currently, only the vqdot and vqdotu case (from zvqdotq) hould be legal.
+  // TODO: There are many other sub-cases we could potentially lower, are
+  // any of them worthwhile?  Ex: via vredsum, vwredsum, vwwmaccu, etc..
+  // TODO: PARTIAL_REDUCE_*MLA can't represent a vqdotsu currently.
+  SDLoc DL(Op);
+  MVT VT = Op.getSimpleValueType();
+  SDValue Accum = Op.getOperand(0);
+  assert(Accum.getSimpleValueType() == VT &&
+         VT.getVectorElementType() == MVT::i32);
+  SDValue A = Op.getOperand(1);
+  SDValue B = Op.getOperand(2);
+  assert(A.getSimpleValueType() == B.getSimpleValueType() &&
+         A.getSimpleValueType().getVectorElementType() == MVT::i8);
+  bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+  auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+  return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
+}
+
 static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
                              SelectionDAG &DAG, unsigned Flags) {
   return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index fc8d8b8ce1b56..78f2044ba83a7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -552,6 +552,7 @@ class RISCVTargetLowering : public TargetLowering {
 
   SDValue lowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
 
   bool isEligibleForTailCallOptimization(
       CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 6df628e3bd812..2bd2ef2878fd5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -524,22 +524,30 @@ entry:
 
 
 define <vscale x 1 x i32> @partial_reduce_nf2(<vscale x 4 x i8> %a, <vscale x 4 x i8> %b) {
-; CHECK-LABEL: partial_reduce_nf2:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vsext.vf2 v11, v9
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    vwmul.vv v8, v10, v11
-; CHECK-NEXT:    srli a0, a0, 3
-; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vx v10, v9, a0
-; CHECK-NEXT:    vslidedown.vx v11, v8, a0
-; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    vadd.vv v9, v11, v9
-; CHECK-NEXT:    vadd.vv v8, v9, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_nf2:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vsext.vf2 v11, v9
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    vwmul.vv v8, v10, v11
+; NODOT-NEXT:    srli a0, a0, 3
+; NODOT-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
+; NODOT-NEXT:    vslidedown.vx v10, v9, a0
+; NODOT-NEXT:    vslidedown.vx v11, v8, a0
+; NODOT-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v10, v8
+; NODOT-NEXT:    vadd.vv v9, v11, v9
+; NODOT-NEXT:    vadd.vv v8, v9, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_nf2:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vmv1r.v v8, v10
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 4 x i8> %a to <vscale x 4 x i32>
   %b.sext = sext <vscale x 4 x i8> %b to <vscale x 4 x i32>
@@ -549,17 +557,25 @@ entry:
 }
 
 define <vscale x 2 x i32> @partial_reduce_m1(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m1:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vsext.vf2 v14, v9
-; CHECK-NEXT:    vwmul.vv v8, v12, v14
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v11, v8
-; CHECK-NEXT:    vadd.vv v9, v9, v10
-; CHECK-NEXT:    vadd.vv v8, v9, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_m1:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vsext.vf2 v14, v9
+; NODOT-NEXT:    vwmul.vv v8, v12, v14
+; NODOT-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v11, v8
+; NODOT-NEXT:    vadd.vv v9, v9, v10
+; NODOT-NEXT:    vadd.vv v8, v9, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_m1:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vmv.v.v v8, v10
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 8 x i8> %a to <vscale x 8 x i32>
   %b.sext = sext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -569,17 +585,25 @@ entry:
 }
 
 define <vscale x 4 x i32> @partial_reduce_m2(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m2:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v16, v8
-; CHECK-NEXT:    vsext.vf2 v20, v10
-; CHECK-NEXT:    vwmul.vv v8, v16, v20
-; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v14, v8
-; CHECK-NEXT:    vadd.vv v10, v10, v12
-; CHECK-NEXT:    vadd.vv v8, v10, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_m2:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vsext.vf2 v20, v10
+; NODOT-NEXT:    vwmul.vv v8, v16, v20
+; NODOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v14, v8
+; NODOT-NEXT:    vadd.vv v10, v10, v12
+; NODOT-NEXT:    vadd.vv v8, v10, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_m2:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdot.vv v12, v8, v10
+; DOT-NEXT:    vmv.v.v v8, v12
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -589,20 +613,28 @@ entry:
 }
 
 define <vscale x 8 x i32> @partial_reduce_m4(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m4:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v24, v8
-; CHECK-NEXT:    vsext.vf2 v16, v10
-; CHECK-NEXT:    vsext.vf2 v28, v12
-; CHECK-NEXT:    vsext.vf2 v20, v14
-; CHECK-NEXT:    vwmul.vv v8, v16, v20
-; CHECK-NEXT:    vwmul.vv v16, v24, v28
-; CHECK-NEXT:    vsetvli a0, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vadd.vv v16, v20, v16
-; CHECK-NEXT:    vadd.vv v8, v12, v8
-; CHECK-NEXT:    vadd.vv v8, v8, v16
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_m4:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v24, v8
+; NODOT-NEXT:    vsext.vf2 v16, v10
+; NODOT-NEXT:    vsext.vf2 v28, v12
+; NODOT-NEXT:    vsext.vf2 v20, v14
+; NODOT-NEXT:    vwmul.vv v8, v16, v20
+; NODOT-NEXT:    vwmul.vv v16, v24, v28
+; NODOT-NEXT:    vsetvli a0, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vadd.vv v16, v20, v16
+; NODOT-NEXT:    vadd.vv v8, v12, v8
+; NODOT-NEXT:    vadd.vv v8, v8, v16
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_m4:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m4, ta, ma
+; DOT-NEXT:    vmv.v.i v16, 0
+; DOT-NEXT:    vqdot.vv v16, v8, v12
+; DOT-NEXT:    vmv.v.v v8, v16
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 32 x i8> %a to <vscale x 32 x i32>
   %b.sext = sext <vscale x 32 x i8> %b to <vscale x 32 x i32>
@@ -612,38 +644,46 @@ entry:
 }
 
 define <vscale x 16 x i32> @partial_reduce_m8(<vscale x 64 x i8> %a, <vscale x 64 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m8:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    addi sp, sp, -16
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 2
-; CHECK-NEXT:    sub sp, sp, a0
-; CHECK-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x04, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 4 * vlenb
-; CHECK-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v24, v10
-; CHECK-NEXT:    addi a0, sp, 16
-; CHECK-NEXT:    vs4r.v v24, (a0) # vscale x 32-byte Folded Spill
-; CHECK-NEXT:    vsext.vf2 v0, v8
-; CHECK-NEXT:    vsext.vf2 v8, v18
-; CHECK-NEXT:    vsext.vf2 v4, v16
-; CHECK-NEXT:    vwmul.vv v24, v0, v4
-; CHECK-NEXT:    vl4r.v v16, (a0) # vscale x 32-byte Folded Reload
-; CHECK-NEXT:    vwmacc.vv v24, v16, v8
-; CHECK-NEXT:    vsext.vf2 v8, v12
-; CHECK-NEXT:    vsext.vf2 v16, v20
-; CHECK-NEXT:    vwmacc.vv v24, v8, v16
-; CHECK-NEXT:    vsext.vf2 v8, v14
-; CHECK-NEXT:    vsext.vf2 v12, v22
-; CHECK-NEXT:    vwmacc.vv v24, v8, v12
-; CHECK-NEXT:    vmv8r.v v8, v24
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 2
-; CHECK-NEXT:    add sp, sp, a0
-; CHECK-NEXT:    .cfi_def_cfa sp, 16
-; CHECK-NEXT:    addi sp, sp, 16
-; CHECK-NEXT:    .cfi_def_cfa_offset 0
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_m8:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    addi sp, sp, -16
+; NODOT-NEXT:    .cfi_def_cfa_offset 16
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 2
+; NODOT-NEXT:    sub sp, sp, a0
+; NODOT-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x04, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 4 * vlenb
+; NODOT-NEXT:    vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v24, v10
+; NODOT-NEXT:    addi a0, sp, 16
+; NODOT-NEXT:    vs4r.v v24, (a0) # vscale x 32-byte Folded Spill
+; NODOT-NEXT:    vsext.vf2 v0, v8
+; NODOT-NEXT:    vsext.vf2 v8, v18
+; NODOT-NEXT:    vsext.vf2 v4, v16
+; NODOT-NEXT:    vwmul.vv v24, v0, v4
+; NODOT-NEXT:    vl4r.v v16, (a0) # vscale x 32-byte Folded Reload
+; NODOT-NEXT:    vwmacc.vv v24, v16, v8
+; NODOT-NEXT:    vsext.vf2 v8, v12
+; NODOT-NEXT:    vsext.vf2 v16, v20
+; NODOT-NEXT:    vwmacc.vv v24, v8, v16
+; NODOT-NEXT:    vsext.vf2 v8, v14
+; NODOT-NEXT:    vsext.vf2 v12, v22
+; NODOT-NEXT:    vwmacc.vv v24, v8, v12
+; NODOT-NEXT:    vmv8r.v v8, v24
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 2
+; NODOT-NEXT:    add sp, sp, a0
+; NODOT-NEXT:    .cfi_def_cfa sp, 16
+; NODOT-NEXT:    addi sp, sp, 16
+; NODOT-NEXT:    .cfi_def_cfa_offset 0
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_m8:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT:    vmv.v.i v24, 0
+; DOT-NEXT:    vqdot.vv v24, v8, v16
+; DOT-NEXT:    vmv.v.v v8, v24
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 64 x i8> %a to <vscale x 64 x i32>
   %b.sext = sext <vscale x 64 x i8> %b to <vscale x 64 x i32>
@@ -653,103 +693,161 @@ entry:
 }
 
 define <vscale x 32 x i32> @partial_reduce_m16(<vscale x 128 x i8> %a, <vscale x 128 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m16:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    addi sp, sp, -16
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    mv a2, a1
-; CHECK-NEXT:    slli a1, a1, 1
-; CHECK-NEXT:    add a1, a1, a2
-; CHECK-NEXT:    sub sp, sp, a1
-; CHECK-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x18, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 24 * vlenb
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 4
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    addi a1, sp, 16
-; CHECK-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    vl8r.v v16, (a0)
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    add a1, sp, a1
-; CHECK-NEXT:    addi a1, a1, 16
-; CHECK-NEXT:    vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT:    vsetvli a1, zero, e16, m4, ta, ma
-; CHECK-NEXT:    vsext.vf2 v4, v8
-; CHECK-NEXT:    vsext.vf2 v0, v16
-; CHECK-NEXT:    vwmul.vv v24, v4, v0
-; CHECK-NEXT:    vsext.vf2 v4, v10
-; CHECK-NEXT:    vsext.vf2 v8, v18
-; CHECK-NEXT:    vwmacc.vv v24, v4, v8
-; CHECK-NEXT:    csrr a1, vlenb
-; CHECK-NEXT:    slli a1, a1, 3
-; CHECK-NEXT:    add a0, a0, a1
-; CHECK-NEXT:    vsext.vf2 v0, v12
-; CHECK-NEXT:    vl8r.v v8, (a0)
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 3
-; CHECK-NEXT:    add a0, sp, a0
-; CHECK-NEXT:    addi a0, a0, 16
-; CHECK-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v4, v20
-; CHECK-NEXT:    vwmacc.vv v24, v0, v4
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 4
-; CHECK-NEXT:    add a0, sp, a0
-; CHECK-NEXT:    addi a0, a0, 16
-; CHECK-NEXT:    vl8r.v v0, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v20, v0
-; CHECK-NEXT:    vsext.vf2 v16, v8
-; CHECK-NEXT:    vwmul.vv v0, v20, v16
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 4
-; CHECK-NEXT:    add a0, sp, a0
-; CHECK-NEXT:    addi a0, a0, 16
-; CHECK-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v20, v18
-; CHECK-NEXT:    vsext.vf2 v16, v10
-; CHECK-NEXT:    vwmacc.vv v0, v20, v16
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 4
-; CHECK-NEXT:    add a0, sp, a0
-; CHECK-NEXT:    addi a0, a0, 16
-; CHECK-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v8, v20
-; CHECK-NEXT:    vsext.vf2 v16, v12
-; CHECK-NEXT:    vwmacc.vv v0, v8, v16
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 4
-; CHECK-NEXT:    add a0, sp, a0
-; CHECK-NEXT:    addi a0, a0, 16
-; CHECK-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v8, v22
-; CHECK-NEXT:    vsext.vf2 v16, v14
-; CHECK-NEXT:    vwmacc.vv v0, v8, v16
-; CHECK-NEXT:    addi a0, sp, 16
-; CHECK-NEXT:    vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v8, v14
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 3
-; CHECK-NEXT:    add a0, sp, a0
-; CHECK-NEXT:    addi a0, a0, 16
-; CHECK-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT:    vsext.vf2 v12, v22
-; CHECK-NEXT:    vwmacc.vv v24, v8, v12
-; CHECK-NEXT:    vmv8r.v v8, v24
-; CHECK-NEXT:    vmv8r.v v16, v0
-; CHECK-NEXT:    csrr a0, vlenb
-; CHECK-NEXT:    slli a0, a0, 3
-; CHECK-NEXT:    mv a1, a0
-; CHECK-NEXT:    slli a0, a0, 1
-; CHECK-NEXT:    add a0, a0, a1
-; CHECK-NEXT:    add sp, sp, a0
-; CHECK-NEXT:    .cfi_def_cfa sp, 16
-; CHECK-NEXT:    addi sp, sp, 16
-; CHECK-NEXT:    .cfi_def_cfa_offset 0
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_m16:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    addi sp, sp, -16
+; NODOT-NEXT:    .cfi_def_cfa_offset 16
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    mv a2, a1
+; NODOT-NEXT:    slli a1, a1, 1
+; NODOT-NEXT:    add a1, a1, a2
+; NODOT-NEXT:    sub sp, sp, a1
+; NODOT-NEXT:    .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x18, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 24 * vlenb
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 4
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    addi a1, sp, 16
+; NODOT-NEXT:    vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    vl8r.v v16, (a0)
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    add a1, sp, a1
+; NODOT-NEXT:    addi a1, a1, 16
+; NODOT-NEXT:    vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT:    vsetvli a1, zero, e16, m4, ta, ma
+; NODOT-NEXT:    vsext.vf2 v4, v8
+; NODOT-NEXT:    vsext.vf2 v0, v16
+; NODOT-NEXT:    vwmul.vv v24, v4, v0
+; NODOT-NEXT:    vsext.vf2 v4, v10
+; NODOT-NEXT:    vsext.vf2 v8, v18
+; NODOT-NEXT:    vwmacc.vv v24, v4, v8
+; NODOT-NEXT:    csrr a1, vlenb
+; NODOT-NEXT:    slli a1, a1, 3
+; NODOT-NEXT:    add a0, a0, a1
+; NODOT-NEXT:    vsext.vf2 v0, v12
+; NODOT-NEXT:    vl8r.v v8, (a0)
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 3
+; NODOT-NEXT:    add a0, sp, a0
+; NODOT-NEXT:    addi a0, a0, 16
+; NODOT-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v4, v20
+; NODOT-NEXT:    vwmacc.vv v24, v0, v4
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 4
+; NODOT-NEXT:    add a0, sp, a0
+; NODOT-NEXT:    addi a0, a0, 16
+; NODOT-NEXT:    vl8r.v v0, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v20, v0
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vwmul.vv v0, v20, v16
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 4
+; NODOT-NEXT:    add a0, sp, a0
+; NODOT-NEXT:    addi a0, a0, 16
+; NODOT-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v20, v18
+; NODOT-NEXT:    vsext.vf2 v16, v10
+; NODOT-NEXT:    vwmacc.vv v0, v20, v16
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 4
+; NODOT-NEXT:    add a0, sp, a0
+; NODOT-NEXT:    addi a0, a0, 16
+; NODOT-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v8, v20
+; NODOT-NEXT:    vsext.vf2 v16, v12
+; NODOT-NEXT:    vwmacc.vv v0, v8, v16
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 4
+; NODOT-NEXT:    add a0, sp, a0
+; NODOT-NEXT:    addi a0, a0, 16
+; NODOT-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v8, v22
+; NODOT-NEXT:    vsext.vf2 v16, v14
+; NODOT-NEXT:    vwmacc.vv v0, v8, v16
+; NODOT-NEXT:    addi a0, sp, 16
+; NODOT-NEXT:    vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v8, v14
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 3
+; NODOT-NEXT:    add a0, sp, a0
+; NODOT-NEXT:    addi a0, a0, 16
+; NODOT-NEXT:    vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT:    vsext.vf2 v12, v22
+; NODOT-NEXT:    vwmacc.vv v24, v8, v12
+; NODOT-NEXT:    vmv8r.v v8, v24
+; NODOT-NEXT:    vmv8r.v v16, v0
+; NODOT-NEXT:    csrr a0, vlenb
+; NODOT-NEXT:    slli a0, a0, 3
+; NODOT-NEXT:    mv a1, a0
+; NODOT-NEXT:    slli a0, a0, 1
+; NODOT-NEXT:    add a0, a0, a1
+; NODOT-NEXT:    add sp, sp, a0
+; NODOT-NEXT:    .cfi_def_cfa sp, 16
+; NODOT-NEXT:    addi sp, sp, 16
+; NODOT-NEXT:    .cfi_def_cfa_offset 0
+; NODOT-NEXT: ...
[truncated]

@@ -1571,6 +1571,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal);
}

if (Subtarget.hasStdExtZvqdotq()) {
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nxv1i32 isn't legal without Zve64/V. Marking it custom will cause the type legalizer to call replaceNodeResults with Zve32 which will assert.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least it would for normal operations, but maybe setPartialReduceMLAAction uses a different table that type legalization doesn't know about?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guarded this block, but FYI, partial_reduce_umla doesn't appear to work with zve32x at all.

$./llc -mtriple=riscv64 -mattr=+zve32x -verify-machineinstrs < test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
WidenVectorResult #0: t13: nxv1i32 = partial_reduce_umla t10, t8, t12

LLVM ERROR: Do not know how to widen the result of this operator!

Copy link
Collaborator

@topperc topperc May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. This is going to be tricky to widen. We can't put any real data into the extra result elements added by widening. They won't be consumed by the widened receiving node.

It should work for nvx2i32 though since that doesn't need to widen.

@@ -1571,6 +1571,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal);
}

// zve32x is broken for partial_reduce_umla, but let's not make it worse.
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getRealMinVLen() >= 64) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to check ELEN, not VLEN. It's still broken with ELEN=32 VLEN=64.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the exact check to use here? Is it !Subtarget.hasVInstructionsI64()?

This is hard to check since it fails either way...

Copy link
Collaborator

@topperc topperc May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to block everything, the check is Subtarget.getELen() >= 64. If you only want to disable for specific types then you have disable for any type where getVectorMinNumElements() < RISCV::RVVBitsPerBlock / Subtarget.getELen()

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 9b4de7d into llvm:main May 22, 2025
6 of 9 checks passed
@preames preames deleted the pr-zvqdot-partial-reduce-scalable-lowering branch May 22, 2025 15:29
@@ -8364,6 +8376,27 @@ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
return Op.getOperand(0);
}

SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
Copy link
Contributor

@lukel97 lukel97 May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's tablegen nodes defined for partial_reduce_{u,s}mla, could we mark the node as legal instead of custom and patterns in tablegen instead? (At least for scalable vectors, we'll still need this for fixed?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I briefly looked at doing this via tablegen, but decided to share the code with the reduce pattern matching. Once I get hat migrated over to the partial_reduce_mla infrastructure, I may revisit the tablegen question.

sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
The semantics of the PARTIAL_REDUCE_SMLA with i32 result element, and i8
sources corresponds to vqdot. Analogously PARTIAL_REDUCE_UMLA
corresponds to vqdotu. There is currently no vqdotsu equivalent.

This patch is a starting place. We can extend this quite a bit more, and
I plan to take a look at the fixed vector lowering, the TTI hook to
drive loop vectorizer, and to try to integrate the reduction based
lowering I'd added for zvqdotq into this flow.
ajaden-codes pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 6, 2025
The semantics of the PARTIAL_REDUCE_SMLA with i32 result element, and i8
sources corresponds to vqdot. Analogously PARTIAL_REDUCE_UMLA
corresponds to vqdotu. There is currently no vqdotsu equivalent.

This patch is a starting place. We can extend this quite a bit more, and
I plan to take a look at the fixed vector lowering, the TTI hook to
drive loop vectorizer, and to try to integrate the reduction based
lowering I'd added for zvqdotq into this flow.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants