Skip to content

Commit 6dcb2a0

Browse files
authored
[MLIR][NVVM] Add Float to TF32 conversion Op (#123199)
PR #121507 added 'cvt' intrinsics to convert float to tf32, with the valid set of rounding and saturation modes. This PR adds an NVVM Dialect Op for the same. * lit tests are added to verify the lowering to intrinsics. * Negative tests are also added to check the error-handling of invalid combinations. Signed-off-by: Durgadoss R <[email protected]>
1 parent f597d34 commit 6dcb2a0

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,77 @@ def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.share
970970
}];
971971
}
972972

973+
//===----------------------------------------------------------------------===//
974+
// NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
975+
//===----------------------------------------------------------------------===//
976+
977+
// Attributes for the floating point rounding modes supported by PTX
978+
def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
979+
def FPRoundingModeRN : I32EnumAttrCase<"RN", 1, "rn">;
980+
def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
981+
def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
982+
def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
983+
def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;
984+
985+
def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
986+
[FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
987+
FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
988+
let genSpecializedAttr = 0;
989+
let cppNamespace = "::mlir::NVVM";
990+
}
991+
def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
992+
let assemblyFormat = "`<` $value `>`";
993+
}
994+
995+
def SaturationModeNone : I32EnumAttrCase<"NONE", 0, "none">;
996+
def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
997+
998+
def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
999+
[SaturationModeNone, SaturationModeFinite]> {
1000+
let genSpecializedAttr = 0;
1001+
let cppNamespace = "::mlir::NVVM";
1002+
}
1003+
def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
1004+
let assemblyFormat = "`<` $value `>`";
1005+
}
1006+
1007+
def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
1008+
let summary = "Convert the given float input to TF32";
1009+
let description = [{
1010+
This Op converts the given f32 input to tf32.
1011+
The result `res` is represented as an i32 type.
1012+
The `relu` attribute, when set, lowers to the '.relu' variant of
1013+
the cvt instruction. The `rnd` and `sat` attributes specify the
1014+
the rounding and saturation modes respectively.
1015+
[For more information, see PTX ISA]
1016+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1017+
}];
1018+
1019+
let hasVerifier = 1;
1020+
let results = (outs I32:$res);
1021+
let arguments = (ins
1022+
F32:$src,
1023+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1024+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1025+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1026+
1027+
let assemblyFormat = "$src attr-dict";
1028+
1029+
let extraClassDeclaration = [{
1030+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode,
1031+
NVVM::SaturationMode,
1032+
bool hasRelu);
1033+
}];
1034+
1035+
string llvmBuilder = [{
1036+
auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
1037+
$res = createIntrinsicCall(builder, intId, {$src});
1038+
}];
1039+
}
1040+
1041+
//===----------------------------------------------------------------------===//
1042+
// NVVM MMA Ops
1043+
//===----------------------------------------------------------------------===//
9731044
/// Helpers to instantiate different version of wmma intrinsics.
9741045
/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
9751046
/// combinations of the intrinsics.

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,26 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
138138
getLoc());
139139
}
140140

141+
LogicalResult CvtFloatToTF32Op::verify() {
142+
using RndMode = NVVM::FPRoundingMode;
143+
switch (getRnd()) {
144+
case RndMode::RNA:
145+
if (getRelu())
146+
return emitError("Relu not supported with rna rounding mode.");
147+
break;
148+
case RndMode::RN:
149+
case RndMode::RZ:
150+
if (getSat() != NVVM::SaturationMode::NONE)
151+
return emitError(
152+
"Saturation mode not supported with rn/rz rounding modes.");
153+
break;
154+
default:
155+
return emitError(
156+
"Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
157+
}
158+
return success();
159+
}
160+
141161
// Given the element type of an operand and whether or not it is an accumulator,
142162
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
143163
// operand's element type.
@@ -1163,6 +1183,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
11631183
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
11641184
}
11651185

1186+
llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1187+
NVVM::SaturationMode sat,
1188+
bool hasRelu) {
1189+
using RndMode = NVVM::FPRoundingMode;
1190+
switch (rnd) {
1191+
case RndMode::RN:
1192+
return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
1193+
: llvm::Intrinsic::nvvm_f2tf32_rn;
1194+
case RndMode::RZ:
1195+
return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
1196+
: llvm::Intrinsic::nvvm_f2tf32_rz;
1197+
case RndMode::RNA:
1198+
return (sat == NVVM::SaturationMode::SATFINITE)
1199+
? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
1200+
: llvm::Intrinsic::nvvm_f2tf32_rna;
1201+
default:
1202+
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1203+
}
1204+
}
1205+
11661206
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
11671207
/// have ConstantRangeAttr.
11681208
static void nvvmInferResultRanges(Operation *op, Value result,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_float_to_tf32_rna
4+
llvm.func @convert_float_to_tf32_rna(%src : f32) -> i32 {
5+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna(float %{{.*}})
6+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>}
7+
llvm.return %res : i32
8+
}
9+
10+
// CHECK-LABEL: @convert_float_to_tf32_rna_sf
11+
llvm.func @convert_float_to_tf32_rna_sf(%src : f32) -> i32 {
12+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %{{.*}})
13+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, sat = #nvvm.sat_mode<satfinite>}
14+
llvm.return %res : i32
15+
}
16+
17+
// CHECK-LABEL: @convert_float_to_tf32_rn
18+
llvm.func @convert_float_to_tf32_rn(%src : f32) -> i32 {
19+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn(float %{{.*}})
20+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>}
21+
llvm.return %res : i32
22+
}
23+
24+
// CHECK-LABEL: @convert_float_to_tf32_rn_relu
25+
llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
26+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu(float %{{.*}})
27+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true}
28+
llvm.return %res : i32
29+
}
30+
31+
// CHECK-LABEL: @convert_float_to_tf32_rz
32+
llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
33+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
34+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>}
35+
llvm.return %res : i32
36+
}
37+
38+
// CHECK-LABEL: @convert_float_to_tf32_rz_relu
39+
llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
40+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu(float %{{.*}})
41+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
42+
llvm.return %res : i32
43+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,35 @@ llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0
106106
nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
107107
llvm.return
108108
}
109+
110+
// -----
111+
112+
llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
113+
// expected-error @below {{Relu not supported with rna rounding mode.}}
114+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, relu=true}
115+
llvm.return %res : i32
116+
}
117+
118+
// -----
119+
120+
llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
121+
// expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
122+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
123+
llvm.return %res : i32
124+
}
125+
126+
// -----
127+
128+
llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
129+
// expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
130+
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
131+
llvm.return %res : i32
132+
}
133+
134+
// -----
135+
136+
llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
137+
// expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
138+
%res = nvvm.cvt.float.to.tf32 %src
139+
llvm.return %res : i32
140+
}

0 commit comments

Comments
 (0)