Skip to content

[NVPTX] Add conversion intrinsics from/to fp8 types (e4m3, e5m2) #102969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.def
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,21 @@ TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn, "sff", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn_relu, "sff", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn, "sff", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn_relu, "sff", "", AND(SM_89,PTX81))

TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn, "sV2h", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn, "sV2h", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))

TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))

// Bitcast

BUILTIN(__nvvm_bitcast_f2i, "if", "")
Expand Down
36 changes: 36 additions & 0 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \
// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP64 %s
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 \
// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s

#define __device__ __attribute__((device))
#define __global__ __attribute__((global))
Expand Down Expand Up @@ -968,6 +971,39 @@ __device__ void nvvm_cvt_sm80() {
// CHECK: ret void
}

// CHECK-LABEL: nvvm_cvt_sm89
__device__ void nvvm_cvt_sm89() {
#if __CUDA_ARCH__ >= 890
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff_to_e4m3x2_rn(1.0f, 1.0f);
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff_to_e4m3x2_rn_relu(1.0f, 1.0f);
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff_to_e5m2x2_rn(1.0f, 1.0f);
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff_to_e5m2x2_rn_relu(1.0f, 1.0f);

// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
__nvvm_f16x2_to_e4m3x2_rn({1.0f16, 1.0f16});
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
__nvvm_f16x2_to_e4m3x2_rn_relu({1.0f16, 1.0f16});
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
__nvvm_f16x2_to_e5m2x2_rn({1.0f16, 1.0f16});
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
__nvvm_f16x2_to_e5m2x2_rn_relu({1.0f16, 1.0f16});

// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 18504)
__nvvm_e4m3x2_to_f16x2_rn(0x4848);
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 18504)
__nvvm_e4m3x2_to_f16x2_rn_relu(0x4848);
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 19532)
__nvvm_e5m2x2_to_f16x2_rn(0x4c4c);
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 19532)
__nvvm_e5m2x2_to_f16x2_rn_relu(0x4c4c);
#endif
// CHECK: ret void
}

#define NAN32 0x7FBFFFFF
#define NAN16 (__bf16)0x7FBF
#define BF16 (__bf16)0.1f
Expand Down
27 changes: 27 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,33 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;

def int_nvvm_ff_to_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn_relu">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff_to_e5m2x2_rn : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn_relu">,
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;

def int_nvvm_f16x2_to_e4m3x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn">,
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f16x2_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn_relu">,
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f16x2_to_e5m2x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn">,
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f16x2_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn_relu">,
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;

def int_nvvm_e4m3x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn">,
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_e4m3x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn_relu">,
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_e5m2x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn">,
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;

//
// Bitcast
//
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,35 @@ let hasSideEffects = false in {

defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Int32Regs>;
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;

// FP8 conversions.
multiclass CVT_TO_F8X2<string F8Name> {
def _f32 :
NVPTXInst<(outs Int16Regs:$dst),
(ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
!strconcat("cvt${mode:base}.satfinite${mode:relu}.",
F8Name, "x2.f32 \t$dst, $src1, $src2;"), []>,
Requires<[hasPTX<81>, hasSM<89>]>;
def _f16x2 :
NVPTXInst<(outs Int16Regs:$dst),
(ins Int32Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}.satfinite${mode:relu}.",
F8Name, "x2.f16x2 \t$dst, $src;"), []>,
Requires<[hasPTX<81>, hasSM<89>]>;
}

defm CVT_e4m3x2 : CVT_TO_F8X2<"e4m3">;
defm CVT_e5m2x2 : CVT_TO_F8X2<"e5m2">;

class CVT_f16x2_fp8<string F8Name> :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int16Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:relu}.f16x2.",
F8Name, "x2 \t$dst, $src;"), []>,
Requires<[hasPTX<81>, hasSM<89>]>;

def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">;
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
}

//-----------------------------------
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,33 @@ def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a),
def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;

def : Pat<(int_nvvm_ff_to_e4m3x2_rn Float32Regs:$a, Float32Regs:$b),
(CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu Float32Regs:$a, Float32Regs:$b),
(CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff_to_e5m2x2_rn Float32Regs:$a, Float32Regs:$b),
(CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu Float32Regs:$a, Float32Regs:$b),
(CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;

def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn Int32Regs:$a),
(CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN)>;
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu Int32Regs:$a),
(CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn Int32Regs:$a),
(CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN)>;
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu Int32Regs:$a),
(CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;

def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn Int16Regs:$a),
(CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN)>;
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu Int16Regs:$a),
(CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN_RELU)>;
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
(CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN)>;
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
(CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN_RELU)>;

//
// Bitcast
//
Expand Down
86 changes: 86 additions & 0 deletions llvm/test/CodeGen/NVPTX/convert-sm89.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
; RUN: llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | FileCheck %s
; RUN: %if ptxas-12.1 %{ llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | %ptxas-verify -arch=sm_89 %}

; CHECK-LABEL: cvt_rn_e4m3x2_f32
define i16 @cvt_rn_e4m3x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.satfinite.e4m3x2.f32
%val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %f1, float %f2);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_relu_e4m3x2_f32
define i16 @cvt_rn_relu_e4m3x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.satfinite.relu.e4m3x2.f32
%val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %f1, float %f2);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_e5m2x2_f32
define i16 @cvt_rn_e5m2x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.satfinite.e5m2x2.f32
%val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %f1, float %f2);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_relu_e5m2x2_f32
define i16 @cvt_rn_relu_e5m2x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.satfinite.relu.e5m2x2.f32
%val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %f1, float %f2);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_e4m3x2_f16x2
define i16 @cvt_rn_e4m3x2_f16x2(<2 x half> %in) {
; CHECK: cvt.rn.satfinite.e4m3x2.f16x2
%val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %in);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_relu_e4m3x2_f16x2
define i16 @cvt_rn_relu_e4m3x2_f16x2(<2 x half> %in) {
; CHECK: cvt.rn.satfinite.relu.e4m3x2.f16x2
%val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %in);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_e5m2x2_f16x2
define i16 @cvt_rn_e5m2x2_f16x2(<2 x half> %in) {
; CHECK: cvt.rn.satfinite.e5m2x2.f16x2
%val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %in);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_relu_e5m2x2_f16x2
define i16 @cvt_rn_relu_e5m2x2_f16x2(<2 x half> %in) {
; CHECK: cvt.rn.satfinite.relu.e5m2x2.f16x2
%val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %in);
ret i16 %val
}

; CHECK-LABEL: cvt_rn_f16x2_e4m3x2
define <2 x half> @cvt_rn_f16x2_e4m3x2(i16 %in) {
; CHECK: cvt.rn.f16x2.e4m3x2
%val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %in);
ret <2 x half> %val
}

; CHECK-LABEL: cvt_rn_relu_f16x2_e4m3x2
define <2 x half> @cvt_rn_relu_f16x2_e4m3x2(i16 %in) {
; CHECK: cvt.rn.relu.f16x2.e4m3x2
%val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %in);
ret <2 x half> %val
}

; CHECK-LABEL: cvt_rn_f16x2_e5m2x2
define <2 x half> @cvt_rn_f16x2_e5m2x2(i16 %in) {
; CHECK: cvt.rn.f16x2.e5m2x2
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %in);
ret <2 x half> %val
}

; CHECK-LABEL: cvt_rn_relu_f16x2_e5m2x2
define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
; CHECK: cvt.rn.relu.f16x2.e5m2x2
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
ret <2 x half> %val
}
Loading