Skip to content

Commit 415a8dc

Browse files
JamesChestermanNickGuy-Arm
authored andcommitted
[SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA
Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes. This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used.
1 parent 364835d commit 415a8dc

File tree

3 files changed

+41
-83
lines changed

3 files changed

+41
-83
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,12 @@ class TargetLoweringBase {
16641664
getPartialReduceMLAAction(AccVT, InputVT) == Custom;
16651665
}
16661666

1667+
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
1668+
/// legal for this target.
1669+
bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const {
1670+
return getPartialReduceMLAAction(AccVT, InputVT) == Legal;
1671+
}
1672+
16671673
/// If the action for this operation is to promote, this method returns the
16681674
/// ValueType to promote to.
16691675
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3220,8 +3220,26 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
32203220
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
32213221
SDValue &Hi) {
32223222
SDLoc DL(N);
3223-
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
3224-
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
3223+
SDValue Acc = N->getOperand(0);
3224+
SDValue Input1 = N->getOperand(1);
3225+
3226+
// If the node has not gone through the DAG combine, then do not attempt to
3227+
// legalise, just expand.
3228+
if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) {
3229+
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
3230+
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
3231+
return;
3232+
}
3233+
3234+
SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi;
3235+
std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
3236+
std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(Input1, DL);
3237+
std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(N->getOperand(2), DL);
3238+
unsigned Opcode = N->getOpcode();
3239+
EVT ResultVT = AccLo.getValueType();
3240+
3241+
Lo = DAG.getNode(Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
3242+
Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
32253243
}
32263244

32273245
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4501,7 +4519,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
45014519
}
45024520

45034521
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
4504-
return TLI.expandPartialReduceMLA(N, DAG);
4522+
SDValue Lo, Hi;
4523+
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
4524+
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), N->getValueType(0), Lo, Hi);
45054525
}
45064526

45074527
//===----------------------------------------------------------------------===//

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 12 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -210,42 +210,8 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
210210
; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z2.b
211211
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
212212
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
213-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
214-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
215-
; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z5.h
216-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
217-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
218-
; CHECK-NEWLOWERING-NEXT: uunpkhi z24.s, z3.h
219-
; CHECK-NEWLOWERING-NEXT: uunpkhi z25.s, z2.h
220-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z3.h
221-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z2.h
222-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
223-
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
224-
; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
225-
; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
226-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z4.s
227-
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z5.s
228-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
229-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
230-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z26.d
231-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z24.s
232-
; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
233-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z6.d
234-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z25.s
235-
; CHECK-NEWLOWERING-NEXT: uunpklo z7.d, z3.s
236-
; CHECK-NEWLOWERING-NEXT: mul z27.d, z29.d, z28.d
237-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z2.s
238-
; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
239-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
240-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
241-
; CHECK-NEWLOWERING-NEXT: mul z4.d, z5.d, z4.d
242-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z26.d
243-
; CHECK-NEWLOWERING-NEXT: movprfx z5, z27
244-
; CHECK-NEWLOWERING-NEXT: mla z5.d, p0/m, z28.d, z7.d
245-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z25.d, z24.d
246-
; CHECK-NEWLOWERING-NEXT: mad z2.d, p0/m, z3.d, z4.d
247-
; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
248-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
213+
; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
214+
; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
249215
; CHECK-NEWLOWERING-NEXT: ret
250216
entry:
251217
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -273,42 +239,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
273239
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z2.b
274240
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
275241
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
276-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
277-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
278-
; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z5.h
279-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
280-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
281-
; CHECK-NEWLOWERING-NEXT: sunpkhi z24.s, z3.h
282-
; CHECK-NEWLOWERING-NEXT: sunpkhi z25.s, z2.h
283-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z3.h
284-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z2.h
285-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
286-
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
287-
; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
288-
; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
289-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z4.s
290-
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z5.s
291-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
292-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
293-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z26.d
294-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z24.s
295-
; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
296-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z6.d
297-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z25.s
298-
; CHECK-NEWLOWERING-NEXT: sunpklo z7.d, z3.s
299-
; CHECK-NEWLOWERING-NEXT: mul z27.d, z29.d, z28.d
300-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z2.s
301-
; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
302-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
303-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
304-
; CHECK-NEWLOWERING-NEXT: mul z4.d, z5.d, z4.d
305-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z26.d
306-
; CHECK-NEWLOWERING-NEXT: movprfx z5, z27
307-
; CHECK-NEWLOWERING-NEXT: mla z5.d, p0/m, z28.d, z7.d
308-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z25.d, z24.d
309-
; CHECK-NEWLOWERING-NEXT: mad z2.d, p0/m, z3.d, z4.d
310-
; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
311-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
242+
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
243+
; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
312244
; CHECK-NEWLOWERING-NEXT: ret
313245
entry:
314246
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -790,11 +722,11 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
790722
; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff
791723
; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
792724
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
793-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
794-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h
795-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
725+
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
726+
; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
796727
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
797-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z4.s, z3.s
728+
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
729+
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
798730
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
799731
; CHECK-NEWLOWERING-NEXT: ret
800732
entry:
@@ -824,11 +756,11 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
824756
; CHECK-NEWLOWERING-NEXT: and z1.s, z1.s, #0xffff
825757
; CHECK-NEWLOWERING-NEXT: and z2.s, z2.s, #0xffff
826758
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
827-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
828-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z1.s
829-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
759+
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z1.s
760+
; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z2.s
830761
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
831-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z4.d, z3.d
762+
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
763+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z3.d, z4.d
832764
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
833765
; CHECK-NEWLOWERING-NEXT: ret
834766
entry:

0 commit comments

Comments
 (0)