Skip to content

[AArch64] Implement intrinsics for SME FP8 F1CVT/F2CVT and BF1CVT/BF2CVT #118027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 8, 2024

Conversation

SpencerAbson
Copy link
Contributor

This patch implements the following intrinsics:

8-bit floating-point convert to half-precision or BFloat16 (in-order).

  // Variant is also available for: _bf16[_mf8]_x2
  svfloat16x2_t svcvt1_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming;
  svfloat16x2_t svcvt2_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming;

In accordance with ARM-software/acle#323.

Co-authored-by: Marin Lukac [email protected]
Co-authored-by: Caroline Concatto [email protected]

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AArch64 clang:frontend Language frontend issues, e.g. anything involving "Sema" llvm:ir labels Nov 28, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-aarch64

Author: None (SpencerAbson)

Changes

This patch implements the following intrinsics:

8-bit floating-point convert to half-precision or BFloat16 (in-order).

  // Variant is also available for: _bf16[_mf8]_x2
  svfloat16x2_t svcvt1_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming;
  svfloat16x2_t svcvt2_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming;

In accordance with ARM-software/acle#323.

Co-authored-by: Marin Lukac [email protected]
Co-authored-by: Caroline Concatto [email protected]


Full diff: https://github.com/llvm/llvm-project/pull/118027.diff

6 Files Affected:

  • (modified) clang/include/clang/Basic/arm_sve.td (+4)
  • (modified) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c (+64)
  • (modified) clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c (+9)
  • (modified) llvm/include/llvm/IR/IntrinsicsAArch64.td (+7)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp (+12)
  • (modified) llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll (+40)
diff --git a/clang/include/clang/Basic/arm_sve.td b/clang/include/clang/Basic/arm_sve.td
index b36e592042da0b..3be5f08b07124a 100644
--- a/clang/include/clang/Basic/arm_sve.td
+++ b/clang/include/clang/Basic/arm_sve.td
@@ -2429,6 +2429,10 @@ let SVETargetGuard = InvalidMode, SMETargetGuard = "sme2,fp8" in {
   def FSCALE_X2 : Inst<"svscale[_{d}_x2]", "222.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x2", [IsStreaming],[]>;
   def FSCALE_X4 : Inst<"svscale[_{d}_x4]", "444.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x4", [IsStreaming],[]>;
 
+  // Convert from FP8 to half-precision/BFloat16 multi-vector
+  def SVF1CVT : Inst<"svcvt1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt1_x2", [IsStreaming, SetsFPMR], []>;
+  def SVF2CVT : Inst<"svcvt2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt2_x2", [IsStreaming, SetsFPMR], []>;
+
   // Convert from FP8 to deinterleaved half-precision/BFloat16 multi-vector
   def SVF1CVTL : Inst<"svcvtl1_{d}[_mf8]_x2_fpm",  "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl1_x2",  [IsStreaming, SetsFPMR], []>;
   def SVF2CVTL : Inst<"svcvtl2_{d}[_mf8]_x2_fpm",  "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl2_x2",  [IsStreaming, SetsFPMR], []>;
diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c
index 5ba76671ff5d5b..13609f034da336 100644
--- a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c
@@ -16,6 +16,70 @@
 #define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
 #endif
 
+// CHECK-LABEL: @test_cvt1_f16_x2(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
+// CHECK-NEXT:    ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
+//
+// CPP-CHECK-LABEL: @_Z16test_cvt1_f16_x2u13__SVMfloat8_tm(
+// CPP-CHECK-NEXT:  entry:
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CPP-CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
+// CPP-CHECK-NEXT:    ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
+//
+svfloat16x2_t test_cvt1_f16_x2(svmfloat8_t zn, fpm_t fpmr)  __arm_streaming {
+  return SVE_ACLE_FUNC(svcvt1_f16,_mf8,_x2_fpm)(zn, fpmr);
+}
+
+// CHECK-LABEL: @test_cvt2_f16_x2(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
+// CHECK-NEXT:    ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
+//
+// CPP-CHECK-LABEL: @_Z16test_cvt2_f16_x2u13__SVMfloat8_tm(
+// CPP-CHECK-NEXT:  entry:
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CPP-CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
+// CPP-CHECK-NEXT:    ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
+//
+svfloat16x2_t test_cvt2_f16_x2(svmfloat8_t zn, fpm_t fpmr)  __arm_streaming {
+  return SVE_ACLE_FUNC(svcvt2_f16,_mf8,_x2_fpm)(zn, fpmr);
+}
+
+// CHECK-LABEL: @test_cvt1_bf16_x2(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
+// CHECK-NEXT:    ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
+//
+// CPP-CHECK-LABEL: @_Z17test_cvt1_bf16_x2u13__SVMfloat8_tm(
+// CPP-CHECK-NEXT:  entry:
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CPP-CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
+// CPP-CHECK-NEXT:    ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
+//
+svbfloat16x2_t test_cvt1_bf16_x2(svmfloat8_t zn, fpm_t fpmr)  __arm_streaming {
+  return SVE_ACLE_FUNC(svcvt1_bf16,_mf8,_x2_fpm)(zn, fpmr);
+}
+
+// CHECK-LABEL: @test_cvt2_bf16_x2(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
+// CHECK-NEXT:    ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
+//
+// CPP-CHECK-LABEL: @_Z17test_cvt2_bf16_x2u13__SVMfloat8_tm(
+// CPP-CHECK-NEXT:  entry:
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
+// CPP-CHECK-NEXT:    [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
+// CPP-CHECK-NEXT:    ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
+//
+svbfloat16x2_t test_cvt2_bf16_x2(svmfloat8_t zn, fpm_t fpmr)  __arm_streaming {
+  return SVE_ACLE_FUNC(svcvt2_bf16,_mf8,_x2_fpm)(zn, fpmr);
+}
+
 // CHECK-LABEL: @test_cvtl1_f16_x2(
 // CHECK-NEXT:  entry:
 // CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c
index 09a80c9dff03ea..af1ef46ea69722 100644
--- a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c
@@ -14,4 +14,13 @@ void test_features_sme2_fp8(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
     svcvtl1_bf16_mf8_x2_fpm(zn, fpmr);
     // expected-error@+1 {{'svcvtl2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
     svcvtl2_bf16_mf8_x2_fpm(zn, fpmr);
+
+    // expected-error@+1 {{'svcvt1_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
+    svcvt1_f16_mf8_x2_fpm(zn, fpmr);
+    // expected-error@+1 {{'svcvt2_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
+    svcvt2_f16_mf8_x2_fpm(zn, fpmr);
+    // expected-error@+1 {{'svcvt1_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
+    svcvt1_bf16_mf8_x2_fpm(zn, fpmr);
+    // expected-error@+1 {{'svcvt2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
+    svcvt2_bf16_mf8_x2_fpm(zn, fpmr);
 }
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index a91616b9556828..d880f434b0a944 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -3817,6 +3817,13 @@ let TargetPrefix = "aarch64" in {
     : DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
                             [llvm_nxv16i8_ty],
                             [IntrReadMem, IntrInaccessibleMemOnly]>;
+
+  //
+  // CVT from FP8 to half-precision/BFloat16 multi-vector
+  //
+  def int_aarch64_sve_fp8_cvt1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
+  def int_aarch64_sve_fp8_cvt2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
+
   //
   // CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector
   //
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 5f0c3d2c21f791..5df61b37220373 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -5581,6 +5581,18 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
               {AArch64::BF2CVTL_2ZZ_BtoH, AArch64::F2CVTL_2ZZ_BtoH}))
         SelectCVTIntrinsicFP8(Node, 2, Opc);
       return;
+    case Intrinsic::aarch64_sve_fp8_cvt1_x2:
+      if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
+              Node->getValueType(0),
+              {AArch64::BF1CVT_2ZZ_BtoH, AArch64::F1CVT_2ZZ_BtoH}))
+        SelectCVTIntrinsicFP8(Node, 2, Opc);
+      return;
+    case Intrinsic::aarch64_sve_fp8_cvt2_x2:
+      if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
+              Node->getValueType(0),
+              {AArch64::BF2CVT_2ZZ_BtoH, AArch64::F2CVT_2ZZ_BtoH}))
+        SelectCVTIntrinsicFP8(Node, 2, Opc);
+      return;
     }
   } break;
   case ISD::INTRINSIC_WO_CHAIN: {
diff --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
index 076a3ad34eac3c..3d3fcb05f6cf07 100644
--- a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
@@ -1,6 +1,46 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s
 
+; F1CVT / F2CVT
+
+define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvt(<vscale x 16 x i8> %zm) {
+; CHECK-LABEL: f1cvt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    f1cvt { z0.h, z1.h }, z0.b
+; CHECK-NEXT:    ret
+  %res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> %zm)
+  ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
+}
+
+define { <vscale x 8 x half>, <vscale x 8 x half> } @f2cvt(<vscale x 16 x i8> %zm) {
+; CHECK-LABEL: f2cvt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    f2cvt { z0.h, z1.h }, z0.b
+; CHECK-NEXT:    ret
+  %res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> %zm)
+  ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
+}
+
+; BF1CVT / BF2CVT
+
+define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf1cvt(<vscale x 16 x i8> %zm) {
+; CHECK-LABEL: bf1cvt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    bf1cvt { z0.h, z1.h }, z0.b
+; CHECK-NEXT:    ret
+  %res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> %zm)
+  ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
+}
+
+define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf2cvt(<vscale x 16 x i8> %zm) {
+; CHECK-LABEL: bf2cvt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    bf2cvt { z0.h, z1.h }, z0.b
+; CHECK-NEXT:    ret
+  %res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> %zm)
+  ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
+}
+
 ; F1CVTL / F2CVTL
 
 define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvtl(<vscale x 16 x i8> %zm) {

Copy link
Contributor

@jthackray jthackray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@CarolineConcatto CarolineConcatto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@SpencerAbson SpencerAbson merged commit b0f0676 into llvm:main Dec 8, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants