Skip to content

[RISCV] Lower PARTIAL_REDUCE_[S/U]MLA via zvqdotq #140950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal);
}

// zve32x is broken for partial_reduce_umla, but let's not make it worse.
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nxv1i32 isn't legal without Zve64/V. Marking it custom will cause the type legalizer to call replaceNodeResults with Zve32 which will assert.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least it would for normal operations, but maybe setPartialReduceMLAAction uses a different table that type legalization doesn't know about?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guarded this block, but FYI, partial_reduce_umla doesn't appear to work with zve32x at all.

$./llc -mtriple=riscv64 -mattr=+zve32x -verify-machineinstrs < test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
WidenVectorResult #0: t13: nxv1i32 = partial_reduce_umla t10, t8, t12

LLVM ERROR: Do not know how to widen the result of this operator!

Copy link
Collaborator

@topperc topperc May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. This is going to be tricky to widen. We can't put any real data into the extra result elements added by widening. They won't be consumed by the widened receiving node.

It should work for nvx2i32 though since that doesn't need to widen.

setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
}

// Function alignments.
const Align FunctionAlignment(Subtarget.hasStdExtCOrZca() ? 2 : 4);
setMinFunctionAlignment(FunctionAlignment);
Expand Down Expand Up @@ -8229,6 +8238,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerINIT_TRAMPOLINE(Op, DAG);
case ISD::ADJUST_TRAMPOLINE:
return lowerADJUST_TRAMPOLINE(Op, DAG);
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
return lowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}

Expand Down Expand Up @@ -8364,6 +8376,27 @@ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
return Op.getOperand(0);
}

SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
Copy link
Contributor

@lukel97 lukel97 May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's tablegen nodes defined for partial_reduce_{u,s}mla, could we mark the node as legal instead of custom and patterns in tablegen instead? (At least for scalable vectors, we'll still need this for fixed?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I briefly looked at doing this via tablegen, but decided to share the code with the reduce pattern matching. Once I get hat migrated over to the partial_reduce_mla infrastructure, I may revisit the tablegen question.

SelectionDAG &DAG) const {
// Currently, only the vqdot and vqdotu case (from zvqdotq) should be legal.
// TODO: There are many other sub-cases we could potentially lower, are
// any of them worthwhile? Ex: via vredsum, vwredsum, vwwmaccu, etc..
// TODO: PARTIAL_REDUCE_*MLA can't represent a vqdotsu currently.
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
SDValue Accum = Op.getOperand(0);
assert(Accum.getSimpleValueType() == VT &&
VT.getVectorElementType() == MVT::i32);
SDValue A = Op.getOperand(1);
SDValue B = Op.getOperand(2);
assert(A.getSimpleValueType() == B.getSimpleValueType() &&
A.getSimpleValueType().getVectorElementType() == MVT::i8);
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
}

static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
SelectionDAG &DAG, unsigned Flags) {
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ class RISCVTargetLowering : public TargetLowering {

SDValue lowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;

bool isEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
Expand Down
Loading
Loading