Skip to content

Commit fcbecc6

Browse files
committed
[NVPTX] Add convert float to tf32 intrinsics
This patch adds an intrinsic to convert float to tf32. * This intrinsic uses flags for rounding and saturation modes as well as relu. The backend looks through these flags and lowers to the appropriate instruction. * Docs are updated to describe the usage of the flag arguments. * Lit tests are added for all the combinations. Note: We already have an intrinsic 'llvm.nvvm.f2tf32.rna' which caters only to one variant of the PTX instruction. Once this change lands, I will submit a follow-up PR to auto-upgrade it to use the generic variant. PTX Spec link: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt Signed-off-by: Durgadoss R <[email protected]>
1 parent 2b63077 commit fcbecc6

File tree

12 files changed

+281
-0
lines changed

12 files changed

+281
-0
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,56 @@ to left-shift the found bit into the most-significant bit position, otherwise
462462
the result is the shift amount needed to right-shift the found bit into the
463463
least-significant bit position. 0xffffffff is returned if no 1 bit is found.
464464

465+
Conversion Intrinsics (for cvt.* PTX instructions)
466+
--------------------------------------------------
467+
468+
'``llvm.nvvm.convert.to.tf32.f32``'
469+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
470+
471+
Syntax:
472+
"""""""
473+
474+
.. code-block:: llvm
475+
476+
declare i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata !round_mode, i8 %flag_sat_mode, i1 %flag_relu)
477+
478+
Overview:
479+
"""""""""
480+
481+
The '``@llvm.nvvm.convert.to.tf32.f32``' intrinsic lowers to
482+
the ``cvt.*.tf32.f32`` set of PTX instructions.
483+
484+
* The first argument is the input float to be converted to TF32.
485+
486+
* The second argument (denoted by ``metadata !round_mode``) denotes
487+
the floating-point rounding modes supported for this instruction.
488+
The metadata strings are the same as the ones used for constrained-fp
489+
intrinsics, documented here:
490+
`<https://llvm.org/docs/LangRef.html#constrainedfp>`_.
491+
492+
The valid rounding modes for this intrinsic are ``round.tonearest,
493+
round.towardzero and round.tonearestaway``.
494+
495+
* The third argument (denoted by ``i8 %flag_sat_mode``) denotes the
496+
saturation modifier for this intrinsic. As of now, it can either
497+
be None or Satfinite, according to the enumeration below:
498+
499+
========== ================
500+
Enum Value Saturation Mode
501+
========== ================
502+
``0`` NONE
503+
``1`` SATFINITE
504+
========== ================
505+
506+
* The last argument (denoted by ``i1 %flag_relu``) when set, generates
507+
the ``.relu`` variant of the instruction.
508+
509+
* Invalid values for the rounding and/or saturation modes may result in
510+
error(s) during Codegen.
511+
512+
For more information, refer PTX ISA
513+
`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt>`_.
514+
465515
TMA family of Intrinsics
466516
------------------------
467517

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,14 @@ let TargetPrefix = "nvvm" in {
14661466
def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
14671467
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
14681468

1469+
// Convert float to TF32
1470+
def int_nvvm_convert_to_tf32_f32 : DefaultAttrsIntrinsic<[llvm_i32_ty],
1471+
[llvm_float_ty, // Input float
1472+
llvm_metadata_ty, // Metadata for Rounding modes
1473+
llvm_i8_ty, // Flag for Saturation modes
1474+
llvm_i1_ty], // Flag for relu
1475+
[IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
1476+
14691477
// FNS
14701478

14711479
def int_nvvm_fns : ClangBuiltin<"__nvvm_fns">,

llvm/include/llvm/IR/NVVMIntrinsicFlags.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ enum class TMAReductionOp : uint8_t {
3434
XOR = 7,
3535
};
3636

37+
// Saturation Modes
38+
enum class SaturationMode : uint8_t {
39+
NONE = 0,
40+
SATFINITE = 1,
41+
};
42+
3743
} // namespace nvvm
3844
} // namespace llvm
3945
#endif // LLVM_IR_NVVMINTRINSICFLAGS_H

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,53 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
453453
llvm_unreachable(
454454
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
455455
}
456+
457+
void NVPTXInstPrinter::printFPRoundingMode(const MCInst *MI, int OpNum,
458+
raw_ostream &O,
459+
const char *Modifier) {
460+
const MCOperand &MO = MI->getOperand(OpNum);
461+
switch (static_cast<llvm::RoundingMode>(MO.getImm())) {
462+
case llvm::RoundingMode::NearestTiesToEven:
463+
O << ".rn";
464+
return;
465+
case llvm::RoundingMode::NearestTiesToAway:
466+
O << ".rna";
467+
return;
468+
case llvm::RoundingMode::TowardZero:
469+
O << ".rz";
470+
return;
471+
case llvm::RoundingMode::TowardPositive:
472+
O << ".rp";
473+
return;
474+
case llvm::RoundingMode::TowardNegative:
475+
O << ".rm";
476+
return;
477+
default:
478+
O << "";
479+
return;
480+
}
481+
}
482+
483+
void NVPTXInstPrinter::printSaturationMode(const MCInst *MI, int OpNum,
484+
raw_ostream &O,
485+
const char *Modifier) {
486+
const MCOperand &MO = MI->getOperand(OpNum);
487+
using Mode = nvvm::SaturationMode;
488+
489+
switch (static_cast<Mode>(MO.getImm())) {
490+
case Mode::NONE:
491+
O << "";
492+
return;
493+
case Mode::SATFINITE:
494+
O << ".satfinite";
495+
return;
496+
}
497+
llvm_unreachable("Invalid mode in printSaturationMode");
498+
}
499+
500+
void NVPTXInstPrinter::printReluModifier(const MCInst *MI, int OpNum,
501+
raw_ostream &O, const char *Modifier) {
502+
const MCOperand &MO = MI->getOperand(OpNum);
503+
if (MO.getImm())
504+
O << ".relu";
505+
}

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ class NVPTXInstPrinter : public MCInstPrinter {
5656
const char *Modifier = nullptr);
5757
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
5858
const char *Modifier = nullptr);
59+
void printFPRoundingMode(const MCInst *MI, int OpNum, raw_ostream &O,
60+
const char *Modifier = nullptr);
61+
void printSaturationMode(const MCInst *MI, int OpNum, raw_ostream &O,
62+
const char *Modifier = nullptr);
63+
void printReluModifier(const MCInst *MI, int OpNum, raw_ostream &O,
64+
const char *Modifier = nullptr);
5965
};
6066

6167
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,55 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
728728
case Intrinsic::nvvm_texsurf_handle_internal:
729729
SelectTexSurfHandle(N);
730730
return true;
731+
case Intrinsic::nvvm_convert_to_tf32_f32:
732+
SelectCvtFloatToTF32(N);
733+
return true;
734+
}
735+
}
736+
737+
void NVPTXDAGToDAGISel::SelectCvtFloatToTF32(SDNode *N) {
738+
// 0 - IID
739+
// 1 - Input float
740+
// 2 - Rounding mode as string metadata
741+
// 3 - Saturation mode
742+
// 4 - Relu flag
743+
uint64_t Sat = N->getConstantOperandVal(3);
744+
bool IsRelu = N->getConstantOperandVal(4) == 1;
745+
746+
if (!Subtarget->hasTF32Math())
747+
report_fatal_error("TF32 destination format requires at least sm80");
748+
749+
using SatMode = nvvm::SaturationMode;
750+
bool IsSatFinite = static_cast<SatMode>(Sat) == SatMode::SATFINITE;
751+
if (IsSatFinite && Subtarget->getPTXVersion() < 81)
752+
report_fatal_error("satfinite modifier requires PTX version 8.1 or higher");
753+
754+
const MDNode *MD = cast<MDNodeSDNode>(N->getOperand(2))->getMD();
755+
auto RndString = cast<MDString>(MD->getOperand(0))->getString();
756+
std::optional<RoundingMode> RndVal = convertStrToRoundingMode(RndString);
757+
switch (*RndVal) {
758+
case RoundingMode::NearestTiesToAway:
759+
if (IsRelu)
760+
report_fatal_error("relu not supported with rna rounding mode");
761+
break;
762+
case RoundingMode::NearestTiesToEven:
763+
case RoundingMode::TowardZero: {
764+
if (Subtarget->getSmVersion() < 90)
765+
report_fatal_error("rn/rz rounding modes require at least sm90");
766+
if (IsSatFinite)
767+
report_fatal_error("satfinite not supported with rn/rz rounding modes");
768+
break;
769+
}
770+
default:
771+
report_fatal_error("Invalid FP rounding mode in SelectCvtFloatToTF32");
731772
}
773+
774+
SDLoc DL(N);
775+
SDValue Ops[] = {N->getOperand(1),
776+
getI32Imm(static_cast<unsigned>(*RndVal), DL),
777+
getI32Imm(Sat, DL), getI32Imm(IsRelu, DL)};
778+
ReplaceNode(N, CurDAG->getMachineNode(NVPTX::cvt_f32_to_tf32, DL,
779+
N->getVTList(), Ops));
732780
}
733781

734782
void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7373
bool tryIntrinsicChain(SDNode *N);
7474
bool tryIntrinsicVoid(SDNode *N);
7575
void SelectTexSurfHandle(SDNode *N);
76+
void SelectCvtFloatToTF32(SDNode *N);
7677
bool tryLoad(SDNode *N);
7778
bool tryLoadVector(SDNode *N);
7879
bool tryLDGLDU(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,22 @@ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
18021802
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
18031803
(CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
18041804

1805+
def FPRoundingMode : Operand<i32> {
1806+
let PrintMethod = "printFPRoundingMode";
1807+
}
1808+
1809+
def SatMode : Operand<i32> {
1810+
let PrintMethod = "printSaturationMode";
1811+
}
1812+
1813+
def ReluFlag : Operand<i32> {
1814+
let PrintMethod = "printReluModifier";
1815+
}
1816+
1817+
def cvt_f32_to_tf32 : NVPTXInst<(outs Int32Regs:$dest),
1818+
(ins Float32Regs:$a, FPRoundingMode:$rnd, SatMode:$sat, ReluFlag:$relu),
1819+
"cvt${rnd:rnd}${sat:sat}${relu:relu}.tf32.f32 \t$dest, $a;", []>;
1820+
18051821
//
18061822
// FNS
18071823
//

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
8383
bool hasFP16Math() const { return SmVersion >= 53; }
8484
bool hasBF16Math() const { return SmVersion >= 80; }
8585
bool allowFP16Math() const;
86+
bool hasTF32Math() const { return SmVersion >= 80 && PTXVersion >= 70; }
8687
bool hasMaskOperator() const { return PTXVersion >= 71; }
8788
bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
8889
// Does SM & PTX support memory orderings (weak and atomic: relaxed, acquire,

llvm/test/CodeGen/NVPTX/convert-sm80.ll

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,21 @@ define <2 x half> @fold_ff2f16x2(float %lo, float %hi) {
261261
%v1 = insertelement <2 x half> %v0, half %hih, i64 1
262262
ret <2 x half> %v1
263263
}
264+
265+
declare i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata, i8, i1)
266+
267+
define i32 @cvt_rna_tf32_f32_flags(float %f1) {
268+
; CHECK-LABEL: cvt_rna_tf32_f32_flags(
269+
; CHECK: {
270+
; CHECK-NEXT: .reg .b32 %r<2>;
271+
; CHECK-NEXT: .reg .f32 %f<2>;
272+
; CHECK-EMPTY:
273+
; CHECK-NEXT: // %bb.0:
274+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rna_tf32_f32_flags_param_0];
275+
; CHECK-NEXT: cvt.rna.tf32.f32 %r1, %f1;
276+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
277+
; CHECK-NEXT: ret;
278+
%val = call i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata !0, i8 0, i1 0)
279+
ret i32 %val
280+
}
281+
!0 = !{!"round.tonearestaway"}

llvm/test/CodeGen/NVPTX/convert-sm89.ll

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
8484
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
8585
ret <2 x half> %val
8686
}
87+
88+
declare i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata, i8, i1)
89+
90+
; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
91+
define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
92+
; CHECK: cvt.rna.satfinite.tf32.f32
93+
%val = call i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata !0, i8 1, i1 0)
94+
ret i32 %val
95+
}
96+
!0 = !{!"round.tonearestaway"}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s
3+
; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %}
4+
5+
declare i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata, i8, i1)
6+
7+
define i32 @cvt_rn_tf32_f32(float %f1) {
8+
; CHECK-LABEL: cvt_rn_tf32_f32(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .b32 %r<2>;
11+
; CHECK-NEXT: .reg .f32 %f<2>;
12+
; CHECK-EMPTY:
13+
; CHECK-NEXT: // %bb.0:
14+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0];
15+
; CHECK-NEXT: cvt.rn.tf32.f32 %r1, %f1;
16+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
17+
; CHECK-NEXT: ret;
18+
%val = call i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata !0, i8 0, i1 0)
19+
ret i32 %val
20+
}
21+
22+
define i32 @cvt_rn_relu_tf32_f32(float %f1) {
23+
; CHECK-LABEL: cvt_rn_relu_tf32_f32(
24+
; CHECK: {
25+
; CHECK-NEXT: .reg .b32 %r<2>;
26+
; CHECK-NEXT: .reg .f32 %f<2>;
27+
; CHECK-EMPTY:
28+
; CHECK-NEXT: // %bb.0:
29+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0];
30+
; CHECK-NEXT: cvt.rn.relu.tf32.f32 %r1, %f1;
31+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
32+
; CHECK-NEXT: ret;
33+
%val = call i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata !0, i8 0, i1 1)
34+
ret i32 %val
35+
}
36+
37+
define i32 @cvt_rz_tf32_f32(float %f1) {
38+
; CHECK-LABEL: cvt_rz_tf32_f32(
39+
; CHECK: {
40+
; CHECK-NEXT: .reg .b32 %r<2>;
41+
; CHECK-NEXT: .reg .f32 %f<2>;
42+
; CHECK-EMPTY:
43+
; CHECK-NEXT: // %bb.0:
44+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0];
45+
; CHECK-NEXT: cvt.rz.tf32.f32 %r1, %f1;
46+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
47+
; CHECK-NEXT: ret;
48+
%val = call i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata !1, i8 0, i1 0)
49+
ret i32 %val
50+
}
51+
52+
define i32 @cvt_rz_relu_tf32_f32(float %f1) {
53+
; CHECK-LABEL: cvt_rz_relu_tf32_f32(
54+
; CHECK: {
55+
; CHECK-NEXT: .reg .b32 %r<2>;
56+
; CHECK-NEXT: .reg .f32 %f<2>;
57+
; CHECK-EMPTY:
58+
; CHECK-NEXT: // %bb.0:
59+
; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0];
60+
; CHECK-NEXT: cvt.rz.relu.tf32.f32 %r1, %f1;
61+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
62+
; CHECK-NEXT: ret;
63+
%val = call i32 @llvm.nvvm.convert.to.tf32.f32(float %f1, metadata !1, i8 0, i1 1)
64+
ret i32 %val
65+
}
66+
!0 = !{!"round.tonearest"}
67+
!1 = !{!"round.towardzero"}

0 commit comments

Comments
 (0)