Skip to content

Commit b43db72

Browse files
Change from adding ISD::PARTIAL_REDUCE_S/UADD to adding
ISD::PARTIAL_REDUCE_S/UMLA This makes the lowering function easier as you do not need to worry about whether the MUL is lowered or not. Instead its operands are taken from it. If there is no MUL instruction and just one operand, the other operand is a vector of ones (for value types eligible for wide add lowering).
1 parent 9231804 commit b43db72

File tree

6 files changed

+34
-44
lines changed

6 files changed

+34
-44
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,8 +1455,8 @@ enum NodeType {
14551455
// unsigned).
14561456
// Operands: Accumulator, Input
14571457
// Outputs: Output
1458-
PARTIAL_REDUCE_SADD,
1459-
PARTIAL_REDUCE_UADD,
1458+
PARTIAL_REDUCE_SMLA,
1459+
PARTIAL_REDUCE_UMLA,
14601460

14611461
// The `llvm.experimental.stackmap` intrinsic.
14621462
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,7 @@ class SelectionDAG {
16021602
/// the target's desired shift amount type.
16031603
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
16041604

1605-
/// Expands PARTIAL_REDUCE_S/UADD nodes.
1605+
/// Expands PARTIAL_REDUCE_S/UMLA nodes.
16061606
/// \p Op1 Accumulator for where the result is stored for the partial
16071607
/// reduction operation.
16081608
/// \p Op2 Input for the partial reduction operation.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8142,7 +8142,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81428142

81438143
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
81448144
setValue(&I,
8145-
DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
8145+
DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input));
81468146
return;
81478147
}
81488148
setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,10 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
567567
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
568568
return "histogram";
569569

570-
case ISD::PARTIAL_REDUCE_UADD:
571-
return "partial_reduce_uadd";
572-
case ISD::PARTIAL_REDUCE_SADD:
573-
return "partial_reduce_sadd";
570+
case ISD::PARTIAL_REDUCE_UMLA:
571+
return "partial_reduce_umla";
572+
case ISD::PARTIAL_REDUCE_SMLA:
573+
return "partial_reduce_smla";
574574

575575
// Vector Predication
576576
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11241124
setTargetDAGCombine(
11251125
{ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
11261126

1127-
setTargetDAGCombine({ISD::PARTIAL_REDUCE_SADD, ISD::PARTIAL_REDUCE_UADD});
1127+
setTargetDAGCombine({ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA});
11281128

11291129
setTargetDAGCombine(ISD::FP_EXTEND);
11301130

@@ -1842,14 +1842,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18421842
}
18431843

18441844
for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
1845-
setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
1846-
setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
1845+
setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Custom);
1846+
setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Custom);
18471847
}
18481848
}
18491849

18501850
for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
1851-
setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
1852-
setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
1851+
setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Custom);
1852+
setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Custom);
18531853
}
18541854

18551855
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
@@ -7606,9 +7606,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
76067606
return LowerFLDEXP(Op, DAG);
76077607
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
76087608
return LowerVECTOR_HISTOGRAM(Op, DAG);
7609-
case ISD::PARTIAL_REDUCE_UADD:
7610-
case ISD::PARTIAL_REDUCE_SADD:
7611-
return LowerPARTIAL_REDUCE_ADD(Op, DAG);
7609+
case ISD::PARTIAL_REDUCE_UMLA:
7610+
case ISD::PARTIAL_REDUCE_SMLA:
7611+
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
76127612
}
76137613
}
76147614

@@ -22070,9 +22070,8 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2207022070
return DAG.expandPartialReduceAdd(DL, Acc, Input);
2207122071

2207222072
unsigned NewOpcode =
22073-
AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
22074-
auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
22075-
return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
22073+
AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22074+
return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
2207622075
}
2207722076

2207822077
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
@@ -22094,9 +22093,10 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2209422093
return SDValue();
2209522094

2209622095
unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22097-
? ISD::PARTIAL_REDUCE_SADD
22098-
: ISD::PARTIAL_REDUCE_UADD;
22099-
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
22096+
? ISD::PARTIAL_REDUCE_SMLA
22097+
: ISD::PARTIAL_REDUCE_UMLA;
22098+
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
22099+
DAG.getConstant(1, DL, InputVT));
2210022100
}
2210122101

2210222102
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
@@ -26412,8 +26412,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2641226412
case ISD::MSCATTER:
2641326413
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
2641426414
return performMaskedGatherScatterCombine(N, DCI, DAG);
26415-
case ISD::PARTIAL_REDUCE_UADD:
26416-
case ISD::PARTIAL_REDUCE_SADD:
26415+
case ISD::PARTIAL_REDUCE_UMLA:
26416+
case ISD::PARTIAL_REDUCE_SMLA:
2641726417
return performPartialReduceAddCombine(N, DAG, Subtarget);
2641826418
case ISD::FP_EXTEND:
2641926419
return performFPExtendCombine(N, DAG, DCI, Subtarget);
@@ -29176,39 +29176,29 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2917629176
}
2917729177

2917829178
SDValue
29179-
AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
29179+
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2918029180
SelectionDAG &DAG) const {
2918129181
SDLoc DL(Op);
2918229182
SDValue Acc = Op.getOperand(0);
29183-
SDValue Input = Op.getOperand(1);
29183+
SDValue Input1 = Op.getOperand(1);
29184+
SDValue Input2 = Op.getOperand(2);
2918429185

2918529186
EVT AccVT = Acc.getValueType();
29186-
EVT InputVT = Input.getValueType();
29187+
EVT InputVT = Input1.getValueType();
2918729188

2918829189
unsigned Opcode = Op.getOpcode();
2918929190

29190-
// If the following condition is true and the input opcode was not ISD::MUL
29191-
// during the DAG-combine, it is already expanded. So this condition means the
29192-
// input opcode must have been ISD::MUL.
2919329191
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
29194-
unsigned IndexAdd = 0;
29195-
// ISD::MUL may have already been lowered, meaning the operands would be in
29196-
// different positions.
29197-
if (Input.getOpcode() != ISD::MUL)
29198-
IndexAdd = 1;
29199-
auto A = Input.getOperand(IndexAdd);
29200-
auto B = Input.getOperand(IndexAdd + 1);
29201-
29202-
unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
29192+
unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
2920329193
: AArch64ISD::UDOT;
29204-
return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
29194+
return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
2920529195
}
29206-
bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
29196+
bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
2920729197
unsigned BottomOpcode =
2920829198
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
2920929199
unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
29210-
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
29211-
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
29200+
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
29201+
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
2921229202
}
2921329203

2921429204
SDValue

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,7 @@ class AArch64TargetLowering : public TargetLowering {
11841184
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
11851185
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
11861186
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
1187-
SDValue LowerPARTIAL_REDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
1187+
SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
11881188
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
11891189
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
11901190
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)