Skip to content

Commit bf34c2e

Browse files
committed
[NVPTX] Add convert float to tf32 intrinsics
This patch adds the missing variants of float to tf32 conversion intrinsics. Lit tests are added for all the intrinsics. PTX Spec link: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt Signed-off-by: Durgadoss R <[email protected]>
1 parent 81ae668 commit bf34c2e

File tree

7 files changed

+119
-5
lines changed

7 files changed

+119
-5
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,16 @@ let TargetPrefix = "nvvm" in {
14381438

14391439
def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
14401440
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1441+
def int_nvvm_f2tf32_rna_satfinite : ClangBuiltin<"__nvvm_f2tf32_rna_satfinite">,
1442+
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1443+
def int_nvvm_f2tf32_rn : ClangBuiltin<"__nvvm_f2tf32_rn">,
1444+
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1445+
def int_nvvm_f2tf32_rn_relu : ClangBuiltin<"__nvvm_f2tf32_rn_relu">,
1446+
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1447+
def int_nvvm_f2tf32_rz : ClangBuiltin<"__nvvm_f2tf32_rz">,
1448+
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1449+
def int_nvvm_f2tf32_rz_relu : ClangBuiltin<"__nvvm_f2tf32_rz_relu">,
1450+
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
14411451

14421452
def int_nvvm_ff_to_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn">,
14431453
Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
110110
if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
111111
O << ".sat";
112112
return;
113+
} else if (Modifier == "satfinite") {
114+
// SATFINITE flag
115+
if (Imm & NVPTX::PTXCvtMode::SATFINITE_FLAG)
116+
O << ".satfinite";
117+
return;
113118
} else if (Modifier == "relu") {
114119
// RELU flag
115120
if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ enum CvtMode {
182182
BASE_MASK = 0x0F,
183183
FTZ_FLAG = 0x10,
184184
SAT_FLAG = 0x20,
185-
RELU_FLAG = 0x40
185+
RELU_FLAG = 0x40,
186+
SATFINITE_FLAG = 0x80
186187
};
187188
}
188189

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def CvtNONE_RELU : PatLeaf<(i32 0x40)>;
6464
def CvtRN_RELU : PatLeaf<(i32 0x45)>;
6565
def CvtRZ_RELU : PatLeaf<(i32 0x46)>;
6666

67+
def CvtNONE_SATFINITE : PatLeaf<(i32 0x80)>;
68+
def CvtRNA_SATFINITE : PatLeaf<(i32 0x89)>;
69+
6770
def CvtMode : Operand<i32> {
6871
let PrintMethod = "printCvtMode";
6972
}
@@ -725,6 +728,12 @@ let hasSideEffects = false in {
725728

726729
def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">;
727730
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
731+
732+
// Float to TF32 conversions.
733+
def CVT_tf32_f32 : NVPTXInst<(outs Int32Regs:$dst),
734+
(ins Float32Regs:$src, CvtMode:$mode),
735+
!strconcat("cvt${mode:base}${mode:relu}${mode:satfinite}.",
736+
"tf32.f32 \t$dst, $src;"), []>;
728737
}
729738

730739
def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,10 +1660,24 @@ def : Pat<(int_nvvm_f2bf16_rz f32:$a),
16601660
def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a),
16611661
(CVT_bf16_f32 $a, CvtRZ_RELU)>;
16621662

1663-
def CVT_tf32_f32 :
1664-
NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
1665-
"cvt.rna.tf32.f32 \t$dest, $a;",
1666-
[(set i32:$dest, (int_nvvm_f2tf32_rna f32:$a))]>;
1663+
def : Pat<(int_nvvm_f2tf32_rna f32:$a),
1664+
(CVT_tf32_f32 $a, CvtRNA)>,
1665+
Requires<[hasPTX<70>, hasSM<80>]>;
1666+
def : Pat<(int_nvvm_f2tf32_rna_satfinite f32:$a),
1667+
(CVT_tf32_f32 $a, CvtRNA_SATFINITE)>,
1668+
Requires<[hasPTX<81>, hasSM<89>]>;
1669+
def : Pat<(int_nvvm_f2tf32_rn f32:$a),
1670+
(CVT_tf32_f32 $a, CvtRN)>,
1671+
Requires<[hasPTX<78>, hasSM<90>]>;
1672+
def : Pat<(int_nvvm_f2tf32_rn_relu f32:$a),
1673+
(CVT_tf32_f32 $a, CvtRN_RELU)>,
1674+
Requires<[hasPTX<78>, hasSM<90>]>;
1675+
def : Pat<(int_nvvm_f2tf32_rz f32:$a),
1676+
(CVT_tf32_f32 $a, CvtRZ)>,
1677+
Requires<[hasPTX<78>, hasSM<90>]>;
1678+
def : Pat<(int_nvvm_f2tf32_rz_relu f32:$a),
1679+
(CVT_tf32_f32 $a, CvtRZ_RELU)>,
1680+
Requires<[hasPTX<78>, hasSM<90>]>;
16671681

16681682
def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};",
16691683
Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>;

llvm/test/CodeGen/NVPTX/convert-sm89.ll

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,10 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
8484
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
8585
ret <2 x half> %val
8686
}
87+
88+
; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
89+
define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
90+
; CHECK: cvt.rna.satfinite.tf32.f32
91+
%val = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %f1)
92+
ret i32 %val
93+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s
3+
; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %}
4+
5+
declare i32 @llvm.nvvm.f2tf32.rn(float %f1)
6+
declare i32 @llvm.nvvm.f2tf32.rn.relu(float %f1)
7+
declare i32 @llvm.nvvm.f2tf32.rz(float %f1)
8+
declare i32 @llvm.nvvm.f2tf32.rz.relu(float %f1)
9+
10+
define i32 @cvt_rn_tf32_f32(float %f1) {
11+
; CHECK-LABEL: cvt_rn_tf32_f32(
12+
; CHECK: {
13+
; CHECK-NEXT: .reg .b32 %r<2>;
14+
; CHECK-NEXT: .reg .f32 %f<2>;
15+
; CHECK-EMPTY:
16+
; CHECK-NEXT: // %bb.0:
17+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0];
18+
; CHECK-NEXT: cvt.rn.tf32.f32 %r1, %f1;
19+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
20+
; CHECK-NEXT: ret;
21+
%val = call i32 @llvm.nvvm.f2tf32.rn(float %f1)
22+
ret i32 %val
23+
}
24+
25+
define i32 @cvt_rn_relu_tf32_f32(float %f1) {
26+
; CHECK-LABEL: cvt_rn_relu_tf32_f32(
27+
; CHECK: {
28+
; CHECK-NEXT: .reg .b32 %r<2>;
29+
; CHECK-NEXT: .reg .f32 %f<2>;
30+
; CHECK-EMPTY:
31+
; CHECK-NEXT: // %bb.0:
32+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0];
33+
; CHECK-NEXT: cvt.rn.relu.tf32.f32 %r1, %f1;
34+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
35+
; CHECK-NEXT: ret;
36+
%val = call i32 @llvm.nvvm.f2tf32.rn.relu(float %f1)
37+
ret i32 %val
38+
}
39+
40+
define i32 @cvt_rz_tf32_f32(float %f1) {
41+
; CHECK-LABEL: cvt_rz_tf32_f32(
42+
; CHECK: {
43+
; CHECK-NEXT: .reg .b32 %r<2>;
44+
; CHECK-NEXT: .reg .f32 %f<2>;
45+
; CHECK-EMPTY:
46+
; CHECK-NEXT: // %bb.0:
47+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0];
48+
; CHECK-NEXT: cvt.rz.tf32.f32 %r1, %f1;
49+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
50+
; CHECK-NEXT: ret;
51+
%val = call i32 @llvm.nvvm.f2tf32.rz(float %f1)
52+
ret i32 %val
53+
}
54+
55+
define i32 @cvt_rz_relu_tf32_f32(float %f1) {
56+
; CHECK-LABEL: cvt_rz_relu_tf32_f32(
57+
; CHECK: {
58+
; CHECK-NEXT: .reg .b32 %r<2>;
59+
; CHECK-NEXT: .reg .f32 %f<2>;
60+
; CHECK-EMPTY:
61+
; CHECK-NEXT: // %bb.0:
62+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0];
63+
; CHECK-NEXT: cvt.rz.relu.tf32.f32 %r1, %f1;
64+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
65+
; CHECK-NEXT: ret;
66+
%val = call i32 @llvm.nvvm.f2tf32.rz.relu(float %f1)
67+
ret i32 %val
68+
}

0 commit comments

Comments
 (0)