Skip to content

Commit a27a811

Browse files
JamesChestermanNickGuy-Arm
authored andcommitted
[SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA
Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used.
1 parent a1f369e commit a27a811

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,12 @@ class TargetLoweringBase {
16681668
return Action == Legal || Action == Custom;
16691669
}
16701670

1671+
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
1672+
/// legal for this target.
1673+
bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const {
1674+
return getPartialReduceMLAAction(AccVT, InputVT) == Legal;
1675+
}
1676+
16711677
/// If the action for this operation is to promote, this method returns the
16721678
/// ValueType to promote to.
16731679
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3220,8 +3220,26 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
32203220
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
32213221
SDValue &Hi) {
32223222
SDLoc DL(N);
3223-
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
3224-
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
3223+
SDValue Acc = N->getOperand(0);
3224+
SDValue Input1 = N->getOperand(1);
3225+
3226+
// If the node has not gone through the DAG combine, then do not attempt to
3227+
// legalise, just expand.
3228+
if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) {
3229+
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
3230+
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
3231+
return;
3232+
}
3233+
3234+
SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi;
3235+
std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
3236+
std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(Input1, DL);
3237+
std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(N->getOperand(2), DL);
3238+
unsigned Opcode = N->getOpcode();
3239+
EVT ResultVT = AccLo.getValueType();
3240+
3241+
Lo = DAG.getNode(Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
3242+
Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
32253243
}
32263244

32273245
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4501,7 +4519,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
45014519
}
45024520

45034521
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
4504-
return TLI.expandPartialReduceMLA(N, DAG);
4522+
SDValue Lo, Hi;
4523+
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
4524+
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), N->getValueType(0), Lo, Hi);
45054525
}
45064526

45074527
//===----------------------------------------------------------------------===//

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
259259
; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
260260
; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
261261
; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
262+
; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
263+
; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
262264
; CHECK-NEWLOWERING-NEXT: ret
263265
entry:
264266
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -293,6 +295,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
293295
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
294296
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
295297
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
298+
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
299+
; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
296300
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
297301
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
298302
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h

0 commit comments

Comments
 (0)