Skip to content

Commit 9396663

Browse files
authored
[SDAG] Add partial_reduce_sumla node (#141267)
We have recently added the partial_reduce_smla and partial_reduce_umla nodes to represent Acc += ext(b) * ext(b) where the two extends have to have the same source type, and have the same extend kind. For riscv64 w/zvqdotq, we have the vqdot and vqdotu instructions which correspond to the existing nodes, but we also have vqdotsu which represents the case where the two extends are sign and zero respective (i.e. not the same type of extend). This patch adds a partial_reduce_sumla node which has sign extension for A, and zero extension for B. The addition is somewhat mechanical.
1 parent f32b756 commit 9396663

File tree

12 files changed

+510
-305
lines changed

12 files changed

+510
-305
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,8 +1493,9 @@ enum NodeType {
14931493
VECREDUCE_UMIN,
14941494

14951495
// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
1496-
// The partial reduction nodes sign or zero extend Input1 and Input2 to the
1497-
// element type of Accumulator before multiplying their results.
1496+
// The partial reduction nodes sign or zero extend Input1 and Input2
1497+
// (with the extension kind noted below) to the element type of
1498+
// Accumulator before multiplying their results.
14981499
// This result is concatenated to the Accumulator, and this is then reduced,
14991500
// using addition, to the result type.
15001501
// The output is only expected to either be given to another partial reduction
@@ -1506,8 +1507,9 @@ enum NodeType {
15061507
// multiple of the number of elements in the Accumulator / output type.
15071508
// Input1 and Input2 must have an element type which is the same as or smaller
15081509
// than the element type of the Accumulator and output.
1509-
PARTIAL_REDUCE_SMLA,
1510-
PARTIAL_REDUCE_UMLA,
1510+
PARTIAL_REDUCE_SMLA, // sext, sext
1511+
PARTIAL_REDUCE_UMLA, // zext, zext
1512+
PARTIAL_REDUCE_SUMLA, // sext, zext
15111513

15121514
// The `llvm.experimental.stackmap` intrinsic.
15131515
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,8 @@ class LLVM_ABI TargetLoweringBase {
16611661
/// target has a custom expander for it.
16621662
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
16631663
EVT InputVT) const {
1664-
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
1664+
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
1665+
Opc == ISD::PARTIAL_REDUCE_SUMLA);
16651666
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
16661667
InputVT.getSimpleVT().SimpleTy};
16671668
auto It = PartialReduceMLAActions.find(Key);
@@ -2759,7 +2760,8 @@ class LLVM_ABI TargetLoweringBase {
27592760
/// sequence, or the target has a custom expander for it.
27602761
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
27612762
LegalizeAction Action) {
2762-
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
2763+
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
2764+
Opc == ISD::PARTIAL_REDUCE_SUMLA);
27632765
assert(AccVT.isValid() && InputVT.isValid() &&
27642766
"setPartialReduceMLAAction types aren't valid");
27652767
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
19921992
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
19931993
case ISD::PARTIAL_REDUCE_SMLA:
19941994
case ISD::PARTIAL_REDUCE_UMLA:
1995+
case ISD::PARTIAL_REDUCE_SUMLA:
19951996
return visitPARTIAL_REDUCE_MLA(N);
19961997
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
19971998
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -12737,26 +12738,27 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1273712738
SDValue LHSExtOp = LHS->getOperand(0);
1273812739
EVT LHSExtOpVT = LHSExtOp.getValueType();
1273912740

12740-
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12741-
unsigned NewOpcode =
12742-
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12743-
12744-
// Only perform these combines if the target supports folding
12745-
// the extends into the operation.
12746-
if (!TLI.isPartialReduceMLALegalOrCustom(
12747-
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12748-
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12749-
return SDValue();
12750-
1275112741
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1275212742
// -> partial_reduce_*mla(acc, x, C)
1275312743
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
12744+
// TODO: Make use of partial_reduce_sumla here
1275412745
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
1275512746
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
1275612747
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
1275712748
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
1275812749
return SDValue();
1275912750

12751+
unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
12752+
? ISD::PARTIAL_REDUCE_SMLA
12753+
: ISD::PARTIAL_REDUCE_UMLA;
12754+
12755+
// Only perform these combines if the target supports folding
12756+
// the extends into the operation.
12757+
if (!TLI.isPartialReduceMLALegalOrCustom(
12758+
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12759+
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12760+
return SDValue();
12761+
1276012762
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
1276112763
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
1276212764
}
@@ -12766,26 +12768,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1276612768
return SDValue();
1276712769

1276812770
SDValue RHSExtOp = RHS->getOperand(0);
12769-
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
12771+
if (LHSExtOpVT != RHSExtOp.getValueType())
12772+
return SDValue();
12773+
12774+
unsigned NewOpc;
12775+
if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
12776+
NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12777+
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12778+
NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12779+
else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12780+
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12781+
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
12782+
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12783+
std::swap(LHSExtOp, RHSExtOp);
12784+
} else
1277012785
return SDValue();
12771-
12772-
// For a 2-stage extend the signedness of both of the extends must be the
12773-
// same. This is so the node can be folded into only a signed or unsigned
12774-
// node.
12775-
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12786+
// For a 2-stage extend the signedness of both of the extends must match
12787+
// If the mul has the same type, there is no outer extend, and thus we
12788+
// can simply use the inner extends to pick the result node.
12789+
// TODO: extend to handle nonneg zext as sext
1277612790
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12777-
if (ExtIsSigned != NodeIsSigned &&
12778-
Op1.getValueType().getVectorElementType() != AccElemVT)
12791+
if (Op1.getValueType().getVectorElementType() != AccElemVT &&
12792+
NewOpc != N->getOpcode())
1277912793
return SDValue();
1278012794

12781-
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12782-
RHSExtOp);
12795+
// Only perform these combines if the target supports folding
12796+
// the extends into the operation.
12797+
if (!TLI.isPartialReduceMLALegalOrCustom(
12798+
NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12799+
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12800+
return SDValue();
12801+
12802+
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
1278312803
}
1278412804

1278512805
// partial.reduce.umla(acc, zext(op), splat(1))
1278612806
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
1278712807
// partial.reduce.smla(acc, sext(op), splat(1))
1278812808
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12809+
// partial.reduce.sumla(acc, sext(op), splat(1))
12810+
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
1278912811
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1279012812
SDLoc DL(N);
1279112813
SDValue Acc = N->getOperand(0);
@@ -12802,7 +12824,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1280212824
return SDValue();
1280312825

1280412826
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12805-
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12827+
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
1280612828
EVT AccElemVT = Acc.getValueType().getVectorElementType();
1280712829
if (Op1IsSigned != NodeIsSigned &&
1280812830
Op1.getValueType().getVectorElementType() != AccElemVT)

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
166166

167167
case ISD::PARTIAL_REDUCE_UMLA:
168168
case ISD::PARTIAL_REDUCE_SMLA:
169+
case ISD::PARTIAL_REDUCE_SUMLA:
169170
Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
170171
break;
171172

@@ -2093,6 +2094,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
20932094
break;
20942095
case ISD::PARTIAL_REDUCE_UMLA:
20952096
case ISD::PARTIAL_REDUCE_SMLA:
2097+
case ISD::PARTIAL_REDUCE_SUMLA:
20962098
Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
20972099
break;
20982100
}
@@ -2886,12 +2888,21 @@ SDValue DAGTypeLegalizer::PromoteIntOp_GET_ACTIVE_LANE_MASK(SDNode *N) {
28862888

28872889
SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
28882890
SmallVector<SDValue, 1> NewOps(N->ops());
2889-
if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
2891+
switch (N->getOpcode()) {
2892+
case ISD::PARTIAL_REDUCE_SMLA:
28902893
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
28912894
NewOps[2] = SExtPromotedInteger(N->getOperand(2));
2892-
} else {
2895+
break;
2896+
case ISD::PARTIAL_REDUCE_UMLA:
28932897
NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
28942898
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
2899+
break;
2900+
case ISD::PARTIAL_REDUCE_SUMLA:
2901+
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
2902+
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
2903+
break;
2904+
default:
2905+
llvm_unreachable("unexpected opcode");
28952906
}
28962907
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
28972908
}

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
530530
}
531531
case ISD::PARTIAL_REDUCE_UMLA:
532532
case ISD::PARTIAL_REDUCE_SMLA:
533+
case ISD::PARTIAL_REDUCE_SUMLA:
533534
Action =
534535
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
535536
Node->getOperand(1).getValueType());
@@ -1211,6 +1212,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
12111212
return;
12121213
case ISD::PARTIAL_REDUCE_UMLA:
12131214
case ISD::PARTIAL_REDUCE_SMLA:
1215+
case ISD::PARTIAL_REDUCE_SUMLA:
12141216
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
12151217
return;
12161218
case ISD::VECREDUCE_SEQ_FADD:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13871387
break;
13881388
case ISD::PARTIAL_REDUCE_UMLA:
13891389
case ISD::PARTIAL_REDUCE_SMLA:
1390+
case ISD::PARTIAL_REDUCE_SUMLA:
13901391
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
13911392
break;
13921393
case ISD::GET_ACTIVE_LANE_MASK:
@@ -3473,6 +3474,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
34733474
break;
34743475
case ISD::PARTIAL_REDUCE_UMLA:
34753476
case ISD::PARTIAL_REDUCE_SMLA:
3477+
case ISD::PARTIAL_REDUCE_SUMLA:
34763478
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
34773479
break;
34783480
}

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7981,7 +7981,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
79817981
break;
79827982
}
79837983
case ISD::PARTIAL_REDUCE_UMLA:
7984-
case ISD::PARTIAL_REDUCE_SMLA: {
7984+
case ISD::PARTIAL_REDUCE_SMLA:
7985+
case ISD::PARTIAL_REDUCE_SUMLA: {
79857986
[[maybe_unused]] EVT AccVT = N1.getValueType();
79867987
[[maybe_unused]] EVT Input1VT = N2.getValueType();
79877988
[[maybe_unused]] EVT Input2VT = N3.getValueType();

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
585585
return "partial_reduce_umla";
586586
case ISD::PARTIAL_REDUCE_SMLA:
587587
return "partial_reduce_smla";
588+
case ISD::PARTIAL_REDUCE_SUMLA:
589+
return "partial_reduce_sumla";
588590

589591
// Vector Predication
590592
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11891,13 +11891,17 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1189111891
EVT ExtMulOpVT =
1189211892
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
1189311893
MulOpVT.getVectorElementCount());
11894-
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
11895-
? ISD::SIGN_EXTEND
11896-
: ISD::ZERO_EXTEND;
11894+
11895+
unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
11896+
? ISD::ZERO_EXTEND
11897+
: ISD::SIGN_EXTEND;
11898+
unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
11899+
? ISD::SIGN_EXTEND
11900+
: ISD::ZERO_EXTEND;
1189711901

1189811902
if (ExtMulOpVT != MulOpVT) {
11899-
MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
11900-
MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
11903+
MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
11904+
MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
1190111905
}
1190211906
SDValue Input = MulLHS;
1190311907
APInt ConstantOne;

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,7 +1574,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15741574
// zve32x is broken for partial_reduce_umla, but let's not make it worse.
15751575
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
15761576
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1577-
ISD::PARTIAL_REDUCE_UMLA};
1577+
ISD::PARTIAL_REDUCE_UMLA,
1578+
ISD::PARTIAL_REDUCE_SUMLA};
15781579
setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
15791580
setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
15801581
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
@@ -8318,6 +8319,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
83188319
return lowerADJUST_TRAMPOLINE(Op, DAG);
83198320
case ISD::PARTIAL_REDUCE_UMLA:
83208321
case ISD::PARTIAL_REDUCE_SMLA:
8322+
case ISD::PARTIAL_REDUCE_SUMLA:
83218323
return lowerPARTIAL_REDUCE_MLA(Op, DAG);
83228324
}
83238325
}
@@ -8534,8 +8536,20 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
85348536
B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
85358537
}
85368538

8537-
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
8538-
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
8539+
unsigned Opc;
8540+
switch (Op.getOpcode()) {
8541+
case ISD::PARTIAL_REDUCE_SMLA:
8542+
Opc = RISCVISD::VQDOT_VL;
8543+
break;
8544+
case ISD::PARTIAL_REDUCE_UMLA:
8545+
Opc = RISCVISD::VQDOTU_VL;
8546+
break;
8547+
case ISD::PARTIAL_REDUCE_SUMLA:
8548+
Opc = RISCVISD::VQDOTSU_VL;
8549+
break;
8550+
default:
8551+
llvm_unreachable("Unexpected opcode");
8552+
}
85398553
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
85408554
SDValue Res = DAG.getNode(Opc, DL, ContainerVT, {A, B, Accum, Mask, VL});
85418555
if (VT.isFixedLengthVector())

0 commit comments

Comments
 (0)