Skip to content

Commit 7db2478

Browse files
committed
[NVPTX] Add conversion intrinsics from/to fp8 types (e4m3, e5m2)
1 parent 290f7ea commit 7db2478

File tree

6 files changed

+222
-0
lines changed

6 files changed

+222
-0
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,21 @@ TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))
584584

585585
TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
586586

587+
TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn, "sff", "", AND(SM_89,PTX81))
588+
TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn_relu, "sff", "", AND(SM_89,PTX81))
589+
TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn, "sff", "", AND(SM_89,PTX81))
590+
TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn_relu, "sff", "", AND(SM_89,PTX81))
591+
592+
TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn, "sV2h", "", AND(SM_89,PTX81))
593+
TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
594+
TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn, "sV2h", "", AND(SM_89,PTX81))
595+
TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn_relu, "sV2h", "", AND(SM_89,PTX81))
596+
597+
TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
598+
TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
599+
TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81))
600+
TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81))
601+
587602
// Bitcast
588603

589604
BUILTIN(__nvvm_bitcast_f2i, "if", "")

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \
2323
// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
2424
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP64 %s
25+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 \
26+
// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \
27+
// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s
2528

2629
#define __device__ __attribute__((device))
2730
#define __global__ __attribute__((global))
@@ -968,6 +971,39 @@ __device__ void nvvm_cvt_sm80() {
968971
// CHECK: ret void
969972
}
970973

974+
// CHECK-LABEL: nvvm_cvt_sm89
975+
__device__ void nvvm_cvt_sm89() {
976+
#if __CUDA_ARCH__ >= 890
977+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float 1.000000e+00, float 1.000000e+00)
978+
__nvvm_ff_to_e4m3x2_rn(1.0f, 1.0f);
979+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
980+
__nvvm_ff_to_e4m3x2_rn_relu(1.0f, 1.0f);
981+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float 1.000000e+00, float 1.000000e+00)
982+
__nvvm_ff_to_e5m2x2_rn(1.0f, 1.0f);
983+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
984+
__nvvm_ff_to_e5m2x2_rn_relu(1.0f, 1.0f);
985+
986+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
987+
__nvvm_f16x2_to_e4m3x2_rn({1.0f16, 1.0f16});
988+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
989+
__nvvm_f16x2_to_e4m3x2_rn_relu({1.0f16, 1.0f16});
990+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>)
991+
__nvvm_f16x2_to_e5m2x2_rn({1.0f16, 1.0f16});
992+
// CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>)
993+
__nvvm_f16x2_to_e5m2x2_rn_relu({1.0f16, 1.0f16});
994+
995+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 18504)
996+
__nvvm_e4m3x2_to_f16x2_rn(0x4848);
997+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 18504)
998+
__nvvm_e4m3x2_to_f16x2_rn_relu(0x4848);
999+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 19532)
1000+
__nvvm_e5m2x2_to_f16x2_rn(0x4c4c);
1001+
// CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 19532)
1002+
__nvvm_e5m2x2_to_f16x2_rn_relu(0x4c4c);
1003+
#endif
1004+
// CHECK: ret void
1005+
}
1006+
9711007
#define NAN32 0x7FBFFFFF
9721008
#define NAN16 (__bf16)0x7FBF
9731009
#define BF16 (__bf16)0.1f

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,33 @@ let TargetPrefix = "nvvm" in {
12961296
def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
12971297
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
12981298

1299+
def int_nvvm_ff_to_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn">,
1300+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1301+
def int_nvvm_ff_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn_relu">,
1302+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1303+
def int_nvvm_ff_to_e5m2x2_rn : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn">,
1304+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1305+
def int_nvvm_ff_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn_relu">,
1306+
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1307+
1308+
def int_nvvm_f16x2_to_e4m3x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn">,
1309+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1310+
def int_nvvm_f16x2_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn_relu">,
1311+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1312+
def int_nvvm_f16x2_to_e5m2x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn">,
1313+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1314+
def int_nvvm_f16x2_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn_relu">,
1315+
Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>;
1316+
1317+
def int_nvvm_e4m3x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn">,
1318+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1319+
def int_nvvm_e4m3x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn_relu">,
1320+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1321+
def int_nvvm_e5m2x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn">,
1322+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1323+
def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
1324+
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1325+
12991326
//
13001327
// Bitcast
13011328
//

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,37 @@ let hasSideEffects = false in {
722722

723723
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Int32Regs>;
724724
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
725+
726+
// FP8 conversions.
727+
multiclass CVT_TO_F8X2<string F8Name> {
728+
def _f32 :
729+
NVPTXInst<(outs Int16Regs:$dst),
730+
(ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
731+
!strconcat("cvt${mode:base}.satfinite${mode:relu}.",
732+
F8Name, "x2.f32 \t$dst, $src1, $src2;"), []>,
733+
Requires<[hasPTX<81>, hasSM<89>]>;
734+
def _f16x2 :
735+
NVPTXInst<(outs Int16Regs:$dst),
736+
(ins Int32Regs:$src, CvtMode:$mode),
737+
!strconcat("cvt${mode:base}.satfinite${mode:relu}.",
738+
F8Name, "x2.f16x2 \t$dst, $src;"), []>,
739+
Requires<[hasPTX<81>, hasSM<89>]>;
740+
}
741+
742+
defm CVT_e4m3x2 : CVT_TO_F8X2<"e4m3">;
743+
defm CVT_e5m2x2 : CVT_TO_F8X2<"e5m2">;
744+
745+
multiclass CVT_FROM_F8X2<string F8Name> {
746+
def x2 :
747+
NVPTXInst<(outs Int32Regs:$dst),
748+
(ins Int16Regs:$src, CvtMode:$mode),
749+
!strconcat("cvt${mode:base}${mode:relu}.f16x2.",
750+
F8Name, "x2 \t$dst, $src;"), []>,
751+
Requires<[hasPTX<81>, hasSM<89>]>;
752+
}
753+
754+
defm CVT_f16x2_e4m3 : CVT_FROM_F8X2<"e4m3">;
755+
defm CVT_f16x2_e5m2 : CVT_FROM_F8X2<"e5m2">;
725756
}
726757

727758
//-----------------------------------

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,33 @@ def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a),
15241524
def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
15251525
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;
15261526

1527+
def : Pat<(int_nvvm_ff_to_e4m3x2_rn Float32Regs:$a, Float32Regs:$b),
1528+
(CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
1529+
def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu Float32Regs:$a, Float32Regs:$b),
1530+
(CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
1531+
def : Pat<(int_nvvm_ff_to_e5m2x2_rn Float32Regs:$a, Float32Regs:$b),
1532+
(CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
1533+
def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu Float32Regs:$a, Float32Regs:$b),
1534+
(CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
1535+
1536+
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn Int32Regs:$a),
1537+
(CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN)>;
1538+
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu Int32Regs:$a),
1539+
(CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
1540+
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn Int32Regs:$a),
1541+
(CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN)>;
1542+
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu Int32Regs:$a),
1543+
(CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN_RELU)>;
1544+
1545+
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn Int16Regs:$a),
1546+
(CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN)>;
1547+
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu Int16Regs:$a),
1548+
(CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN_RELU)>;
1549+
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
1550+
(CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN)>;
1551+
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
1552+
(CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN_RELU)>;
1553+
15271554
//
15281555
// Bitcast
15291556
//
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | FileCheck %s
2+
; RUN: %if ptxas-12.1 %{ llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | %ptxas-verify -arch=sm_89 %}
3+
4+
; CHECK-LABEL: cvt_rn_e4m3x2_f32
5+
define i16 @cvt_rn_e4m3x2_f32(float %f1, float %f2) {
6+
; CHECK: cvt.rn.satfinite.e4m3x2.f32
7+
%val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %f1, float %f2);
8+
ret i16 %val
9+
}
10+
11+
; CHECK-LABEL: cvt_rn_relu_e4m3x2_f32
12+
define i16 @cvt_rn_relu_e4m3x2_f32(float %f1, float %f2) {
13+
; CHECK: cvt.rn.satfinite.relu.e4m3x2.f32
14+
%val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %f1, float %f2);
15+
ret i16 %val
16+
}
17+
18+
; CHECK-LABEL: cvt_rn_e5m2x2_f32
19+
define i16 @cvt_rn_e5m2x2_f32(float %f1, float %f2) {
20+
; CHECK: cvt.rn.satfinite.e5m2x2.f32
21+
%val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %f1, float %f2);
22+
ret i16 %val
23+
}
24+
25+
; CHECK-LABEL: cvt_rn_relu_e5m2x2_f32
26+
define i16 @cvt_rn_relu_e5m2x2_f32(float %f1, float %f2) {
27+
; CHECK: cvt.rn.satfinite.relu.e5m2x2.f32
28+
%val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %f1, float %f2);
29+
ret i16 %val
30+
}
31+
32+
; CHECK-LABEL: cvt_rn_e4m3x2_f16x2
33+
define i16 @cvt_rn_e4m3x2_f16x2(<2 x half> %in) {
34+
; CHECK: cvt.rn.satfinite.e4m3x2.f16x2
35+
%val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %in);
36+
ret i16 %val
37+
}
38+
39+
; CHECK-LABEL: cvt_rn_relu_e4m3x2_f16x2
40+
define i16 @cvt_rn_relu_e4m3x2_f16x2(<2 x half> %in) {
41+
; CHECK: cvt.rn.satfinite.relu.e4m3x2.f16x2
42+
%val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %in);
43+
ret i16 %val
44+
}
45+
46+
; CHECK-LABEL: cvt_rn_e5m2x2_f16x2
47+
define i16 @cvt_rn_e5m2x2_f16x2(<2 x half> %in) {
48+
; CHECK: cvt.rn.satfinite.e5m2x2.f16x2
49+
%val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %in);
50+
ret i16 %val
51+
}
52+
53+
; CHECK-LABEL: cvt_rn_relu_e5m2x2_f16x2
54+
define i16 @cvt_rn_relu_e5m2x2_f16x2(<2 x half> %in) {
55+
; CHECK: cvt.rn.satfinite.relu.e5m2x2.f16x2
56+
%val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %in);
57+
ret i16 %val
58+
}
59+
60+
; CHECK-LABEL: cvt_rn_f16x2_e4m3x2
61+
define <2 x half> @cvt_rn_f16x2_e4m3x2(i16 %in) {
62+
; CHECK: cvt.rn.f16x2.e4m3x2
63+
%val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %in);
64+
ret <2 x half> %val
65+
}
66+
67+
; CHECK-LABEL: cvt_rn_relu_f16x2_e4m3x2
68+
define <2 x half> @cvt_rn_relu_f16x2_e4m3x2(i16 %in) {
69+
; CHECK: cvt.rn.relu.f16x2.e4m3x2
70+
%val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %in);
71+
ret <2 x half> %val
72+
}
73+
74+
; CHECK-LABEL: cvt_rn_f16x2_e5m2x2
75+
define <2 x half> @cvt_rn_f16x2_e5m2x2(i16 %in) {
76+
; CHECK: cvt.rn.f16x2.e5m2x2
77+
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %in);
78+
ret <2 x half> %val
79+
}
80+
81+
; CHECK-LABEL: cvt_rn_relu_f16x2_e5m2x2
82+
define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
83+
; CHECK: cvt.rn.relu.f16x2.e5m2x2
84+
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
85+
ret <2 x half> %val
86+
}

0 commit comments

Comments
 (0)