-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Lower PARTIAL_REDUCE_[S/U]MLA via zvqdotq #140950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1571,6 +1571,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, | |
setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal); | ||
} | ||
|
||
// zve32x is broken for partial_reduce_umla, but let's not make it worse. | ||
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) { | ||
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom); | ||
setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom); | ||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom); | ||
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom); | ||
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom); | ||
} | ||
|
||
// Function alignments. | ||
const Align FunctionAlignment(Subtarget.hasStdExtCOrZca() ? 2 : 4); | ||
setMinFunctionAlignment(FunctionAlignment); | ||
|
@@ -8229,6 +8238,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, | |
return lowerINIT_TRAMPOLINE(Op, DAG); | ||
case ISD::ADJUST_TRAMPOLINE: | ||
return lowerADJUST_TRAMPOLINE(Op, DAG); | ||
case ISD::PARTIAL_REDUCE_UMLA: | ||
case ISD::PARTIAL_REDUCE_SMLA: | ||
return lowerPARTIAL_REDUCE_MLA(Op, DAG); | ||
} | ||
} | ||
|
||
|
@@ -8364,6 +8376,27 @@ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op, | |
return Op.getOperand(0); | ||
} | ||
|
||
SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like there's tablegen nodes defined for partial_reduce_{u,s}mla, could we mark the node as legal instead of custom and patterns in tablegen instead? (At least for scalable vectors, we'll still need this for fixed?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I briefly looked at doing this via tablegen, but decided to share the code with the reduce pattern matching. Once I get hat migrated over to the partial_reduce_mla infrastructure, I may revisit the tablegen question. |
||
SelectionDAG &DAG) const { | ||
// Currently, only the vqdot and vqdotu case (from zvqdotq) should be legal. | ||
// TODO: There are many other sub-cases we could potentially lower, are | ||
// any of them worthwhile? Ex: via vredsum, vwredsum, vwwmaccu, etc.. | ||
// TODO: PARTIAL_REDUCE_*MLA can't represent a vqdotsu currently. | ||
SDLoc DL(Op); | ||
MVT VT = Op.getSimpleValueType(); | ||
SDValue Accum = Op.getOperand(0); | ||
assert(Accum.getSimpleValueType() == VT && | ||
VT.getVectorElementType() == MVT::i32); | ||
SDValue A = Op.getOperand(1); | ||
SDValue B = Op.getOperand(2); | ||
assert(A.getSimpleValueType() == B.getSimpleValueType() && | ||
A.getSimpleValueType().getVectorElementType() == MVT::i8); | ||
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA; | ||
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL; | ||
auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget); | ||
return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL}); | ||
} | ||
|
||
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty, | ||
SelectionDAG &DAG, unsigned Flags) { | ||
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nxv1i32 isn't legal without Zve64/V. Marking it custom will cause the type legalizer to call replaceNodeResults with Zve32 which will assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or at least it would for normal operations, but maybe setPartialReduceMLAAction uses a different table that type legalization doesn't know about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guarded this block, but FYI, partial_reduce_umla doesn't appear to work with zve32x at all.
$./llc -mtriple=riscv64 -mattr=+zve32x -verify-machineinstrs < test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
WidenVectorResult #0: t13: nxv1i32 = partial_reduce_umla t10, t8, t12
LLVM ERROR: Do not know how to widen the result of this operator!
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. This is going to be tricky to widen. We can't put any real data into the extra result elements added by widening. They won't be consumed by the widened receiving node.
It should work for nvx2i32 though since that doesn't need to widen.