Skip to content

Commit 0cafe35

Browse files
committed
[AArch64] Implement intrinsics for FP8 SME FMLAL/FMLALL (multi)
1 parent 75b2d78 commit 0cafe35

File tree

7 files changed

+215
-17
lines changed

7 files changed

+215
-17
lines changed

clang/include/clang/Basic/arm_sme.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,11 @@ let SMETargetGuard = "sme-f8f32" in {
873873
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
874874
def SVMLA_FP8_SINGLE_ZA32_VG4x4 : Inst<"svmla[_single]_za32[_mf8]_vg4x4_fpm", "vm4d>", "m", MergeNone, "aarch64_sme_fp8_fmlall_single_za32_vg4x4",
875875
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
876+
// FMLALL (mutliple)
877+
def SVMLA_FP8_MULTI_ZA32_VG4x2 : Inst<"svmla_za32[_mf8]_vg4x2_fpm", "vm22>", "m", MergeNone, "aarch64_sme_fp8_fmlall_multi_za32_vg4x2",
878+
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
879+
def SVMLA_FP8_MULTI_ZA32_VG4x4 : Inst<"svmla_za32[_mf8]_vg4x4_fpm", "vm44>", "m", MergeNone, "aarch64_sme_fp8_fmlall_multi_za32_vg4x4",
880+
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
876881
}
877882

878883
let SMETargetGuard = "sme-f8f16" in {
@@ -892,6 +897,11 @@ let SMETargetGuard = "sme-f8f16" in {
892897
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
893898
def SVMLA_FP8_SINGLE_ZA16_VG2x4 : Inst<"svmla[_single]_za16[_mf8]_vg2x4_fpm", "vm4d>", "m", MergeNone, "aarch64_sme_fp8_fmlal_single_za16_vg2x4",
894899
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
900+
// FMLAL (mutliple)
901+
def SVMLA_FP8_MULTI_ZA16_VG2x2 : Inst<"svmla_za16[_mf8]_vg2x2_fpm", "vm22>", "m", MergeNone, "aarch64_sme_fp8_fmlal_multi_za16_vg2x2",
902+
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
903+
def SVMLA_FP8_MULTI_ZA16_VG2x4 : Inst<"svmla_za16[_mf8]_vg2x4_fpm", "vm44>", "m", MergeNone, "aarch64_sme_fp8_fmlal_multi_za16_vg2x4",
904+
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], []>;
895905
}
896906

897907
} // let SVETargetGuard = InvalidMode

clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_mla.c

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
32
// REQUIRES: aarch64-registered-target
43

@@ -239,3 +238,79 @@ void test_svmla_single_za32_vg4x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8_t
239238
void test_svmla_single_za32_vg4x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
240239
SME_ACLE_FUNC(svmla,_single,_za32,_mf8,_vg4x4_fpm)(slice, zn, zm, fpm);
241240
}
241+
242+
// FMLAL (multi)
243+
244+
// CHECK-LABEL: define dso_local void @test_svmla_multi_za16_vg2x2(
245+
// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
246+
// CHECK-NEXT: [[ENTRY:.*:]]
247+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
248+
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
249+
// CHECK-NEXT: ret void
250+
//
251+
// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za16_vg2x2j13svmfloat8x2_tS_m(
252+
// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0:[0-9]+]] {
253+
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
254+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
255+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
256+
// CPP-CHECK-NEXT: ret void
257+
//
258+
void test_svmla_multi_za16_vg2x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
259+
SME_ACLE_FUNC(svmla_za16,_mf8,_vg2x2_fpm,,)(slice, zn, zm, fpm);
260+
}
261+
262+
// CHECK-LABEL: define dso_local void @test_svmla_multi_za16_vg2x4(
263+
// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
264+
// CHECK-NEXT: [[ENTRY:.*:]]
265+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
266+
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
267+
// CHECK-NEXT: ret void
268+
//
269+
// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za16_vg2x4j13svmfloat8x4_tS_m(
270+
// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
271+
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
272+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
273+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlal.multi.za16.vg2x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
274+
// CPP-CHECK-NEXT: ret void
275+
//
276+
void test_svmla_multi_za16_vg2x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
277+
SME_ACLE_FUNC(svmla_za16,_mf8,_vg2x4_fpm,,)(slice, zn, zm, fpm);
278+
}
279+
280+
// FMLALL (multi)
281+
282+
// CHECK-LABEL: define dso_local void @test_svmla_multi_za32_vg4x2(
283+
// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
284+
// CHECK-NEXT: [[ENTRY:.*:]]
285+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
286+
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
287+
// CHECK-NEXT: ret void
288+
//
289+
// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za32_vg4x2j13svmfloat8x2_tS_m(
290+
// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
291+
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
292+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
293+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x2(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]])
294+
// CPP-CHECK-NEXT: ret void
295+
//
296+
void test_svmla_multi_za32_vg4x2(uint32_t slice, svmfloat8x2_t zn, svmfloat8x2_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
297+
SME_ACLE_FUNC(svmla_za32,_mf8,_vg4x2_fpm,,)(slice, zn, zm, fpm);
298+
}
299+
300+
// CHECK-LABEL: define dso_local void @test_svmla_multi_za32_vg4x4(
301+
// CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
302+
// CHECK-NEXT: [[ENTRY:.*:]]
303+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
304+
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
305+
// CHECK-NEXT: ret void
306+
//
307+
// CPP-CHECK-LABEL: define dso_local void @_Z27test_svmla_multi_za32_vg4x4j13svmfloat8x4_tS_m(
308+
// CPP-CHECK-SAME: i32 noundef [[SLICE:%.*]], <vscale x 16 x i8> [[ZN_COERCE0:%.*]], <vscale x 16 x i8> [[ZN_COERCE1:%.*]], <vscale x 16 x i8> [[ZN_COERCE2:%.*]], <vscale x 16 x i8> [[ZN_COERCE3:%.*]], <vscale x 16 x i8> [[ZM_COERCE0:%.*]], <vscale x 16 x i8> [[ZM_COERCE1:%.*]], <vscale x 16 x i8> [[ZM_COERCE2:%.*]], <vscale x 16 x i8> [[ZM_COERCE3:%.*]], i64 noundef [[FPM:%.*]]) #[[ATTR0]] {
309+
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
310+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPM]])
311+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmlall.multi.za32.vg4x4(i32 [[SLICE]], <vscale x 16 x i8> [[ZN_COERCE0]], <vscale x 16 x i8> [[ZN_COERCE1]], <vscale x 16 x i8> [[ZN_COERCE2]], <vscale x 16 x i8> [[ZN_COERCE3]], <vscale x 16 x i8> [[ZM_COERCE0]], <vscale x 16 x i8> [[ZM_COERCE1]], <vscale x 16 x i8> [[ZM_COERCE2]], <vscale x 16 x i8> [[ZM_COERCE3]])
312+
// CPP-CHECK-NEXT: ret void
313+
//
314+
void test_svmla_multi_za32_vg4x4(uint32_t slice, svmfloat8x4_t zn, svmfloat8x4_t zm, fpm_t fpm) __arm_streaming __arm_inout("za") {
315+
SME_ACLE_FUNC(svmla_za32,_mf8,_vg4x4_fpm,,)(slice, zn, zm, fpm);
316+
}

clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mla.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,16 @@ void test_svmla(uint32_t slice, svmfloat8_t zn, svmfloat8x2_t znx2, svmfloat8x4_
4141

4242
// expected-error@+1 {{'svmla_single_za32_mf8_vg4x4_fpm' needs target feature sme,sme-f8f32}}
4343
svmla_single_za32_mf8_vg4x4_fpm(slice, znx4, zn, fpmr);
44+
45+
// expected-error@+1 {{'svmla_za16_mf8_vg2x2_fpm' needs target feature sme,sme-f8f16}}
46+
svmla_za16_mf8_vg2x2_fpm(slice, znx2, znx2, fpmr);
47+
48+
// expected-error@+1 {{'svmla_za16_mf8_vg2x4_fpm' needs target feature sme,sme-f8f16}}
49+
svmla_za16_mf8_vg2x4_fpm(slice, znx4, znx4, fpmr);
50+
51+
// expected-error@+1 {{'svmla_za32_mf8_vg4x2_fpm' needs target feature sme,sme-f8f32}}
52+
svmla_za32_mf8_vg4x2_fpm(slice, znx2, znx2, fpmr);
53+
54+
// expected-error@+1 {{'svmla_za32_mf8_vg4x4_fpm' needs target feature sme,sme-f8f32}}
55+
svmla_za32_mf8_vg4x4_fpm(slice, znx4, znx4, fpmr);
4456
}

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3997,6 +3997,18 @@ let TargetPrefix = "aarch64" in {
39973997
: DefaultAttrsIntrinsic<[], [llvm_i32_ty,
39983998
llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty,
39993999
llvm_nxv16i8_ty],
4000+
[IntrInaccessibleMemOnly, IntrHasSideEffects]>;
4001+
4002+
class SME_FP8_ZA_MULTI_VGx2_Intrinsic
4003+
: DefaultAttrsIntrinsic<[], [llvm_i32_ty,
4004+
llvm_nxv16i8_ty, llvm_nxv16i8_ty,
4005+
llvm_nxv16i8_ty, llvm_nxv16i8_ty],
4006+
[IntrInaccessibleMemOnly, IntrHasSideEffects]>;
4007+
4008+
class SME_FP8_ZA_MULTI_VGx4_Intrinsic
4009+
: DefaultAttrsIntrinsic<[], [llvm_i32_ty,
4010+
llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty,
4011+
llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty, llvm_nxv16i8_ty],
40004012
[IntrInaccessibleMemOnly, IntrHasSideEffects]>;
40014013
//
40024014
// CVT from FP8 to half-precision/BFloat16 multi-vector
@@ -4036,6 +4048,9 @@ let TargetPrefix = "aarch64" in {
40364048
def int_aarch64_sme_fp8_fmlal_single_za16_vg2x1 : SME_FP8_ZA_SINGLE_VGx1_Intrinsic;
40374049
def int_aarch64_sme_fp8_fmlal_single_za16_vg2x2 : SME_FP8_ZA_SINGLE_VGx2_Intrinsic;
40384050
def int_aarch64_sme_fp8_fmlal_single_za16_vg2x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic;
4051+
// Multi
4052+
def int_aarch64_sme_fp8_fmlal_multi_za16_vg2x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic;
4053+
def int_aarch64_sme_fp8_fmlal_multi_za16_vg2x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic;
40394054

40404055
// Quad-vector groups (F8F32)
40414056
def int_aarch64_sme_fp8_fmlall_lane_za32_vg4x1 : SME_FP8_ZA_LANE_VGx1_Intrinsic;
@@ -4045,6 +4060,9 @@ let TargetPrefix = "aarch64" in {
40454060
def int_aarch64_sme_fp8_fmlall_single_za32_vg4x1 : SME_FP8_ZA_SINGLE_VGx1_Intrinsic;
40464061
def int_aarch64_sme_fp8_fmlall_single_za32_vg4x2 : SME_FP8_ZA_SINGLE_VGx2_Intrinsic;
40474062
def int_aarch64_sme_fp8_fmlall_single_za32_vg4x4 : SME_FP8_ZA_SINGLE_VGx4_Intrinsic;
4063+
// Multi
4064+
def int_aarch64_sme_fp8_fmlall_multi_za32_vg4x2 : SME_FP8_ZA_MULTI_VGx2_Intrinsic;
4065+
def int_aarch64_sme_fp8_fmlall_multi_za32_vg4x4 : SME_FP8_ZA_MULTI_VGx4_Intrinsic;
40484066

40494067
//
40504068
// FP8 FDOT intrinsics

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,8 +1003,9 @@ defm FMLAL_VG2_MZZ_BtoH : sme2_fp8_fmlal_single_za16<"fmlal", int_aarch64_sme_f
10031003
defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x2, [FPMR, FPCR]>;
10041004
defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlal_single_za16_vg2x4, [FPMR, FPCR]>;
10051005

1006-
defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
1007-
defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
1006+
// FP8 FMLALL (multi)
1007+
defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlal_multi_za16_vg2x2, [FPMR, FPCR]>;
1008+
defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlal_multi_za16_vg2x4, [FPMR, FPCR]>;
10081009

10091010
defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>;
10101011
} //[HasSMEF8F16]
@@ -1030,8 +1031,9 @@ defm FMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixO
10301031
defm FMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg2_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x2, [FPMR, FPCR]>;
10311032
defm FMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg4_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8, nxv16i8, int_aarch64_sme_fp8_fmlall_single_za32_vg4x4, [FPMR, FPCR]>;
10321033

1033-
defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
1034-
defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
1034+
// FP8 FMLALL (multi)
1035+
defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlall_multi_za32_vg4x2, [FPMR, FPCR]>;
1036+
defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, int_aarch64_sme_fp8_fmlall_multi_za32_vg4x4, [FPMR, FPCR]>;
10351037

10361038
defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>;
10371039
} //[HasSMEF8F32]

llvm/lib/Target/AArch64/SMEInstrFormats.td

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,11 +2262,12 @@ class sme2_mla_long_array_vg2_multi<string mnemonic, bits<2> op0, bits<3> op,
22622262
}
22632263

22642264
multiclass sme2_fp_mla_long_array_vg2_multi<string mnemonic, bits<3> op, MatrixOperand matrix_ty,
2265-
RegisterOperand multi_vector_ty,
2266-
ValueType zpr_ty, SDPatternOperator intrinsic> {
2267-
2265+
RegisterOperand multi_vector_ty, ValueType zpr_ty,
2266+
SDPatternOperator intrinsic, list<Register> uses=[]> {
22682267
def NAME : sme2_mla_long_array_vg2_multi<mnemonic, 0b10, op, matrix_ty, multi_vector_ty>,
2269-
SMEPseudo2Instr<NAME, 1>;
2268+
SMEPseudo2Instr<NAME, 1> {
2269+
let Uses = uses;
2270+
}
22702271

22712272
def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm2s2range, multi_vector_ty, SMEMatrixArray>;
22722273

@@ -2309,9 +2310,11 @@ class sme2_mla_long_array_vg4_multi<string mnemonic, bits<2> op0, bits<3> op,
23092310

23102311
multiclass sme2_fp_mla_long_array_vg4_multi<string mnemonic, bits<3> op, MatrixOperand matrix_ty,
23112312
RegisterOperand multi_vector_ty, ValueType zpr_ty,
2312-
SDPatternOperator intrinsic> {
2313+
SDPatternOperator intrinsic, list<Register> uses=[]> {
23132314
def NAME : sme2_mla_long_array_vg4_multi<mnemonic, 0b10, op, matrix_ty, multi_vector_ty>,
2314-
SMEPseudo2Instr<NAME, 1>;
2315+
SMEPseudo2Instr<NAME, 1> {
2316+
let Uses = uses;
2317+
}
23152318

23162319
def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm2s2range, multi_vector_ty, SMEMatrixArray>;
23172320

@@ -3341,9 +3344,11 @@ class sme2_mla_ll_array_vg2_multi<bits<5> op, MatrixOperand matrix_ty,
33413344

33423345
multiclass sme2_mla_ll_array_vg2_multi<string mnemonic, bits<5> op,
33433346
MatrixOperand matrix_ty,
3344-
RegisterOperand vector_ty,
3345-
ValueType vt, SDPatternOperator intrinsic> {
3346-
def NAME : sme2_mla_ll_array_vg2_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1>;
3347+
RegisterOperand vector_ty, ValueType vt,
3348+
SDPatternOperator intrinsic, list<Register> uses=[]> {
3349+
def NAME : sme2_mla_ll_array_vg2_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1> {
3350+
let Uses = uses;
3351+
}
33473352

33483353
def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm1s4range, vector_ty, SMEMatrixArray>;
33493354

@@ -3385,9 +3390,11 @@ class sme2_mla_ll_array_vg4_multi<bits<5> op,MatrixOperand matrix_ty,
33853390

33863391
multiclass sme2_mla_ll_array_vg4_multi<string mnemonic, bits<5> op,
33873392
MatrixOperand matrix_ty,
3388-
RegisterOperand vector_ty,
3389-
ValueType vt, SDPatternOperator intrinsic> {
3390-
def NAME : sme2_mla_ll_array_vg4_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1>;
3393+
RegisterOperand vector_ty, ValueType vt,
3394+
SDPatternOperator intrinsic, list<Register> uses=[]> {
3395+
def NAME : sme2_mla_ll_array_vg4_multi<op, matrix_ty, vector_ty, mnemonic>, SMEPseudo2Instr<NAME, 1> {
3396+
let Uses = uses;
3397+
}
33913398

33923399
def _PSEUDO : sme2_za_array_2op_multi_multi_pseudo<NAME, uimm1s4range, vector_ty, SMEMatrixArray>;
33933400

0 commit comments

Comments
 (0)