Skip to content

[AArch64] Implement intrinsics for FP8 FCVT/FCVTN/BFCVT #118025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/arm_sve.td
Original file line number Diff line number Diff line change
Expand Up @@ -2436,6 +2436,12 @@ let SVETargetGuard = InvalidMode, SMETargetGuard = "sme2,fp8" in {
// Convert from FP8 to deinterleaved half-precision/BFloat16 multi-vector
def SVF1CVTL_X2 : Inst<"svcvtl1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl1_x2", [IsStreaming, SetsFPMR], []>;
def SVF2CVTL_X2 : Inst<"svcvtl2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl2_x2", [IsStreaming, SetsFPMR], []>;

// Convert from single/half/bfloat multivector to FP8
def SVFCVT_X2 : Inst<"svcvt_mf8[_{d}_x2]_fpm", "~2>", "bh", MergeNone, "aarch64_sve_fp8_cvt_x2", [IsStreaming, SetsFPMR], []>;
def SVFCVT_X4 : Inst<"svcvt_mf8[_{d}_x4]_fpm", "~4>", "f", MergeNone, "aarch64_sve_fp8_cvt_x4", [IsOverloadNone, IsStreaming, SetsFPMR], []>;
// interleaved
def SVFCVTN_X4 : Inst<"svcvtn_mf8[_{d}_x4]_fpm", "~4>", "f", MergeNone, "aarch64_sve_fp8_cvtn_x4", [IsOverloadNone, IsStreaming, SetsFPMR], []>;
}

let SVETargetGuard = "sve2p1", SMETargetGuard = "sme2" in {
Expand Down
64 changes: 64 additions & 0 deletions clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,70 @@
#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
#endif

// CHECK-LABEL: @test_cvt_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8f16(<vscale x 8 x half> [[ZN_COERCE0:%.*]], <vscale x 8 x half> [[ZN_COERCE1:%.*]])
// CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z15test_cvt_f16_x213svfloat16x2_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8f16(<vscale x 8 x half> [[ZN_COERCE0:%.*]], <vscale x 8 x half> [[ZN_COERCE1:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
svmfloat8_t test_cvt_f16_x2(svfloat16x2_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt_mf8,_f16_x2,_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt_f32_x4(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> [[ZN_COERCE0:%.*]], <vscale x 4 x float> [[ZN_COERCE1:%.*]], <vscale x 4 x float> [[ZN_COERCE2:%.*]], <vscale x 4 x float> [[ZN_COERCE3:%.*]])
// CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z15test_cvt_f32_x413svfloat32x4_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> [[ZN_COERCE0:%.*]], <vscale x 4 x float> [[ZN_COERCE1:%.*]], <vscale x 4 x float> [[ZN_COERCE2:%.*]], <vscale x 4 x float> [[ZN_COERCE3:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
svmfloat8_t test_cvt_f32_x4(svfloat32x4_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt_mf8,_f32_x4,_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvtn_f32_x4(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvtn.x4(<vscale x 4 x float> [[ZN_COERCE0:%.*]], <vscale x 4 x float> [[ZN_COERCE1:%.*]], <vscale x 4 x float> [[ZN_COERCE2:%.*]], <vscale x 4 x float> [[ZN_COERCE3:%.*]])
// CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z16test_cvtn_f32_x413svfloat32x4_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvtn.x4(<vscale x 4 x float> [[ZN_COERCE0:%.*]], <vscale x 4 x float> [[ZN_COERCE1:%.*]], <vscale x 4 x float> [[ZN_COERCE2:%.*]], <vscale x 4 x float> [[ZN_COERCE3:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
svmfloat8_t test_cvtn_f32_x4(svfloat32x4_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvtn_mf8,_f32_x4,_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt_bf16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8bf16(<vscale x 8 x bfloat> [[ZN_COERCE0:%.*]], <vscale x 8 x bfloat> [[ZN_COERCE1:%.*]])
// CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z16test_cvt_bf16_x214svbfloat16x2_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8bf16(<vscale x 8 x bfloat> [[ZN_COERCE0:%.*]], <vscale x 8 x bfloat> [[ZN_COERCE1:%.*]])
// CPP-CHECK-NEXT: ret <vscale x 16 x i8> [[TMP0]]
//
svmfloat8_t test_cvt_bf16_x2(svbfloat16x2_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt_mf8,_bf16_x2,_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt1_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
Expand Down
12 changes: 11 additions & 1 deletion clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#include <arm_sve.h>


void test_features_sme2_fp8(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
void test_features_sme2_fp8(svmfloat8_t zn, svfloat16x2_t znf16, svbfloat16x2_t znbf16,
svfloat32x4_t znf32, fpm_t fpmr) __arm_streaming {
// expected-error@+1 {{'svcvtl1_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvtl1_f16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvtl2_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
Expand All @@ -23,4 +24,13 @@ void test_features_sme2_fp8(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
svcvt1_bf16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvt2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt2_bf16_mf8_x2_fpm(zn, fpmr);

// expected-error@+1 {{'svcvt_mf8_f16_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt_mf8_f16_x2_fpm(znf16, fpmr);
// expected-error@+1 {{'svcvt_mf8_bf16_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt_mf8_bf16_x2_fpm(znbf16, fpmr);
// expected-error@+1 {{'svcvt_mf8_f32_x4_fpm' needs target feature sme,sme2,fp8}}
svcvt_mf8_f32_x4_fpm(znf32, fpmr);
// expected-error@+1 {{'svcvtn_mf8_f32_x4_fpm' needs target feature sme,sme2,fp8}}
svcvtn_mf8_f32_x4_fpm(znf32, fpmr);
}
17 changes: 17 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -3812,6 +3812,7 @@ let TargetPrefix = "aarch64" in {
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>,
LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>],
[IntrNoMem]>;

}

// SVE2.1 - ZIPQ1, ZIPQ2, UZPQ1, UZPQ2
Expand Down Expand Up @@ -3876,6 +3877,11 @@ let TargetPrefix = "aarch64" in {
[llvm_nxv16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;

class SME2_FP8_CVT_Single_X4_Intrinsic
: DefaultAttrsIntrinsic<[llvm_nxv16i8_ty],
[llvm_nxv4f32_ty, llvm_nxv4f32_ty, llvm_nxv4f32_ty, llvm_nxv4f32_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;

class SME_FP8_OuterProduct_Intrinsic
: DefaultAttrsIntrinsic<[],
[llvm_i32_ty,
Expand All @@ -3894,6 +3900,17 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;

//
// CVT to FP8 from half-precision/BFloat16/single-precision multi-vector
//
def int_aarch64_sve_fp8_cvt_x2
: DefaultAttrsIntrinsic<[llvm_nxv16i8_ty],
[llvm_anyvector_ty, LLVMMatchType<0>],
[IntrReadMem, IntrInaccessibleMemOnly]>;

def int_aarch64_sve_fp8_cvt_x4 : SME2_FP8_CVT_Single_X4_Intrinsic;
def int_aarch64_sve_fp8_cvtn_x4 : SME2_FP8_CVT_Single_X4_Intrinsic;

// FP8 outer product
def int_aarch64_sme_fp8_fmopa_za16 : SME_FP8_OuterProduct_Intrinsic;
def int_aarch64_sme_fp8_fmopa_za32 : SME_FP8_OuterProduct_Intrinsic;
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,10 @@ defm F2CVTL_2ZZ_BtoH : sme2p1_fp8_cvt_vector_vg2_single<"f2cvtl", 0b10, 0b1>;
defm BF2CVT_2ZZ_BtoH : sme2p1_fp8_cvt_vector_vg2_single<"bf2cvt", 0b11, 0b0>;
defm BF2CVTL_2ZZ_BtoH : sme2p1_fp8_cvt_vector_vg2_single<"bf2cvtl", 0b11, 0b1>;

defm FCVT_Z2Z_HtoB : sme2_fp8_cvt_vg2_single<"fcvt", 0b0>;
defm BFCVT_Z2Z_HtoB : sme2_fp8_cvt_vg2_single<"bfcvt", 0b1>;
defm FCVT_Z4Z_StoB : sme2_fp8_cvt_vg4_single<"fcvt", 0b0>;
defm FCVTN_Z4Z_StoB : sme2_fp8_cvt_vg4_single<"fcvtn", 0b1>;
defm FCVT_Z2Z_HtoB : sme2_fp8_cvt_vg2_single<"fcvt", 0b0, nxv8f16, int_aarch64_sve_fp8_cvt_x2>;
defm BFCVT_Z2Z_HtoB : sme2_fp8_cvt_vg2_single<"bfcvt", 0b1, nxv8bf16, int_aarch64_sve_fp8_cvt_x2>;
defm FCVT_Z4Z_StoB : sme2_fp8_cvt_vg4_single<"fcvt", 0b0, int_aarch64_sve_fp8_cvt_x4>;
defm FCVTN_Z4Z_StoB : sme2_fp8_cvt_vg4_single<"fcvtn", 0b1, int_aarch64_sve_fp8_cvtn_x4>;

defm FSCALE_2ZZ : sme2_fp_sve_destructive_vector_vg2_single<"fscale", 0b0011000>;
defm FSCALE_4ZZ : sme2_fp_sve_destructive_vector_vg4_single<"fscale", 0b0011000>;
Expand Down
15 changes: 12 additions & 3 deletions llvm/lib/Target/AArch64/SMEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -2398,10 +2398,14 @@ multiclass sme2_cvt_vg2_single<string mnemonic, bits<5> op, ValueType out_vt,
}

// SME2 multi-vec FP8 down convert two registers
multiclass sme2_fp8_cvt_vg2_single<string mnemonic, bit op> {
multiclass sme2_fp8_cvt_vg2_single<string mnemonic, bit op, ValueType in_vt, SDPatternOperator intrinsic> {
def NAME : sme2_cvt_vg2_single<mnemonic, {op, 0b1000}, ZPR8, ZZ_h_mul_r>{
let mayLoad = 1;
let mayStore = 0;
let Uses = [FPMR, FPCR];
}
def : Pat<(nxv16i8 (intrinsic in_vt:$Zn1, in_vt:$Zn2)),
(!cast<Instruction>(NAME) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1))>;
}

class sme2_cvt_unpk_vector_vg2<bits<2>sz, bits<3> op, bit u, RegisterOperand first_ty,
Expand Down Expand Up @@ -2467,8 +2471,13 @@ multiclass sme2_int_cvt_vg4_single<string mnemonic, bits<3> op, SDPatternOperato
}

//SME2 multi-vec FP8 down convert four registers
multiclass sme2_fp8_cvt_vg4_single<string mnemonic, bit N> {
def _NAME : sme2_cvt_vg4_single<0b0, {0b00, N}, 0b0100, ZPR8, ZZZZ_s_mul_r, mnemonic>;
multiclass sme2_fp8_cvt_vg4_single<string mnemonic, bit N, SDPatternOperator intrinsic> {
def NAME : sme2_cvt_vg4_single<0b0, {0b00, N}, 0b0100, ZPR8, ZZZZ_s_mul_r, mnemonic> {
let mayLoad = 1;
let mayStore = 0;
let Uses = [FPMR, FPCR];
}
def : SME2_Cvt_VG4_Pat<NAME, intrinsic, nxv16i8, nxv4f32>;
}

class sme2_unpk_vector_vg4<bits<2>sz, bit u, RegisterOperand first_ty,
Expand Down
52 changes: 52 additions & 0 deletions llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,58 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s

; FCVT / FCVTN / BFCVT

define <vscale x 16 x i8> @fcvt_x2(<vscale x 8 x half> %zn0, <vscale x 8 x half> %zn1) {
; CHECK-LABEL: fcvt_x2:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $z1 killed $z1 killed $z0_z1 def $z0_z1
; CHECK-NEXT: // kill: def $z0 killed $z0 killed $z0_z1 def $z0_z1
; CHECK-NEXT: fcvt z0.b, { z0.h, z1.h }
; CHECK-NEXT: ret
%res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8f16(<vscale x 8 x half> %zn0, <vscale x 8 x half> %zn1)
ret <vscale x 16 x i8> %res
}

define <vscale x 16 x i8> @fcvt_x4(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1, <vscale x 4 x float> %zn2, <vscale x 4 x float> %zn3) {
; CHECK-LABEL: fcvt_x4:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $z3 killed $z3 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: // kill: def $z2 killed $z2 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: // kill: def $z1 killed $z1 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: // kill: def $z0 killed $z0 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: fcvt z0.b, { z0.s - z3.s }
; CHECK-NEXT: ret
%res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1,
<vscale x 4 x float> %zn2, <vscale x 4 x float> %zn3)
ret <vscale x 16 x i8> %res
}

define <vscale x 16 x i8> @fcvtn(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1, <vscale x 4 x float> %zn2, <vscale x 4 x float> %zn3) {
; CHECK-LABEL: fcvtn:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $z3 killed $z3 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: // kill: def $z2 killed $z2 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: // kill: def $z1 killed $z1 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: // kill: def $z0 killed $z0 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
; CHECK-NEXT: fcvtn z0.b, { z0.s - z3.s }
; CHECK-NEXT: ret
%res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvtn.x4(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1,
<vscale x 4 x float> %zn2, <vscale x 4 x float> %zn3)
ret <vscale x 16 x i8> %res
}

define <vscale x 16 x i8> @bfcvt(<vscale x 8 x bfloat> %zn0, <vscale x 8 x bfloat> %zn1) {
; CHECK-LABEL: bfcvt:
; CHECK: // %bb.0:
; CHECK-NEXT: // kill: def $z1 killed $z1 killed $z0_z1 def $z0_z1
; CHECK-NEXT: // kill: def $z0 killed $z0 killed $z0_z1 def $z0_z1
; CHECK-NEXT: bfcvt z0.b, { z0.h, z1.h }
; CHECK-NEXT: ret
%res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8bf16(<vscale x 8 x bfloat> %zn0, <vscale x 8 x bfloat> %zn1)
ret <vscale x 16 x i8> %res
}

; F1CVT / F2CVT

define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvt(<vscale x 16 x i8> %zm) {
Expand Down
Loading