Skip to content

Commit 1651aa2

Browse files
authored
[SDAG] Split the partial reduce legalize table by opcode [nfc] (#141970)
On it's own, this change should be non-functional. This is a preparatory change for #141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
1 parent beb6972 commit 1651aa2

File tree

5 files changed

+61
-43
lines changed

5 files changed

+61
-43
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,17 +1659,20 @@ class LLVM_ABI TargetLoweringBase {
16591659
/// InputVT should be treated. Either it's legal, needs to be promoted to a
16601660
/// larger size, needs to be expanded to some other code sequence, or the
16611661
/// target has a custom expander for it.
1662-
LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
1663-
PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
1664-
InputVT.getSimpleVT().SimpleTy};
1665-
auto It = PartialReduceMLAActions.find(TypePair);
1662+
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
1663+
EVT InputVT) const {
1664+
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
1665+
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
1666+
InputVT.getSimpleVT().SimpleTy};
1667+
auto It = PartialReduceMLAActions.find(Key);
16661668
return It != PartialReduceMLAActions.end() ? It->second : Expand;
16671669
}
16681670

16691671
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
16701672
/// legal or custom for this target.
1671-
bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
1672-
LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
1673+
bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT,
1674+
EVT InputVT) const {
1675+
LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT);
16731676
return Action == Legal || Action == Custom;
16741677
}
16751678

@@ -2754,12 +2757,18 @@ class LLVM_ABI TargetLoweringBase {
27542757
/// type InputVT should be treated by the target. Either it's legal, needs to
27552758
/// be promoted to a larger size, needs to be expanded to some other code
27562759
/// sequence, or the target has a custom expander for it.
2757-
void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
2760+
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
27582761
LegalizeAction Action) {
2762+
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
27592763
assert(AccVT.isValid() && InputVT.isValid() &&
27602764
"setPartialReduceMLAAction types aren't valid");
2761-
PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
2762-
PartialReduceMLAActions[TypePair] = Action;
2765+
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
2766+
PartialReduceMLAActions[Key] = Action;
2767+
}
2768+
void setPartialReduceMLAAction(ArrayRef<unsigned> Opcodes, MVT AccVT,
2769+
MVT InputVT, LegalizeAction Action) {
2770+
for (unsigned Opc : Opcodes)
2771+
setPartialReduceMLAAction(Opc, AccVT, InputVT, Action);
27632772
}
27642773

27652774
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
@@ -3751,10 +3760,10 @@ class LLVM_ABI TargetLoweringBase {
37513760
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
37523761

37533762
using PartialReduceActionTypes =
3754-
std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
3755-
/// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
3756-
/// nodes, keep a LegalizeAction which indicates how instruction selection
3757-
/// should deal with this operation.
3763+
std::tuple<unsigned, MVT::SimpleValueType, MVT::SimpleValueType>;
3764+
/// For each partial reduce opcode, result type and input type combination,
3765+
/// keep a LegalizeAction which indicates how instruction selection should
3766+
/// deal with this operation.
37583767
DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;
37593768

37603769
ValueTypeActionImpl ValueTypeActions;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1267312673
SDValue LHSExtOp = LHS->getOperand(0);
1267412674
EVT LHSExtOpVT = LHSExtOp.getValueType();
1267512675

12676+
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12677+
unsigned NewOpcode =
12678+
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12679+
1267612680
// Only perform these combines if the target supports folding
1267712681
// the extends into the operation.
1267812682
if (!TLI.isPartialReduceMLALegalOrCustom(
12679-
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12683+
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
1268012684
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1268112685
return SDValue();
1268212686

12683-
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12684-
unsigned NewOpcode =
12685-
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12686-
1268712687
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1268812688
// -> partial_reduce_*mla(acc, x, C)
1268912689
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
@@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1273712737
if (!ISD::isExtOpcode(Op1Opcode))
1273812738
return SDValue();
1273912739

12740-
SDValue UnextOp1 = Op1.getOperand(0);
12741-
EVT UnextOp1VT = UnextOp1.getValueType();
12742-
auto *Context = DAG.getContext();
12743-
if (!TLI.isPartialReduceMLALegalOrCustom(
12744-
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12745-
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
12746-
return SDValue();
12747-
1274812740
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
1274912741
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
1275012742
EVT AccElemVT = Acc.getValueType().getVectorElementType();
@@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1275412746

1275512747
unsigned NewOpcode =
1275612748
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12749+
12750+
SDValue UnextOp1 = Op1.getOperand(0);
12751+
EVT UnextOp1VT = UnextOp1.getValueType();
12752+
auto *Context = DAG.getContext();
12753+
if (!TLI.isPartialReduceMLALegalOrCustom(
12754+
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12755+
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
12756+
return SDValue();
12757+
1275712758
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
1275812759
DAG.getConstant(1, DL, UnextOp1VT));
1275912760
}

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
530530
}
531531
case ISD::PARTIAL_REDUCE_UMLA:
532532
case ISD::PARTIAL_REDUCE_SMLA:
533-
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
534-
Node->getOperand(1).getValueType());
533+
Action =
534+
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
535+
Node->getOperand(1).getValueType());
535536
break;
536537

537538
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14581458
setOperationAction(ISD::FADD, VT, Custom);
14591459

14601460
if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
1461-
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1462-
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
1463-
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1461+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1462+
ISD::PARTIAL_REDUCE_UMLA};
1463+
1464+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
1465+
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
1466+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
14641467
}
14651468

14661469
} else /* !isNeonAvailable */ {
@@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18811884
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
18821885
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
18831886
// Other pairs will default to 'Expand'.
1884-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
1885-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
1887+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1888+
ISD::PARTIAL_REDUCE_UMLA};
1889+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal);
1890+
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal);
18861891

1887-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
1892+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
18881893

18891894
// Wide add types
18901895
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
1891-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
1892-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
1893-
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
1896+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
1897+
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
1898+
setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
18941899
}
18951900
}
18961901

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,11 +1575,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15751575

15761576
// zve32x is broken for partial_reduce_umla, but let's not make it worse.
15771577
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
1578-
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
1579-
setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
1580-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
1581-
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
1582-
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
1578+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1579+
ISD::PARTIAL_REDUCE_UMLA};
1580+
setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
1581+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
1582+
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
1583+
setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom);
1584+
setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom);
15831585

15841586
if (Subtarget.useRVVForFixedLengthVectors()) {
15851587
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
@@ -1588,7 +1590,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15881590
continue;
15891591
ElementCount EC = VT.getVectorElementCount();
15901592
MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
1591-
setPartialReduceMLAAction(VT, ArgVT, Custom);
1593+
setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom);
15921594
}
15931595
}
15941596
}

0 commit comments

Comments
 (0)