Skip to content

Commit 832cd74

Browse files
committed
[AArch64] Armv8.6-a Matrix Mult Assembly + Intrinsics
This patch upstreams support for the Armv8.6-a Matrix Multiplication Extension. A summary of the features can be found here: https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a This patch includes: - Assembly support for AArch64 only (no SVE or Neon) - Intrinsics Support for AArch64 Armv8.6a Matrix Multiplication Instructions (No bfloat16 matrix multiplication) No IR types or C Types are needed for this extension. This is part of a patch series, starting with BFloat16 support and the other components in the armv8.6a extension (in previous patches linked in phabricator) Based on work by: - Luke Geeson - Oliver Stannard - Luke Cheeseman Reviewers: ostannard, t.p.northover, rengolin, kmclaughlin Reviewed By: kmclaughlin Subscribers: kmclaughlin, kristof.beyls, hiraditya, danielkiss, cfe-commits Tags: #clang Differential Revision: https://reviews.llvm.org/D77871
1 parent dc9cff1 commit 832cd74

File tree

15 files changed

+561
-18
lines changed

15 files changed

+561
-18
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,21 @@ def OP_FMLAL_LN_Hi : Op<(call "vfmlal_high", $p0, $p1,
221221
def OP_FMLSL_LN_Hi : Op<(call "vfmlsl_high", $p0, $p1,
222222
(dup_typed $p1, (call "vget_lane", $p2, $p3)))>;
223223

224+
def OP_USDOT_LN
225+
: Op<(call "vusdot", $p0, $p1,
226+
(cast "8", "S", (call_mangled "splat_lane", (bitcast "int32x2_t", $p2), $p3)))>;
227+
def OP_USDOT_LNQ
228+
: Op<(call "vusdot", $p0, $p1,
229+
(cast "8", "S", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)))>;
230+
231+
// sudot splats the second vector and then calls vusdot
232+
def OP_SUDOT_LN
233+
: Op<(call "vusdot", $p0,
234+
(cast "8", "U", (call_mangled "splat_lane", (bitcast "int32x2_t", $p2), $p3)), $p1)>;
235+
def OP_SUDOT_LNQ
236+
: Op<(call "vusdot", $p0,
237+
(cast "8", "U", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)), $p1)>;
238+
224239
//===----------------------------------------------------------------------===//
225240
// Auxiliary Instructions
226241
//===----------------------------------------------------------------------===//
@@ -1792,6 +1807,23 @@ let ArchGuard = "defined(__ARM_FEATURE_FP16FML) && defined(__aarch64__)" in {
17921807
}
17931808
}
17941809

1810+
let ArchGuard = "defined(__ARM_FEATURE_MATMUL_INT8)" in {
1811+
def VMMLA : SInst<"vmmla", "..(<<)(<<)", "QUiQi">;
1812+
def VUSMMLA : SInst<"vusmmla", "..(<<U)(<<)", "Qi">;
1813+
1814+
def VUSDOT : SInst<"vusdot", "..(<<U)(<<)", "iQi">;
1815+
1816+
def VUSDOT_LANE : SOpInst<"vusdot_lane", "..(<<U)(<<q)I", "iQi", OP_USDOT_LN>;
1817+
def VSUDOT_LANE : SOpInst<"vsudot_lane", "..(<<)(<<qU)I", "iQi", OP_SUDOT_LN>;
1818+
1819+
let ArchGuard = "defined(__aarch64__)" in {
1820+
let isLaneQ = 1 in {
1821+
def VUSDOT_LANEQ : SOpInst<"vusdot_laneq", "..(<<U)(<<Q)I", "iQi", OP_USDOT_LNQ>;
1822+
def VSUDOT_LANEQ : SOpInst<"vsudot_laneq", "..(<<)(<<QU)I", "iQi", OP_SUDOT_LNQ>;
1823+
}
1824+
}
1825+
}
1826+
17951827
// v8.3-A Vector complex addition intrinsics
17961828
let ArchGuard = "defined(__ARM_FEATURE_COMPLEX) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)" in {
17971829
def VCADD_ROT90_FP16 : SInst<"vcadd_rot90", "...", "h">;
@@ -1808,4 +1840,4 @@ let ArchGuard = "defined(__ARM_FEATURE_COMPLEX)" in {
18081840
let ArchGuard = "defined(__ARM_FEATURE_COMPLEX) && defined(__aarch64__)" in {
18091841
def VCADDQ_ROT90_FP64 : SInst<"vcaddq_rot90", "QQQ", "d">;
18101842
def VCADDQ_ROT270_FP64 : SInst<"vcaddq_rot270", "QQQ", "d">;
1811-
}
1843+
}

clang/lib/Basic/Targets/AArch64.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ void AArch64TargetInfo::getTargetDefines(const LangOptions &Opts,
280280
if (HasTME)
281281
Builder.defineMacro("__ARM_FEATURE_TME", "1");
282282

283+
if (HasMatMul)
284+
Builder.defineMacro("__ARM_FEATURE_MATMUL_INT8", "1");
285+
283286
if ((FPU & NeonMode) && HasFP16FML)
284287
Builder.defineMacro("__ARM_FEATURE_FP16FML", "1");
285288

@@ -356,6 +359,7 @@ bool AArch64TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
356359
HasFP16FML = false;
357360
HasMTE = false;
358361
HasTME = false;
362+
HasMatMul = false;
359363
ArchKind = llvm::AArch64::ArchKind::ARMV8A;
360364

361365
for (const auto &Feature : Features) {
@@ -391,6 +395,8 @@ bool AArch64TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
391395
HasMTE = true;
392396
if (Feature == "+tme")
393397
HasTME = true;
398+
if (Feature == "+i8mm")
399+
HasMatMul = true;
394400
}
395401

396402
setDataLayout();

clang/lib/Basic/Targets/AArch64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY AArch64TargetInfo : public TargetInfo {
3636
bool HasFP16FML;
3737
bool HasMTE;
3838
bool HasTME;
39+
bool HasMatMul;
3940

4041
llvm::AArch64::ArchKind ArchKind;
4142

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5009,6 +5009,7 @@ static const ARMVectorIntrinsicInfo AArch64SIMDIntrinsicMap[] = {
50095009
NEONMAP1(vld1q_x2_v, aarch64_neon_ld1x2, 0),
50105010
NEONMAP1(vld1q_x3_v, aarch64_neon_ld1x3, 0),
50115011
NEONMAP1(vld1q_x4_v, aarch64_neon_ld1x4, 0),
5012+
NEONMAP2(vmmlaq_v, aarch64_neon_ummla, aarch64_neon_smmla, 0),
50125013
NEONMAP0(vmovl_v),
50135014
NEONMAP0(vmovn_v),
50145015
NEONMAP1(vmul_v, aarch64_neon_pmul, Add1ArgType),
@@ -5091,6 +5092,9 @@ static const ARMVectorIntrinsicInfo AArch64SIMDIntrinsicMap[] = {
50915092
NEONMAP0(vsubhn_v),
50925093
NEONMAP0(vtst_v),
50935094
NEONMAP0(vtstq_v),
5095+
NEONMAP1(vusdot_v, aarch64_neon_usdot, 0),
5096+
NEONMAP1(vusdotq_v, aarch64_neon_usdot, 0),
5097+
NEONMAP1(vusmmlaq_v, aarch64_neon_usmmla, 0),
50945098
};
50955099

50965100
static const ARMVectorIntrinsicInfo AArch64SISDIntrinsicMap[] = {
@@ -6076,6 +6080,26 @@ Value *CodeGenFunction::EmitCommonNeonBuiltinExpr(
60766080
llvm::Type *Tys[2] = { Ty, InputTy };
60776081
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vfmlsl_high");
60786082
}
6083+
case NEON::BI__builtin_neon_vmmlaq_v: {
6084+
llvm::Type *InputTy =
6085+
llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
6086+
llvm::Type *Tys[2] = { Ty, InputTy };
6087+
Int = Usgn ? LLVMIntrinsic : AltLLVMIntrinsic;
6088+
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vmmla");
6089+
}
6090+
case NEON::BI__builtin_neon_vusmmlaq_v: {
6091+
llvm::Type *InputTy =
6092+
llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
6093+
llvm::Type *Tys[2] = { Ty, InputTy };
6094+
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vusmmla");
6095+
}
6096+
case NEON::BI__builtin_neon_vusdot_v:
6097+
case NEON::BI__builtin_neon_vusdotq_v: {
6098+
llvm::Type *InputTy =
6099+
llvm::VectorType::get(Int8Ty, Ty->getPrimitiveSizeInBits() / 8);
6100+
llvm::Type *Tys[2] = { Ty, InputTy };
6101+
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vusdot");
6102+
}
60796103
}
60806104

60816105
assert(Int && "Expected valid intrinsic number");

clang/test/CodeGen/aarch64-matmul.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: %clang_cc1 -triple aarch64-eabi -target-feature +neon -target-feature +i8mm -S -emit-llvm %s -o - | FileCheck %s
2+
3+
#ifdef __ARM_FEATURE_MATMUL_INT8
4+
extern "C" void arm_feature_matmulint8_defined() {}
5+
#endif
6+
// CHECK: define void @arm_feature_matmulint8_defined()
7+
8+
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// RUN: %clang_cc1 -triple arm64-none-linux-gnu -target-feature +neon -target-feature +fullfp16 -target-feature +v8.6a -target-feature +i8mm \
2+
// RUN: -fallow-half-arguments-and-returns -S -disable-O0-optnone -emit-llvm -o - %s \
3+
// RUN: | opt -S -mem2reg -sroa \
4+
// RUN: | FileCheck %s
5+
6+
// REQUIRES: aarch64-registered-target
7+
8+
#include <arm_neon.h>
9+
10+
// CHECK-LABEL: test_vmmlaq_s32
11+
// CHECK: [[VAL:%.*]] = call <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
12+
// CHECK: ret <4 x i32> [[VAL]]
13+
int32x4_t test_vmmlaq_s32(int32x4_t r, int8x16_t a, int8x16_t b) {
14+
return vmmlaq_s32(r, a, b);
15+
}
16+
17+
// CHECK-LABEL: test_vmmlaq_u32
18+
// CHECK: [[VAL:%.*]] = call <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
19+
// CHECK: ret <4 x i32> [[VAL]]
20+
uint32x4_t test_vmmlaq_u32(uint32x4_t r, uint8x16_t a, uint8x16_t b) {
21+
return vmmlaq_u32(r, a, b);
22+
}
23+
24+
// CHECK-LABEL: test_vusmmlaq_s32
25+
// CHECK: [[VAL:%.*]] = call <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
26+
// CHECK: ret <4 x i32> [[VAL]]
27+
int32x4_t test_vusmmlaq_s32(int32x4_t r, uint8x16_t a, int8x16_t b) {
28+
return vusmmlaq_s32(r, a, b);
29+
}
30+
31+
// CHECK-LABEL: test_vusdot_s32
32+
// CHECK: [[VAL:%.*]] = call <2 x i32> @llvm.aarch64.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b)
33+
// CHECK: ret <2 x i32> [[VAL]]
34+
int32x2_t test_vusdot_s32(int32x2_t r, uint8x8_t a, int8x8_t b) {
35+
return vusdot_s32(r, a, b);
36+
}
37+
38+
// CHECK-LABEL: test_vusdot_lane_s32
39+
// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
40+
// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to <8 x i8>
41+
// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <2 x i32>
42+
// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <2 x i32> zeroinitializer
43+
// CHECK: [[TMP4:%.*]] = bitcast <2 x i32> [[LANE]] to <8 x i8>
44+
// CHECK: [[TMP5:%.*]] = bitcast <2 x i32> %r to <8 x i8>
45+
// CHECK: [[OP:%.*]] = call <2 x i32> @llvm.aarch64.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> [[TMP4]])
46+
// CHECK: ret <2 x i32> [[OP]]
47+
int32x2_t test_vusdot_lane_s32(int32x2_t r, uint8x8_t a, int8x8_t b) {
48+
return vusdot_lane_s32(r, a, b, 0);
49+
}
50+
51+
// CHECK-LABEL: test_vsudot_lane_s32
52+
// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
53+
// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> %0 to <8 x i8>
54+
// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> %1 to <2 x i32>
55+
// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <2 x i32> zeroinitializer
56+
// CHECK: [[TMP4:%.*]] = bitcast <2 x i32> [[LANE]] to <8 x i8>
57+
// CHECK: [[TMP5:%.*]] = bitcast <2 x i32> %r to <8 x i8>
58+
// CHECK: [[OP:%.*]] = call <2 x i32> @llvm.aarch64.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> [[TMP4]], <8 x i8> %a)
59+
// CHECK: ret <2 x i32> [[OP]]
60+
int32x2_t test_vsudot_lane_s32(int32x2_t r, int8x8_t a, uint8x8_t b) {
61+
return vsudot_lane_s32(r, a, b, 0);
62+
}
63+
64+
// CHECK-LABEL: test_vusdot_laneq_s32
65+
// CHECK: [[TMP0:%.*]] = bitcast <16 x i8> %b to <4 x i32>
66+
// CHECK: [[TMP1:%.*]] = bitcast <4 x i32> [[TMP0]] to <16 x i8>
67+
// CHECK: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <4 x i32>
68+
// CHECK: [[LANE:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> [[TMP2]], <2 x i32> zeroinitializer
69+
// CHECK: [[TMP4:%.*]] = bitcast <2 x i32> [[LANE]] to <8 x i8>
70+
// CHECK: [[TMP5:%.*]] = bitcast <2 x i32> %r to <8 x i8>
71+
// CHECK: [[OP:%.*]] = call <2 x i32> @llvm.aarch64.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> [[TMP4]])
72+
// CHECK: ret <2 x i32> [[OP]]
73+
int32x2_t test_vusdot_laneq_s32(int32x2_t r, uint8x8_t a, int8x16_t b) {
74+
return vusdot_laneq_s32(r, a, b, 0);
75+
}
76+
77+
// CHECK-LABEL: test_vsudot_laneq_s32
78+
// CHECK: [[TMP0:%.*]] = bitcast <16 x i8> %b to <4 x i32>
79+
// CHECK: [[TMP1:%.*]] = bitcast <4 x i32> [[TMP0]] to <16 x i8>
80+
// CHECK: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <4 x i32>
81+
// CHECK: [[LANE:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> [[TMP2]], <2 x i32> zeroinitializer
82+
// CHECK: [[TMP4:%.*]] = bitcast <2 x i32> [[LANE]] to <8 x i8>
83+
// CHECK: [[TMP5:%.*]] = bitcast <2 x i32> %r to <8 x i8>
84+
// CHECK: [[OP:%.*]] = call <2 x i32> @llvm.aarch64.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> [[TMP4]], <8 x i8> %a)
85+
// CHECK: ret <2 x i32> [[OP]]
86+
int32x2_t test_vsudot_laneq_s32(int32x2_t r, int8x8_t a, uint8x16_t b) {
87+
return vsudot_laneq_s32(r, a, b, 0);
88+
}
89+
90+
// CHECK-LABEL: test_vusdotq_s32
91+
// CHECK: [[VAL:%.*]] = call <4 x i32> @llvm.aarch64.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
92+
// CHECK: ret <4 x i32> [[VAL]]
93+
int32x4_t test_vusdotq_s32(int32x4_t r, uint8x16_t a, int8x16_t b) {
94+
return vusdotq_s32(r, a, b);
95+
}
96+
97+
// CHECK-LABEL: test_vusdotq_lane_s32
98+
// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
99+
// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to <8 x i8>
100+
// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <2 x i32>
101+
// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <4 x i32> zeroinitializer
102+
// CHECK: [[TMP4:%.*]] = bitcast <4 x i32> [[LANE]] to <16 x i8>
103+
// CHECK: [[TMP5:%.*]] = bitcast <4 x i32> %r to <16 x i8>
104+
// CHECK: [[OP:%.*]] = call <4 x i32> @llvm.aarch64.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> [[TMP4]])
105+
// CHECK: ret <4 x i32> [[OP]]
106+
int32x4_t test_vusdotq_lane_s32(int32x4_t r, uint8x16_t a, int8x8_t b) {
107+
return vusdotq_lane_s32(r, a, b, 0);
108+
}
109+
110+
// CHECK-LABEL: test_vsudotq_lane_s32
111+
// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
112+
// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to <8 x i8>
113+
// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <2 x i32>
114+
// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <4 x i32> zeroinitializer
115+
// CHECK: [[TMP4:%.*]] = bitcast <4 x i32> [[LANE]] to <16 x i8>
116+
// CHECK: [[TMP5:%.*]] = bitcast <4 x i32> %r to <16 x i8>
117+
// CHECK: [[OP:%.*]] = call <4 x i32> @llvm.aarch64.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> [[TMP4]], <16 x i8> %a)
118+
// CHECK: ret <4 x i32> [[OP]]
119+
int32x4_t test_vsudotq_lane_s32(int32x4_t r, int8x16_t a, uint8x8_t b) {
120+
return vsudotq_lane_s32(r, a, b, 0);
121+
}
122+
123+
// CHECK-LABEL: test_vusdotq_laneq_s32
124+
// CHECK: [[TMP0:%.*]] = bitcast <16 x i8> %b to <4 x i32>
125+
// CHECK: [[TMP1:%.*]] = bitcast <4 x i32> [[TMP0]] to <16 x i8>
126+
// CHECK: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <4 x i32>
127+
// CHECK: [[LANE:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> [[TMP2]], <4 x i32> zeroinitializer
128+
// CHECK: [[TMP4:%.*]] = bitcast <4 x i32> [[LANE]] to <16 x i8>
129+
// CHECK: [[TMP5:%.*]] = bitcast <4 x i32> %r to <16 x i8>
130+
// CHECK: [[OP:%.*]] = call <4 x i32> @llvm.aarch64.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> [[TMP4]])
131+
// CHECK: ret <4 x i32> [[OP]]
132+
int32x4_t test_vusdotq_laneq_s32(int32x4_t r, uint8x16_t a, int8x16_t b) {
133+
return vusdotq_laneq_s32(r, a, b, 0);
134+
}
135+
136+
// CHECK-LABEL: test_vsudotq_laneq_s32
137+
// CHECK: [[TMP0:%.*]] = bitcast <16 x i8> %b to <4 x i32>
138+
// CHECK: [[TMP1:%.*]] = bitcast <4 x i32> [[TMP0]] to <16 x i8>
139+
// CHECK: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <4 x i32>
140+
// CHECK: [[LANE:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> [[TMP2]], <4 x i32> zeroinitializer
141+
// CHECK: [[TMP4:%.*]] = bitcast <4 x i32> [[LANE]] to <16 x i8>
142+
// CHECK: [[TMP5:%.*]] = bitcast <4 x i32> %r to <16 x i8>
143+
// CHECK: [[OP:%.*]] = call <4 x i32> @llvm.aarch64.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> [[TMP4]], <16 x i8> %a)
144+
// CHECK: ret <4 x i32> [[OP]]
145+
int32x4_t test_vsudotq_laneq_s32(int32x4_t r, int8x16_t a, uint8x16_t b) {
146+
return vsudotq_laneq_s32(r, a, b, 0);
147+
}

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ let TargetPrefix = "aarch64" in { // All intrinsics start with "llvm.aarch64.".
173173
: Intrinsic<[llvm_anyvector_ty],
174174
[LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>],
175175
[IntrNoMem]>;
176+
177+
class AdvSIMD_MatMul_Intrinsic
178+
: Intrinsic<[llvm_anyvector_ty],
179+
[LLVMMatchType<0>, llvm_anyvector_ty, LLVMMatchType<1>],
180+
[IntrNoMem]>;
176181
}
177182

178183
// Arithmetic ops
@@ -449,6 +454,12 @@ let TargetPrefix = "aarch64", IntrProperties = [IntrNoMem] in {
449454
def int_aarch64_neon_udot : AdvSIMD_Dot_Intrinsic;
450455
def int_aarch64_neon_sdot : AdvSIMD_Dot_Intrinsic;
451456

457+
// v8.6-A Matrix Multiply Intrinsics
458+
def int_aarch64_neon_ummla : AdvSIMD_MatMul_Intrinsic;
459+
def int_aarch64_neon_smmla : AdvSIMD_MatMul_Intrinsic;
460+
def int_aarch64_neon_usmmla : AdvSIMD_MatMul_Intrinsic;
461+
def int_aarch64_neon_usdot : AdvSIMD_Dot_Intrinsic;
462+
452463
// v8.2-A FP16 Fused Multiply-Add Long
453464
def int_aarch64_neon_fmlal : AdvSIMD_FP16FML_Intrinsic;
454465
def int_aarch64_neon_fmlsl : AdvSIMD_FP16FML_Intrinsic;

llvm/lib/Target/AArch64/AArch64.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,22 @@ def FeatureTaggedGlobals : SubtargetFeature<"tagged-globals",
373373
def FeatureBF16 : SubtargetFeature<"bf16", "HasBF16",
374374
"true", "Enable BFloat16 Extension" >;
375375

376+
def FeatureMatMulInt8 : SubtargetFeature<"i8mm", "HasMatMulInt8",
377+
"true", "Enable Matrix Multiply Int8 Extension">;
378+
379+
def FeatureMatMulFP32 : SubtargetFeature<"f32mm", "HasMatMulFP32",
380+
"true", "Enable Matrix Multiply FP32 Extension", [FeatureSVE]>;
381+
382+
def FeatureMatMulFP64 : SubtargetFeature<"f64mm", "HasMatMulFP64",
383+
"true", "Enable Matrix Multiply FP64 Extension", [FeatureSVE]>;
384+
376385
def FeatureFineGrainedTraps : SubtargetFeature<"fgt", "HasFineGrainedTraps",
377386
"true", "Enable fine grained virtualization traps extension">;
378387

379388
def FeatureEnhancedCounterVirtualization :
380389
SubtargetFeature<"ecv", "HasEnhancedCounterVirtualization",
381390
"true", "Enable enhanced counter virtualization extension">;
382391

383-
384392
//===----------------------------------------------------------------------===//
385393
// Architectures.
386394
//
@@ -413,7 +421,7 @@ def HasV8_6aOps : SubtargetFeature<
413421
"v8.6a", "HasV8_6aOps", "true", "Support ARM v8.6a instructions",
414422

415423
[HasV8_5aOps, FeatureAMVS, FeatureBF16, FeatureFineGrainedTraps,
416-
FeatureEnhancedCounterVirtualization]>;
424+
FeatureEnhancedCounterVirtualization, FeatureMatMulInt8]>;
417425

418426
//===----------------------------------------------------------------------===//
419427
// Register File Description

0 commit comments

Comments
 (0)