@@ -1124,7 +1124,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1124
1124
setTargetDAGCombine(
1125
1125
{ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
1126
1126
1127
- setTargetDAGCombine({ISD::PARTIAL_REDUCE_SADD , ISD::PARTIAL_REDUCE_UADD });
1127
+ setTargetDAGCombine({ISD::PARTIAL_REDUCE_SMLA , ISD::PARTIAL_REDUCE_UMLA });
1128
1128
1129
1129
setTargetDAGCombine(ISD::FP_EXTEND);
1130
1130
@@ -1842,14 +1842,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1842
1842
}
1843
1843
1844
1844
for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
1845
- setOperationAction(ISD::PARTIAL_REDUCE_UADD , VT, Custom);
1846
- setOperationAction(ISD::PARTIAL_REDUCE_SADD , VT, Custom);
1845
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA , VT, Custom);
1846
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA , VT, Custom);
1847
1847
}
1848
1848
}
1849
1849
1850
1850
for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
1851
- setOperationAction(ISD::PARTIAL_REDUCE_UADD , VT, Custom);
1852
- setOperationAction(ISD::PARTIAL_REDUCE_SADD , VT, Custom);
1851
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA , VT, Custom);
1852
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA , VT, Custom);
1853
1853
}
1854
1854
1855
1855
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
@@ -7606,9 +7606,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
7606
7606
return LowerFLDEXP(Op, DAG);
7607
7607
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
7608
7608
return LowerVECTOR_HISTOGRAM(Op, DAG);
7609
- case ISD::PARTIAL_REDUCE_UADD :
7610
- case ISD::PARTIAL_REDUCE_SADD :
7611
- return LowerPARTIAL_REDUCE_ADD (Op, DAG);
7609
+ case ISD::PARTIAL_REDUCE_UMLA :
7610
+ case ISD::PARTIAL_REDUCE_SMLA :
7611
+ return LowerPARTIAL_REDUCE_MLA (Op, DAG);
7612
7612
}
7613
7613
}
7614
7614
@@ -22070,9 +22070,8 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22070
22070
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22071
22071
22072
22072
unsigned NewOpcode =
22073
- AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
22074
- auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
22075
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
22073
+ AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22074
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
22076
22075
}
22077
22076
22078
22077
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
@@ -22094,9 +22093,10 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22094
22093
return SDValue();
22095
22094
22096
22095
unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22097
- ? ISD::PARTIAL_REDUCE_SADD
22098
- : ISD::PARTIAL_REDUCE_UADD;
22099
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
22096
+ ? ISD::PARTIAL_REDUCE_SMLA
22097
+ : ISD::PARTIAL_REDUCE_UMLA;
22098
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
22099
+ DAG.getConstant(1, DL, InputVT));
22100
22100
}
22101
22101
22102
22102
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
@@ -26412,8 +26412,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
26412
26412
case ISD::MSCATTER:
26413
26413
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
26414
26414
return performMaskedGatherScatterCombine(N, DCI, DAG);
26415
- case ISD::PARTIAL_REDUCE_UADD :
26416
- case ISD::PARTIAL_REDUCE_SADD :
26415
+ case ISD::PARTIAL_REDUCE_UMLA :
26416
+ case ISD::PARTIAL_REDUCE_SMLA :
26417
26417
return performPartialReduceAddCombine(N, DAG, Subtarget);
26418
26418
case ISD::FP_EXTEND:
26419
26419
return performFPExtendCombine(N, DAG, DCI, Subtarget);
@@ -29176,39 +29176,29 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29176
29176
}
29177
29177
29178
29178
SDValue
29179
- AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD (SDValue Op,
29179
+ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA (SDValue Op,
29180
29180
SelectionDAG &DAG) const {
29181
29181
SDLoc DL(Op);
29182
29182
SDValue Acc = Op.getOperand(0);
29183
- SDValue Input = Op.getOperand(1);
29183
+ SDValue Input1 = Op.getOperand(1);
29184
+ SDValue Input2 = Op.getOperand(2);
29184
29185
29185
29186
EVT AccVT = Acc.getValueType();
29186
- EVT InputVT = Input .getValueType();
29187
+ EVT InputVT = Input1 .getValueType();
29187
29188
29188
29189
unsigned Opcode = Op.getOpcode();
29189
29190
29190
- // If the following condition is true and the input opcode was not ISD::MUL
29191
- // during the DAG-combine, it is already expanded. So this condition means the
29192
- // input opcode must have been ISD::MUL.
29193
29191
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
29194
- unsigned IndexAdd = 0;
29195
- // ISD::MUL may have already been lowered, meaning the operands would be in
29196
- // different positions.
29197
- if (Input.getOpcode() != ISD::MUL)
29198
- IndexAdd = 1;
29199
- auto A = Input.getOperand(IndexAdd);
29200
- auto B = Input.getOperand(IndexAdd + 1);
29201
-
29202
- unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
29192
+ unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
29203
29193
: AArch64ISD::UDOT;
29204
- return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B );
29194
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2 );
29205
29195
}
29206
- bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD ;
29196
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA ;
29207
29197
unsigned BottomOpcode =
29208
29198
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
29209
29199
unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
29210
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input );
29211
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input );
29200
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1 );
29201
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1 );
29212
29202
}
29213
29203
29214
29204
SDValue
0 commit comments