Skip to content

Commit c217243

Browse files
[AArch64] Implements FP8 SVE intrinsics for dot-product (llvm#118125)
This patch adds the following intrinsics: * 8-bit floating-point dot product to single-precision. // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT4) || __ARM_FEATURE_SSVE_FP8DOT4 svfloat32_t svdot[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat32_t svdot[_n_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point indexed dot product to single-precision. // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT4) || __ARM_FEATURE_SSVE_FP8DOT4 svfloat32_t svdot_lane[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_3, fpm_t fpm); * 8-bit floating-point dot product to half-precision. // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT2) || __ARM_FEATURE_SSVE_FP8DOT2 svfloat16_t svdot[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat16_t svdot[_n_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point indexed dot product to half-precision. // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8DOT2) || __ARM_FEATURE_SSVE_FP8DOT2 svfloat16_t svdot_lane[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_7, fpm_t fpm);
1 parent 75e6d0e commit c217243

File tree

10 files changed

+287
-14
lines changed

10 files changed

+287
-14
lines changed

clang/include/clang/Basic/arm_sve.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,3 +2476,22 @@ let SVETargetGuard = "sve2,fp8", SMETargetGuard = "sme2,fp8" in {
24762476
def SVFCVTNB : SInst<"svcvtnb_mf8[_f32_x2]_fpm", "~2>", "f", MergeNone, "aarch64_sve_fp8_cvtnb", [VerifyRuntimeMode, SetsFPMR]>;
24772477
def SVFCVTNT : SInst<"svcvtnt_mf8[_f32_x2]_fpm", "~~2>", "f", MergeNone, "aarch64_sve_fp8_cvtnt", [VerifyRuntimeMode, SetsFPMR]>;
24782478
}
2479+
2480+
let SVETargetGuard = "sve2,fp8dot2", SMETargetGuard ="sme,ssve-fp8dot2" in {
2481+
// 8-bit floating-point dot product to half-precision (vectors)
2482+
def SVFDOT_2WAY : SInst<"svdot[_f16_mf8]_fpm", "dd~~>", "h", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
2483+
def SVFDOT_N_2WAY : SInst<"svdot[_n_f16_mf8]_fpm", "dd~!>", "h", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
2484+
2485+
// 8-bit floating-point dot product to half-precision (indexed)
2486+
def SVFDOT_LANE_2WAY : SInst<"svdot_lane[_f16_mf8]_fpm", "dd~~i>", "h", MergeNone, "aarch64_sve_fp8_fdot_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_7>]>;
2487+
}
2488+
2489+
let SVETargetGuard = "sve2,fp8dot4", SMETargetGuard ="sme,ssve-fp8dot4" in {
2490+
// 8-bit floating-point dot product to single-precision (vectors)
2491+
def SVFDOT_4WAY : SInst<"svdot[_f32_mf8]_fpm", "dd~~>", "f", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
2492+
def SVFDOT_N_4WAY : SInst<"svdot[_n_f32_mf8]_fpm", "dd~!>", "f", MergeNone, "aarch64_sve_fp8_fdot", [VerifyRuntimeMode, SetsFPMR]>;
2493+
2494+
// 8-bit floating-point dot product to single-precision (indexed)
2495+
def SVFDOT_LANE_4WAY : SInst<"svdot_lane[_f32_mf8]_fpm", "dd~~i>", "f", MergeNone, "aarch64_sve_fp8_fdot_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_3>]>;
2496+
}
2497+

clang/include/clang/Basic/arm_sve_sme_incl.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ include "arm_immcheck_incl.td"
8989
// j: element type promoted to 64bits (splat to vector type)
9090
// K: element type bitcast to a signed integer (splat to vector type)
9191
// L: element type bitcast to an unsigned integer (splat to vector type)
92+
// !: mfloat8_t (splat to svmfloat8_t)
9293
//
9394
// i: constant uint64_t
9495
// k: int32_t

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10719,7 +10719,16 @@ Value *CodeGenFunction::EmitSVEDupX(Value *Scalar, llvm::Type *Ty) {
1071910719
cast<llvm::VectorType>(Ty)->getElementCount(), Scalar);
1072010720
}
1072110721

10722-
Value *CodeGenFunction::EmitSVEDupX(Value* Scalar) {
10722+
Value *CodeGenFunction::EmitSVEDupX(Value *Scalar) {
10723+
if (auto *Ty = Scalar->getType(); Ty->isVectorTy()) {
10724+
#ifndef NDEBUG
10725+
auto *VecTy = cast<llvm::VectorType>(Ty);
10726+
ElementCount EC = VecTy->getElementCount();
10727+
assert(EC.isScalar() && VecTy->getElementType() == Int8Ty &&
10728+
"Only <1 x i8> expected");
10729+
#endif
10730+
Scalar = Builder.CreateExtractElement(Scalar, uint64_t(0));
10731+
}
1072310732
return EmitSVEDupX(Scalar, getSVEVectorForElementType(Scalar->getType()));
1072410733
}
1072510734

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sve -target-feature +sve2 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s
3+
// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +fp8 -target-feature +ssve-fp8dot2 -target-feature +ssve-fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s -check-prefix=CHECK-CXX
4+
5+
// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sve -target-feature +sve2 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s
6+
// RUN: %clang_cc1 -x c++ -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +fp8 -target-feature +ssve-fp8dot2 -target-feature +ssve-fp8dot4 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg,instcombine,tailcallelim | FileCheck %s -check-prefix=CHECK-CXX
7+
8+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sve -target-feature +sve2 -target-feature +fp8 -target-feature +fp8dot2 -target-feature +fp8dot4 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s
9+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -target-feature +fp8 -target-feature +ssve-fp8dot2 -target-feature +ssve-fp8dot4 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s
10+
11+
// REQUIRES: aarch64-registered-target
12+
13+
#ifdef __ARM_FEATURE_SME
14+
#include <arm_sme.h>
15+
#else
16+
#include <arm_sve.h>
17+
#endif
18+
19+
#ifdef SVE_OVERLOADED_FORMS
20+
#define SVE_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3
21+
#else
22+
#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
23+
#endif
24+
25+
#ifdef __ARM_FEATURE_SME
26+
#define STREAMING __arm_streaming
27+
#else
28+
#define STREAMING
29+
#endif
30+
31+
// CHECK-LABEL: define dso_local <vscale x 4 x float> @test_svdot_f32_mf8(
32+
// CHECK-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
33+
// CHECK-NEXT: [[ENTRY:.*:]]
34+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
35+
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
36+
// CHECK-NEXT: ret <vscale x 4 x float> [[TMP0]]
37+
//
38+
// CHECK-CXX-LABEL: define dso_local <vscale x 4 x float> @_Z18test_svdot_f32_mf8u13__SVFloat32_tu13__SVMfloat8_tS0_m(
39+
// CHECK-CXX-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
40+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
41+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
42+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
43+
// CHECK-CXX-NEXT: ret <vscale x 4 x float> [[TMP0]]
44+
//
45+
svfloat32_t test_svdot_f32_mf8(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
46+
return SVE_ACLE_FUNC(svdot,_f32_mf8,_fpm)(zda, zn, zm, fpm);
47+
}
48+
49+
// CHECK-LABEL: define dso_local <vscale x 4 x float> @test_svdot_n_f32_mf8(
50+
// CHECK-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
51+
// CHECK-NEXT: [[ENTRY:.*:]]
52+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
53+
// CHECK-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
54+
// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
55+
// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
56+
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
57+
// CHECK-NEXT: ret <vscale x 4 x float> [[TMP1]]
58+
//
59+
// CHECK-CXX-LABEL: define dso_local <vscale x 4 x float> @_Z20test_svdot_n_f32_mf8u13__SVFloat32_tu13__SVMfloat8_tu6__mfp8m(
60+
// CHECK-CXX-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
61+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
62+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
63+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
64+
// CHECK-CXX-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
65+
// CHECK-CXX-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
66+
// CHECK-CXX-NEXT: [[TMP1:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
67+
// CHECK-CXX-NEXT: ret <vscale x 4 x float> [[TMP1]]
68+
//
69+
svfloat32_t test_svdot_n_f32_mf8(svfloat32_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm) STREAMING {
70+
return SVE_ACLE_FUNC(svdot,_n_f32_mf8,_fpm)(zda, zn, zm, fpm);
71+
}
72+
73+
// CHECK-LABEL: define dso_local <vscale x 8 x half> @test_svdot_f16_mf8(
74+
// CHECK-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
75+
// CHECK-NEXT: [[ENTRY:.*:]]
76+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
77+
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
78+
// CHECK-NEXT: ret <vscale x 8 x half> [[TMP0]]
79+
//
80+
// CHECK-CXX-LABEL: define dso_local <vscale x 8 x half> @_Z18test_svdot_f16_mf8u13__SVFloat16_tu13__SVMfloat8_tS0_m(
81+
// CHECK-CXX-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
82+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
83+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
84+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
85+
// CHECK-CXX-NEXT: ret <vscale x 8 x half> [[TMP0]]
86+
//
87+
svfloat16_t test_svdot_f16_mf8(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
88+
return SVE_ACLE_FUNC(svdot,_f16_mf8,_fpm)(zda, zn, zm, fpm);
89+
}
90+
91+
// CHECK-LABEL: define dso_local <vscale x 8 x half> @test_svdot_n_f16_mf8(
92+
// CHECK-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
93+
// CHECK-NEXT: [[ENTRY:.*:]]
94+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
95+
// CHECK-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
96+
// CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
97+
// CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
98+
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
99+
// CHECK-NEXT: ret <vscale x 8 x half> [[TMP1]]
100+
//
101+
// CHECK-CXX-LABEL: define dso_local <vscale x 8 x half> @_Z20test_svdot_n_f16_mf8u13__SVFloat16_tu13__SVMfloat8_tu6__mfp8m(
102+
// CHECK-CXX-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <1 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
103+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
104+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
105+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = extractelement <1 x i8> [[ZM]], i64 0
106+
// CHECK-CXX-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[TMP0]], i64 0
107+
// CHECK-CXX-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[DOTSPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
108+
// CHECK-CXX-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[DOTSPLAT]])
109+
// CHECK-CXX-NEXT: ret <vscale x 8 x half> [[TMP1]]
110+
//
111+
svfloat16_t test_svdot_n_f16_mf8(svfloat16_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm) STREAMING {
112+
return SVE_ACLE_FUNC(svdot,_n_f16_mf8,_fpm)(zda, zn, zm, fpm);
113+
}
114+
115+
// CHECK-LABEL: define dso_local <vscale x 4 x float> @test_svdot_lane_f32_mf8(
116+
// CHECK-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
117+
// CHECK-NEXT: [[ENTRY:.*:]]
118+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
119+
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.lane.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 3)
120+
// CHECK-NEXT: ret <vscale x 4 x float> [[TMP0]]
121+
//
122+
// CHECK-CXX-LABEL: define dso_local <vscale x 4 x float> @_Z23test_svdot_lane_f32_mf8u13__SVFloat32_tu13__SVMfloat8_tS0_m(
123+
// CHECK-CXX-SAME: <vscale x 4 x float> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
124+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
125+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
126+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 4 x float> @llvm.aarch64.sve.fp8.fdot.lane.nxv4f32(<vscale x 4 x float> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 3)
127+
// CHECK-CXX-NEXT: ret <vscale x 4 x float> [[TMP0]]
128+
//
129+
svfloat32_t test_svdot_lane_f32_mf8(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
130+
return SVE_ACLE_FUNC(svdot_lane,_f32_mf8,_fpm)(zda, zn, zm, 3, fpm);
131+
}
132+
133+
// CHECK-LABEL: define dso_local <vscale x 8 x half> @test_svdot_lane_f16_mf8(
134+
// CHECK-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
135+
// CHECK-NEXT: [[ENTRY:.*:]]
136+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
137+
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.lane.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 7)
138+
// CHECK-NEXT: ret <vscale x 8 x half> [[TMP0]]
139+
//
140+
// CHECK-CXX-LABEL: define dso_local <vscale x 8 x half> @_Z23test_svdot_lane_f16_mf8u13__SVFloat16_tu13__SVMfloat8_tS0_m(
141+
// CHECK-CXX-SAME: <vscale x 8 x half> [[ZDA:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
142+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
143+
// CHECK-CXX-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
144+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x half> @llvm.aarch64.sve.fp8.fdot.lane.nxv8f16(<vscale x 8 x half> [[ZDA]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]], i32 7)
145+
// CHECK-CXX-NEXT: ret <vscale x 8 x half> [[TMP0]]
146+
//
147+
svfloat16_t test_svdot_lane_f16_mf8(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) STREAMING {
148+
return SVE_ACLE_FUNC(svdot_lane,_f16_mf8,_fpm)(zda, zn, zm, 7, fpm);
149+
}

clang/test/Sema/aarch64-sve2-intrinsics/acle_sve2_fp8.c

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include <arm_sve.h>
66

7-
void test_features(svmfloat8_t zn, fpm_t fpm) {
7+
void test_features(svmfloat8_t zn, svmfloat8_t zm, mfloat8_t x, fpm_t fpm) {
88
svcvt1_bf16_mf8_fpm(zn, fpm);
99
// expected-error@-1 {{'svcvt1_bf16_mf8_fpm' needs target feature (sve,sve2,fp8)|(sme,sme2,fp8)}}
1010
svcvt2_bf16_mf8_fpm(zn, fpm);
@@ -30,4 +30,25 @@ void test_features(svmfloat8_t zn, fpm_t fpm) {
3030
// expected-error@-1 {{'svcvtnb_mf8_f32_x2_fpm' needs target feature (sve,sve2,fp8)|(sme,sme2,fp8)}}
3131
svcvtnt_mf8_f32_x2_fpm(zn, svcreate2(svundef_f32(), svundef_f32()), fpm);
3232
// expected-error@-1 {{'svcvtnt_mf8_f32_x2_fpm' needs target feature (sve,sve2,fp8)|(sme,sme2,fp8)}}
33+
34+
svdot_f32_mf8_fpm(svundef_f32(), zn, zm, fpm);
35+
// expected-error@-1 {{'svdot_f32_mf8_fpm' needs target feature (sve,sve2,fp8dot4)|(sme,ssve-fp8dot4)}}
36+
svdot_n_f32_mf8_fpm(svundef_f32(), zn, x, fpm);
37+
// expected-error@-1 {{'svdot_n_f32_mf8_fpm' needs target feature (sve,sve2,fp8dot4)|(sme,ssve-fp8dot4)}}
38+
svdot_f16_mf8_fpm(svundef_f16(), zn, zm, fpm);
39+
// expected-error@-1 {{'svdot_f16_mf8_fpm' needs target feature (sve,sve2,fp8dot2)|(sme,ssve-fp8dot2)}}
40+
svdot_n_f16_mf8_fpm(svundef_f16(), zn, x, fpm);
41+
// expected-error@-1 {{'svdot_n_f16_mf8_fpm' needs target feature (sve,sve2,fp8dot2)|(sme,ssve-fp8dot2)}}
42+
svdot_lane_f32_mf8_fpm(svundef_f32(), zn, zm, 3, fpm);
43+
// expected-error@-1 {{'svdot_lane_f32_mf8_fpm' needs target feature (sve,sve2,fp8dot4)|(sme,ssve-fp8dot4)}}
44+
svdot_lane_f16_mf8_fpm(svundef_f16(), zn, zm, 7, fpm);
45+
// expected-error@-1 {{'svdot_lane_f16_mf8_fpm' needs target feature (sve,sve2,fp8dot2)|(sme,ssve-fp8dot2)}}
3346
}
47+
48+
49+
void test_imm_range(svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) {
50+
svdot_lane_f32_mf8_fpm(svundef_f32(), zn, zm, -1, fpm);
51+
// expected-error@-1 {{argument value 18446744073709551615 is outside the valid range [0, 3]}}
52+
svdot_lane_f16_mf8_fpm(svundef_f16(), zn, zm, -1, fpm);
53+
// expected-error@-1 {{argument value 18446744073709551615 is outside the valid range [0, 7]}}
54+
}

clang/utils/TableGen/SveEmitter.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ class Intrinsic {
253253
/// Return true if the intrinsic takes a splat operand.
254254
bool hasSplat() const {
255255
// These prototype modifiers are described in arm_sve.td.
256-
return Proto.find_first_of("ajfrKLR@") != std::string::npos;
256+
return Proto.find_first_of("ajfrKLR@!") != std::string::npos;
257257
}
258258

259259
/// Return the parameter index of the splat operand.
@@ -262,7 +262,7 @@ class Intrinsic {
262262
for (; I < Proto.size(); ++I, ++Param) {
263263
if (Proto[I] == 'a' || Proto[I] == 'j' || Proto[I] == 'f' ||
264264
Proto[I] == 'r' || Proto[I] == 'K' || Proto[I] == 'L' ||
265-
Proto[I] == 'R' || Proto[I] == '@')
265+
Proto[I] == 'R' || Proto[I] == '@' || Proto[I] == '!')
266266
break;
267267

268268
// Multivector modifier can be skipped
@@ -910,6 +910,11 @@ void SVEType::applyModifier(char Mod) {
910910
Kind = MFloat8;
911911
ElementBitwidth = 8;
912912
break;
913+
case '!':
914+
Kind = MFloat8;
915+
Bitwidth = ElementBitwidth = 8;
916+
NumVectors = 0;
917+
break;
913918
case '.':
914919
llvm_unreachable(". is never a type in itself");
915920
break;

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3911,6 +3911,22 @@ let TargetPrefix = "aarch64" in {
39113911
[llvm_nxv16i8_ty, llvm_anyvector_ty, LLVMMatchType<0>],
39123912
[IntrReadMem, IntrInaccessibleMemOnly]>;
39133913

3914+
// Dot product
3915+
class SVE2_FP8_FMLA_FDOT
3916+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
3917+
[LLVMMatchType<0>,
3918+
llvm_nxv16i8_ty, llvm_nxv16i8_ty],
3919+
[IntrReadMem, IntrInaccessibleMemOnly]>;
3920+
3921+
class SVE2_FP8_FMLA_FDOT_Lane
3922+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
3923+
[LLVMMatchType<0>,
3924+
llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_i32_ty],
3925+
[IntrReadMem, IntrInaccessibleMemOnly, ImmArg<ArgIndex<3>>]>;
3926+
3927+
def int_aarch64_sve_fp8_fdot : SVE2_FP8_FMLA_FDOT;
3928+
def int_aarch64_sve_fp8_fdot_lane : SVE2_FP8_FMLA_FDOT_Lane;
3929+
39143930
class SME2_FP8_CVT_X2_Single_Intrinsic
39153931
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
39163932
[llvm_nxv16i8_ty],

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4423,18 +4423,17 @@ let Predicates = [HasSVE2, HasF8F16MM] in {
44234423

44244424
let Predicates = [HasSSVE_FP8DOT2] in {
44254425
// FP8 Widening Dot-Product - Indexed Group
4426-
defm FDOT_ZZZI_BtoH : sve2_fp8_dot_indexed_h<"fdot">;
4426+
defm FDOT_ZZZI_BtoH : sve2_fp8_dot_indexed_h<"fdot", int_aarch64_sve_fp8_fdot_lane>;
44274427
// FP8 Widening Dot-Product - Group
4428-
// TODO: Replace nxv16i8 by nxv16f8
4429-
defm FDOT_ZZZ_BtoH : sve_fp8_dot<0b0, ZPR16, "fdot">;
4428+
defm FDOT_ZZZ_BtoH : sve_fp8_dot<0b0, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fp8_fdot>;
44304429
}
44314430

44324431
// TODO: Replace nxv16i8 by nxv16f8
44334432
let Predicates = [HasSSVE_FP8DOT4] in {
44344433
// FP8 Widening Dot-Product - Indexed Group
4435-
defm FDOT_ZZZI_BtoS : sve2_fp8_dot_indexed_s<"fdot">;
4434+
defm FDOT_ZZZI_BtoS : sve2_fp8_dot_indexed_s<"fdot", int_aarch64_sve_fp8_fdot_lane>;
44364435
// FP8 Widening Dot-Product - Group
4437-
defm FDOT_ZZZ_BtoS : sve_fp8_dot<0b1, ZPR32, "fdot">;
4436+
defm FDOT_ZZZ_BtoS : sve_fp8_dot<0b1, ZPR32, "fdot", nxv4f32, int_aarch64_sve_fp8_fdot>;
44384437
}
44394438

44404439
let Predicates = [HasSVE2orSME2, HasLUT] in {

0 commit comments

Comments
 (0)