@@ -1128,7 +1128,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1128
1128
setTargetDAGCombine(
1129
1129
{ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
1130
1130
1131
- setTargetDAGCombine({ISD::PARTIAL_REDUCE_SADD , ISD::PARTIAL_REDUCE_UADD });
1131
+ setTargetDAGCombine({ISD::PARTIAL_REDUCE_SMLA , ISD::PARTIAL_REDUCE_UMLA });
1132
1132
1133
1133
setTargetDAGCombine(ISD::FP_EXTEND);
1134
1134
@@ -1848,14 +1848,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1848
1848
}
1849
1849
1850
1850
for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
1851
- setOperationAction(ISD::PARTIAL_REDUCE_UADD , VT, Custom);
1852
- setOperationAction(ISD::PARTIAL_REDUCE_SADD , VT, Custom);
1851
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA , VT, Custom);
1852
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA , VT, Custom);
1853
1853
}
1854
1854
}
1855
1855
1856
1856
for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
1857
- setOperationAction(ISD::PARTIAL_REDUCE_UADD , VT, Custom);
1858
- setOperationAction(ISD::PARTIAL_REDUCE_SADD , VT, Custom);
1857
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA , VT, Custom);
1858
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA , VT, Custom);
1859
1859
}
1860
1860
1861
1861
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
@@ -7669,9 +7669,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
7669
7669
return LowerFLDEXP(Op, DAG);
7670
7670
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
7671
7671
return LowerVECTOR_HISTOGRAM(Op, DAG);
7672
- case ISD::PARTIAL_REDUCE_UADD :
7673
- case ISD::PARTIAL_REDUCE_SADD :
7674
- return LowerPARTIAL_REDUCE_ADD (Op, DAG);
7672
+ case ISD::PARTIAL_REDUCE_UMLA :
7673
+ case ISD::PARTIAL_REDUCE_SMLA :
7674
+ return LowerPARTIAL_REDUCE_MLA (Op, DAG);
7675
7675
}
7676
7676
}
7677
7677
@@ -22112,9 +22112,8 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22112
22112
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22113
22113
22114
22114
unsigned NewOpcode =
22115
- AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
22116
- auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
22117
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
22115
+ AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22116
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
22118
22117
}
22119
22118
22120
22119
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
@@ -22136,9 +22135,10 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22136
22135
return SDValue();
22137
22136
22138
22137
unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22139
- ? ISD::PARTIAL_REDUCE_SADD
22140
- : ISD::PARTIAL_REDUCE_UADD;
22141
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
22138
+ ? ISD::PARTIAL_REDUCE_SMLA
22139
+ : ISD::PARTIAL_REDUCE_UMLA;
22140
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
22141
+ DAG.getConstant(1, DL, InputVT));
22142
22142
}
22143
22143
22144
22144
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
@@ -26599,8 +26599,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
26599
26599
case ISD::MSCATTER:
26600
26600
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
26601
26601
return performMaskedGatherScatterCombine(N, DCI, DAG);
26602
- case ISD::PARTIAL_REDUCE_UADD :
26603
- case ISD::PARTIAL_REDUCE_SADD :
26602
+ case ISD::PARTIAL_REDUCE_UMLA :
26603
+ case ISD::PARTIAL_REDUCE_SMLA :
26604
26604
return performPartialReduceAddCombine(N, DAG, Subtarget);
26605
26605
case ISD::FP_EXTEND:
26606
26606
return performFPExtendCombine(N, DAG, DCI, Subtarget);
@@ -29372,39 +29372,29 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29372
29372
}
29373
29373
29374
29374
SDValue
29375
- AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD (SDValue Op,
29375
+ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA (SDValue Op,
29376
29376
SelectionDAG &DAG) const {
29377
29377
SDLoc DL(Op);
29378
29378
SDValue Acc = Op.getOperand(0);
29379
- SDValue Input = Op.getOperand(1);
29379
+ SDValue Input1 = Op.getOperand(1);
29380
+ SDValue Input2 = Op.getOperand(2);
29380
29381
29381
29382
EVT AccVT = Acc.getValueType();
29382
- EVT InputVT = Input .getValueType();
29383
+ EVT InputVT = Input1 .getValueType();
29383
29384
29384
29385
unsigned Opcode = Op.getOpcode();
29385
29386
29386
- // If the following condition is true and the input opcode was not ISD::MUL
29387
- // during the DAG-combine, it is already expanded. So this condition means the
29388
- // input opcode must have been ISD::MUL.
29389
29387
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
29390
- unsigned IndexAdd = 0;
29391
- // ISD::MUL may have already been lowered, meaning the operands would be in
29392
- // different positions.
29393
- if (Input.getOpcode() != ISD::MUL)
29394
- IndexAdd = 1;
29395
- auto A = Input.getOperand(IndexAdd);
29396
- auto B = Input.getOperand(IndexAdd + 1);
29397
-
29398
- unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
29388
+ unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
29399
29389
: AArch64ISD::UDOT;
29400
- return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B );
29390
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2 );
29401
29391
}
29402
- bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD ;
29392
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA ;
29403
29393
unsigned BottomOpcode =
29404
29394
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
29405
29395
unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
29406
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input );
29407
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input );
29396
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1 );
29397
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1 );
29408
29398
}
29409
29399
29410
29400
SDValue
0 commit comments