Skip to content

[AArch64] Implement intrinsics for SME FP8 F1CVT/F2CVT and BF1CVT/BF2CVT #118027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/arm_sve.td
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,10 @@ let SVETargetGuard = InvalidMode, SMETargetGuard = "sme2,fp8" in {
def FSCALE_X2 : Inst<"svscale[_{d}_x2]", "222.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x2", [IsStreaming],[]>;
def FSCALE_X4 : Inst<"svscale[_{d}_x4]", "444.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x4", [IsStreaming],[]>;

// Convert from FP8 to half-precision/BFloat16 multi-vector
def SVF1CVT : Inst<"svcvt1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt1_x2", [IsStreaming, SetsFPMR], []>;
def SVF2CVT : Inst<"svcvt2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt2_x2", [IsStreaming, SetsFPMR], []>;

// Convert from FP8 to deinterleaved half-precision/BFloat16 multi-vector
def SVF1CVTL : Inst<"svcvtl1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl1_x2", [IsStreaming, SetsFPMR], []>;
def SVF2CVTL : Inst<"svcvtl2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl2_x2", [IsStreaming, SetsFPMR], []>;
Expand Down
64 changes: 64 additions & 0 deletions clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,70 @@
#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
#endif

// CHECK-LABEL: @test_cvt1_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z16test_cvt1_f16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
svfloat16x2_t test_cvt1_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt1_f16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt2_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z16test_cvt2_f16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
svfloat16x2_t test_cvt2_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt2_f16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt1_bf16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z17test_cvt1_bf16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
svbfloat16x2_t test_cvt1_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt1_bf16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt2_bf16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z17test_cvt2_bf16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
svbfloat16x2_t test_cvt2_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt2_bf16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvtl1_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
Expand Down
9 changes: 9 additions & 0 deletions clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ void test_features_sme2_fp8(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
svcvtl1_bf16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvtl2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvtl2_bf16_mf8_x2_fpm(zn, fpmr);

// expected-error@+1 {{'svcvt1_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt1_f16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvt2_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt2_f16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvt1_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt1_bf16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvt2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt2_bf16_mf8_x2_fpm(zn, fpmr);
}
32 changes: 22 additions & 10 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -3812,16 +3812,6 @@ let TargetPrefix = "aarch64" in {
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>,
LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>],
[IntrNoMem]>;

class SME2_FP8_CVT_X2_Single_Intrinsic
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
[llvm_nxv16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;
//
// CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector
//
def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
}

// SVE2.1 - ZIPQ1, ZIPQ2, UZPQ1, UZPQ2
Expand Down Expand Up @@ -3864,3 +3854,25 @@ def int_aarch64_sve_famin_u : AdvSIMD_Pred2VectorArg_Intrinsic;
// Neon absolute maximum and minimum
def int_aarch64_neon_famax : AdvSIMD_2VectorArg_Intrinsic;
def int_aarch64_neon_famin : AdvSIMD_2VectorArg_Intrinsic;

//
// FP8 Intrinsics
//
let TargetPrefix = "aarch64" in {

class SME2_FP8_CVT_X2_Single_Intrinsic
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
[llvm_nxv16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;
//
// CVT from FP8 to half-precision/BFloat16 multi-vector
//
def int_aarch64_sve_fp8_cvt1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sve_fp8_cvt2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;

//
// CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector
//
def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
}
12 changes: 12 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5581,6 +5581,18 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
{AArch64::BF2CVTL_2ZZ_BtoH, AArch64::F2CVTL_2ZZ_BtoH}))
SelectCVTIntrinsicFP8(Node, 2, Opc);
return;
case Intrinsic::aarch64_sve_fp8_cvt1_x2:
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
Node->getValueType(0),
{AArch64::BF1CVT_2ZZ_BtoH, AArch64::F1CVT_2ZZ_BtoH}))
SelectCVTIntrinsicFP8(Node, 2, Opc);
return;
case Intrinsic::aarch64_sve_fp8_cvt2_x2:
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
Node->getValueType(0),
{AArch64::BF2CVT_2ZZ_BtoH, AArch64::F2CVT_2ZZ_BtoH}))
SelectCVTIntrinsicFP8(Node, 2, Opc);
return;
}
} break;
case ISD::INTRINSIC_WO_CHAIN: {
Expand Down
40 changes: 40 additions & 0 deletions llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s

; F1CVT / F2CVT

define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: f1cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: f1cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
}

define { <vscale x 8 x half>, <vscale x 8 x half> } @f2cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: f2cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: f2cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
}

; BF1CVT / BF2CVT

define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf1cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: bf1cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: bf1cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
}

define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf2cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: bf2cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: bf2cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
}

; F1CVTL / F2CVTL

define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvtl(<vscale x 16 x i8> %zm) {
Expand Down
Loading