Skip to content

[SDAG] Split the partial reduce legalize table by opcode [nfc] #141970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1659,17 +1659,20 @@ class LLVM_ABI TargetLoweringBase {
/// InputVT should be treated. Either it's legal, needs to be promoted to a
/// larger size, needs to be expanded to some other code sequence, or the
/// target has a custom expander for it.
LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
InputVT.getSimpleVT().SimpleTy};
auto It = PartialReduceMLAActions.find(TypePair);
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
EVT InputVT) const {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
InputVT.getSimpleVT().SimpleTy};
auto It = PartialReduceMLAActions.find(Key);
return It != PartialReduceMLAActions.end() ? It->second : Expand;
}

/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
/// legal or custom for this target.
bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT,
EVT InputVT) const {
LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT);
return Action == Legal || Action == Custom;
}

Expand Down Expand Up @@ -2754,12 +2757,18 @@ class LLVM_ABI TargetLoweringBase {
/// type InputVT should be treated by the target. Either it's legal, needs to
/// be promoted to a larger size, needs to be expanded to some other code
/// sequence, or the target has a custom expander for it.
void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
LegalizeAction Action) {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
assert(AccVT.isValid() && InputVT.isValid() &&
"setPartialReduceMLAAction types aren't valid");
PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
PartialReduceMLAActions[TypePair] = Action;
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
PartialReduceMLAActions[Key] = Action;
}
void setPartialReduceMLAAction(ArrayRef<unsigned> Opcodes, MVT AccVT,
MVT InputVT, LegalizeAction Action) {
for (unsigned Opc : Opcodes)
setPartialReduceMLAAction(Opc, AccVT, InputVT, Action);
}

/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
Expand Down Expand Up @@ -3751,10 +3760,10 @@ class LLVM_ABI TargetLoweringBase {
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];

using PartialReduceActionTypes =
std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
/// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
/// nodes, keep a LegalizeAction which indicates how instruction selection
/// should deal with this operation.
std::tuple<unsigned, MVT::SimpleValueType, MVT::SimpleValueType>;
/// For each partial reduce opcode, result type and input type combination,
/// keep a LegalizeAction which indicates how instruction selection should
/// deal with this operation.
DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;

ValueTypeActionImpl ValueTypeActions;
Expand Down
27 changes: 14 additions & 13 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();

bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;

// Only perform these combines if the target supports folding
// the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();

bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;

// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
Expand Down Expand Up @@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!ISD::isExtOpcode(Op1Opcode))
return SDValue();

SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
auto *Context = DAG.getContext();
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();

bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
Expand All @@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {

unsigned NewOpcode =
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;

SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
auto *Context = DAG.getContext();
if (!TLI.isPartialReduceMLALegalOrCustom(
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();

return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
DAG.getConstant(1, DL, UnextOp1VT));
}
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
Node->getOperand(1).getValueType());
Action =
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
Node->getOperand(1).getValueType());
break;

#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
Expand Down
23 changes: 14 additions & 9 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FADD, VT, Custom);

if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
ISD::PARTIAL_REDUCE_UMLA};

setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
}

} else /* !isNeonAvailable */ {
Expand Down Expand Up @@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
// Other pairs will default to 'Expand'.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
ISD::PARTIAL_REDUCE_UMLA};
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal);

setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);

// Wide add types
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
}
}

Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1573,11 +1573,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

// zve32x is broken for partial_reduce_umla, but let's not make it worse.
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
ISD::PARTIAL_REDUCE_UMLA};
setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom);
setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom);

if (Subtarget.useRVVForFixedLengthVectors()) {
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
Expand All @@ -1586,7 +1588,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
continue;
ElementCount EC = VT.getVectorElementCount();
MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
setPartialReduceMLAAction(VT, ArgVT, Custom);
setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom);
}
}
}
Expand Down