Skip to content

[MLIR][NVVM] Add Float to TF32 conversion Op #123199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,77 @@ def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.share
}];
}

//===----------------------------------------------------------------------===//
// NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
//===----------------------------------------------------------------------===//

// Attributes for the floating point rounding modes supported by PTX
def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
def FPRoundingModeRN : I32EnumAttrCase<"RN", 1, "rn">;
def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;

def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
[FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
let assemblyFormat = "`<` $value `>`";
}

def SaturationModeNone : I32EnumAttrCase<"NONE", 0, "none">;
def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;

def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
[SaturationModeNone, SaturationModeFinite]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
let summary = "Convert the given float input to TF32";
let description = [{
This Op converts the given f32 input to tf32.
The result `res` is represented as an i32 type.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction. The `rnd` and `sat` attributes specify the
the rounding and saturation modes respectively.
[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];

let hasVerifier = 1;
let results = (outs I32:$res);
let arguments = (ins
F32:$src,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu);

let assemblyFormat = "$src attr-dict";

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode,
NVVM::SaturationMode,
bool hasRelu);
}];

string llvmBuilder = [{
auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
$res = createIntrinsicCall(builder, intId, {$src});
}];
}

//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
/// Helpers to instantiate different version of wmma intrinsics.
/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
/// combinations of the intrinsics.
Expand Down
40 changes: 40 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,26 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
getLoc());
}

LogicalResult CvtFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
case RndMode::RNA:
if (getRelu())
return emitError("Relu not supported with rna rounding mode.");
break;
case RndMode::RN:
case RndMode::RZ:
if (getSat() != NVVM::SaturationMode::NONE)
return emitError(
"Saturation mode not supported with rn/rz rounding modes.");
break;
default:
return emitError(
"Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
}
return success();
}

// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
Expand Down Expand Up @@ -1163,6 +1183,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
}

llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu) {
using RndMode = NVVM::FPRoundingMode;
switch (rnd) {
case RndMode::RN:
return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
: llvm::Intrinsic::nvvm_f2tf32_rn;
case RndMode::RZ:
return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
: llvm::Intrinsic::nvvm_f2tf32_rz;
case RndMode::RNA:
return (sat == NVVM::SaturationMode::SATFINITE)
? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
: llvm::Intrinsic::nvvm_f2tf32_rna;
default:
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
}
}

/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s

// CHECK-LABEL: @convert_float_to_tf32_rna
llvm.func @convert_float_to_tf32_rna(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna(float %{{.*}})
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>}
llvm.return %res : i32
}

// CHECK-LABEL: @convert_float_to_tf32_rna_sf
llvm.func @convert_float_to_tf32_rna_sf(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %{{.*}})
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, sat = #nvvm.sat_mode<satfinite>}
llvm.return %res : i32
}

// CHECK-LABEL: @convert_float_to_tf32_rn
llvm.func @convert_float_to_tf32_rn(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn(float %{{.*}})
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>}
llvm.return %res : i32
}

// CHECK-LABEL: @convert_float_to_tf32_rn_relu
llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu(float %{{.*}})
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true}
llvm.return %res : i32
}

// CHECK-LABEL: @convert_float_to_tf32_rz
llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>}
llvm.return %res : i32
}

// CHECK-LABEL: @convert_float_to_tf32_rz_relu
llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu(float %{{.*}})
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
llvm.return %res : i32
}
32 changes: 32 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,35 @@ llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0
nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
llvm.return
}

// -----

llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
// expected-error @below {{Relu not supported with rna rounding mode.}}
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, relu=true}
llvm.return %res : i32
}

// -----

llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
// expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
llvm.return %res : i32
}

// -----

llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
// expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
llvm.return %res : i32
}

// -----

llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
// expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
%res = nvvm.cvt.float.to.tf32 %src
llvm.return %res : i32
}
Loading