Skip to content

Commit 7923fdc

Browse files
SamTebbs33tmsri
authored andcommitted
[AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot (llvm#107566)
This PR adds lowering for partial reductions of a mix of sign/zero extended inputs to the usdot intrinsic.
1 parent def587d commit 7923fdc

File tree

6 files changed

+304
-39
lines changed

6 files changed

+304
-39
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,6 +2701,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
27012701
MAKE_CASE(AArch64ISD::SADDLV)
27022702
MAKE_CASE(AArch64ISD::SDOT)
27032703
MAKE_CASE(AArch64ISD::UDOT)
2704+
MAKE_CASE(AArch64ISD::USDOT)
27042705
MAKE_CASE(AArch64ISD::SMINV)
27052706
MAKE_CASE(AArch64ISD::UMINV)
27062707
MAKE_CASE(AArch64ISD::SMAXV)
@@ -6114,6 +6115,11 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
61146115
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
61156116
Op.getOperand(2), Op.getOperand(3));
61166117
}
6118+
case Intrinsic::aarch64_neon_usdot:
6119+
case Intrinsic::aarch64_sve_usdot: {
6120+
return DAG.getNode(AArch64ISD::USDOT, dl, Op.getValueType(),
6121+
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
6122+
}
61176123
case Intrinsic::get_active_lane_mask: {
61186124
SDValue ID =
61196125
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
@@ -21849,37 +21855,50 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2184921855

2185021856
auto ExtA = MulOp->getOperand(0);
2185121857
auto ExtB = MulOp->getOperand(1);
21852-
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21853-
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
21854-
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
21858+
21859+
if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
21860+
!ISD::isExtOpcode(ExtB->getOpcode()))
2185521861
return SDValue();
21862+
bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21863+
bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
2185621864

2185721865
auto A = ExtA->getOperand(0);
2185821866
auto B = ExtB->getOperand(0);
2185921867
if (A.getValueType() != B.getValueType())
2186021868
return SDValue();
2186121869

21862-
unsigned Opcode = 0;
21863-
21864-
if (IsSExt)
21865-
Opcode = AArch64ISD::SDOT;
21866-
else if (IsZExt)
21867-
Opcode = AArch64ISD::UDOT;
21868-
21869-
assert(Opcode != 0 && "Unexpected dot product case encountered.");
21870-
2187121870
EVT ReducedType = N->getValueType(0);
2187221871
EVT MulSrcType = A.getValueType();
2187321872

2187421873
// Dot products operate on chunks of four elements so there must be four times
2187521874
// as many elements in the wide type
21876-
if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
21877-
(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
21878-
(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
21879-
(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21880-
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
21875+
if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
21876+
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
21877+
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
21878+
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21879+
return SDValue();
2188121880

21882-
return SDValue();
21881+
// If the extensions are mixed, we should lower it to a usdot instead
21882+
unsigned Opcode = 0;
21883+
if (AIsSigned != BIsSigned) {
21884+
if (!Subtarget->hasMatMulInt8())
21885+
return SDValue();
21886+
21887+
bool Scalable = N->getValueType(0).isScalableVT();
21888+
// There's no nxv2i64 version of usdot
21889+
if (Scalable && ReducedType != MVT::nxv4i32)
21890+
return SDValue();
21891+
21892+
Opcode = AArch64ISD::USDOT;
21893+
// USDOT expects the signed operand to be last
21894+
if (!BIsSigned)
21895+
std::swap(A, B);
21896+
} else if (AIsSigned)
21897+
Opcode = AArch64ISD::SDOT;
21898+
else
21899+
Opcode = AArch64ISD::UDOT;
21900+
21901+
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
2188321902
}
2188421903

2188521904
static SDValue performIntrinsicCombine(SDNode *N,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,10 @@ enum NodeType : unsigned {
280280
SADDLP,
281281
UADDLP,
282282

283-
// udot/sdot instructions
283+
// udot/sdot/usdot instructions
284284
UDOT,
285285
SDOT,
286+
USDOT,
286287

287288
// Vector across-lanes min/max
288289
// Only the lower result lane is defined.

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ def AArch64frsqrts : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>;
857857

858858
def AArch64sdot : SDNode<"AArch64ISD::SDOT", SDT_AArch64Dot>;
859859
def AArch64udot : SDNode<"AArch64ISD::UDOT", SDT_AArch64Dot>;
860+
def AArch64usdot : SDNode<"AArch64ISD::USDOT", SDT_AArch64Dot>;
860861

861862
def AArch64saddv : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>;
862863
def AArch64uaddv : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>;
@@ -1419,8 +1420,8 @@ let Predicates = [HasMatMulInt8] in {
14191420
def SMMLA : SIMDThreeSameVectorMatMul<0, 0, "smmla", int_aarch64_neon_smmla>;
14201421
def UMMLA : SIMDThreeSameVectorMatMul<0, 1, "ummla", int_aarch64_neon_ummla>;
14211422
def USMMLA : SIMDThreeSameVectorMatMul<1, 0, "usmmla", int_aarch64_neon_usmmla>;
1422-
defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", int_aarch64_neon_usdot>;
1423-
defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", int_aarch64_neon_usdot>;
1423+
defm USDOT : SIMDThreeSameVectorDot<0, 1, "usdot", AArch64usdot>;
1424+
defm USDOTlane : SIMDThreeSameVectorDotIndex<0, 1, 0b10, "usdot", AArch64usdot>;
14241425

14251426
// sudot lane has a pattern where usdot is expected (there is no sudot).
14261427
// The second operand is used in the dup operation to repeat the indexed
@@ -1432,7 +1433,7 @@ class BaseSIMDSUDOTIndex<bit Q, string dst_kind, string lhs_kind,
14321433
lhs_kind, rhs_kind, RegType, AccumType,
14331434
InputType, null_frag> {
14341435
let Pattern = [(set (AccumType RegType:$dst),
1435-
(AccumType (int_aarch64_neon_usdot (AccumType RegType:$Rd),
1436+
(AccumType (AArch64usdot (AccumType RegType:$Rd),
14361437
(InputType (bitconvert (AccumType
14371438
(AArch64duplane32 (v4i32 V128:$Rm),
14381439
VectorIndexS:$idx)))),

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3433,7 +3433,7 @@ let Predicates = [HasSVE, HasMatMulInt8] in {
34333433
} // End HasSVE, HasMatMulInt8
34343434

34353435
let Predicates = [HasSVEorSME, HasMatMulInt8] in {
3436-
defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", int_aarch64_sve_usdot>;
3436+
defm USDOT_ZZZ : sve_int_dot_mixed<"usdot", AArch64usdot>;
34373437
defm USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot", int_aarch64_sve_usdot_lane>;
34383438
defm SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot", int_aarch64_sve_sudot_lane>;
34393439
} // End HasSVEorSME, HasMatMulInt8

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
3-
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
2+
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
3+
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
4+
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
45

56
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
67
; CHECK-DOT-LABEL: udot:
@@ -102,7 +103,115 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
102103
ret <2 x i32> %partial.reduce
103104
}
104105

105-
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
106+
define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
107+
; CHECK-NOI8MM-LABEL: usdot:
108+
; CHECK-NOI8MM: // %bb.0:
109+
; CHECK-NOI8MM-NEXT: ushll v3.8h, v1.8b, #0
110+
; CHECK-NOI8MM-NEXT: ushll2 v1.8h, v1.16b, #0
111+
; CHECK-NOI8MM-NEXT: sshll v4.8h, v2.8b, #0
112+
; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0
113+
; CHECK-NOI8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
114+
; CHECK-NOI8MM-NEXT: smull v5.4s, v2.4h, v1.4h
115+
; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
116+
; CHECK-NOI8MM-NEXT: smlal2 v5.4s, v4.8h, v3.8h
117+
; CHECK-NOI8MM-NEXT: add v0.4s, v5.4s, v0.4s
118+
; CHECK-NOI8MM-NEXT: ret
119+
;
120+
; CHECK-I8MM-LABEL: usdot:
121+
; CHECK-I8MM: // %bb.0:
122+
; CHECK-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
123+
; CHECK-I8MM-NEXT: ret
124+
%u.wide = zext <16 x i8> %u to <16 x i32>
125+
%s.wide = sext <16 x i8> %s to <16 x i32>
126+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
127+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
128+
ret <4 x i32> %partial.reduce
129+
}
130+
131+
define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
132+
; CHECK-NOI8MM-LABEL: usdot_narrow:
133+
; CHECK-NOI8MM: // %bb.0:
134+
; CHECK-NOI8MM-NEXT: ushll v1.8h, v1.8b, #0
135+
; CHECK-NOI8MM-NEXT: sshll v2.8h, v2.8b, #0
136+
; CHECK-NOI8MM-NEXT: // kill: def $d0 killed $d0 def $q0
137+
; CHECK-NOI8MM-NEXT: smull v3.4s, v2.4h, v1.4h
138+
; CHECK-NOI8MM-NEXT: smull2 v4.4s, v2.8h, v1.8h
139+
; CHECK-NOI8MM-NEXT: ext v5.16b, v1.16b, v1.16b, #8
140+
; CHECK-NOI8MM-NEXT: ext v6.16b, v2.16b, v2.16b, #8
141+
; CHECK-NOI8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
142+
; CHECK-NOI8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
143+
; CHECK-NOI8MM-NEXT: ext v1.16b, v4.16b, v4.16b, #8
144+
; CHECK-NOI8MM-NEXT: smlal v3.4s, v6.4h, v5.4h
145+
; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s
146+
; CHECK-NOI8MM-NEXT: add v0.2s, v3.2s, v0.2s
147+
; CHECK-NOI8MM-NEXT: ret
148+
;
149+
; CHECK-I8MM-LABEL: usdot_narrow:
150+
; CHECK-I8MM: // %bb.0:
151+
; CHECK-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
152+
; CHECK-I8MM-NEXT: ret
153+
%u.wide = zext <8 x i8> %u to <8 x i32>
154+
%s.wide = sext <8 x i8> %s to <8 x i32>
155+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
156+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
157+
ret <2 x i32> %partial.reduce
158+
}
159+
160+
define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
161+
; CHECK-NOI8MM-LABEL: sudot:
162+
; CHECK-NOI8MM: // %bb.0:
163+
; CHECK-NOI8MM-NEXT: sshll v3.8h, v1.8b, #0
164+
; CHECK-NOI8MM-NEXT: sshll2 v1.8h, v1.16b, #0
165+
; CHECK-NOI8MM-NEXT: ushll v4.8h, v2.8b, #0
166+
; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0
167+
; CHECK-NOI8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
168+
; CHECK-NOI8MM-NEXT: smull v5.4s, v2.4h, v1.4h
169+
; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
170+
; CHECK-NOI8MM-NEXT: smlal2 v5.4s, v4.8h, v3.8h
171+
; CHECK-NOI8MM-NEXT: add v0.4s, v5.4s, v0.4s
172+
; CHECK-NOI8MM-NEXT: ret
173+
;
174+
; CHECK-I8MM-LABEL: sudot:
175+
; CHECK-I8MM: // %bb.0:
176+
; CHECK-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
177+
; CHECK-I8MM-NEXT: ret
178+
%u.wide = sext <16 x i8> %u to <16 x i32>
179+
%s.wide = zext <16 x i8> %s to <16 x i32>
180+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
181+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
182+
ret <4 x i32> %partial.reduce
183+
}
184+
185+
define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
186+
; CHECK-NOI8MM-LABEL: sudot_narrow:
187+
; CHECK-NOI8MM: // %bb.0:
188+
; CHECK-NOI8MM-NEXT: sshll v1.8h, v1.8b, #0
189+
; CHECK-NOI8MM-NEXT: ushll v2.8h, v2.8b, #0
190+
; CHECK-NOI8MM-NEXT: // kill: def $d0 killed $d0 def $q0
191+
; CHECK-NOI8MM-NEXT: smull v3.4s, v2.4h, v1.4h
192+
; CHECK-NOI8MM-NEXT: smull2 v4.4s, v2.8h, v1.8h
193+
; CHECK-NOI8MM-NEXT: ext v5.16b, v1.16b, v1.16b, #8
194+
; CHECK-NOI8MM-NEXT: ext v6.16b, v2.16b, v2.16b, #8
195+
; CHECK-NOI8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
196+
; CHECK-NOI8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
197+
; CHECK-NOI8MM-NEXT: ext v1.16b, v4.16b, v4.16b, #8
198+
; CHECK-NOI8MM-NEXT: smlal v3.4s, v6.4h, v5.4h
199+
; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s
200+
; CHECK-NOI8MM-NEXT: add v0.2s, v3.2s, v0.2s
201+
; CHECK-NOI8MM-NEXT: ret
202+
;
203+
; CHECK-I8MM-LABEL: sudot_narrow:
204+
; CHECK-I8MM: // %bb.0:
205+
; CHECK-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
206+
; CHECK-I8MM-NEXT: ret
207+
%u.wide = sext <8 x i8> %u to <8 x i32>
208+
%s.wide = zext <8 x i8> %s to <8 x i32>
209+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
210+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
211+
ret <2 x i32> %partial.reduce
212+
}
213+
214+
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
106215
; CHECK-LABEL: not_udot:
107216
; CHECK: // %bb.0:
108217
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b

0 commit comments

Comments
 (0)