Skip to content

[MLIR][NVVM] Add Float to TF32 conversion Op #123199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025

Conversation

durga4github
Copy link
Contributor

PR #121507 added 'cvt' intrinsics to convert
float to tf32, with the valid set of rounding and
saturation modes. This PR adds an NVVM Dialect Op
for the same.

  • lit tests are added to verify the lowering to intrinsics.
  • Negative tests are also added to check the error-handling of invalid combinations.

PR llvm#121507 added 'cvt' intrinsics to convert
float to tf32, with the valid set of rounding and
saturation modes. This PR adds an NVVM Dialect Op
for the same.
* lit tests are added to verify the lowering to
  intrinsics.
* Negative tests are also added to check the
  error-handling of invalid combinations.

Signed-off-by: Durgadoss R <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

Changes

PR #121507 added 'cvt' intrinsics to convert
float to tf32, with the valid set of rounding and
saturation modes. This PR adds an NVVM Dialect Op
for the same.

  • lit tests are added to verify the lowering to intrinsics.
  • Negative tests are also added to check the error-handling of invalid combinations.

Full diff: https://github.com/llvm/llvm-project/pull/123199.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+71)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+40)
  • (added) mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir (+43)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 04042903e343ed..bf3131932a56bc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -970,6 +970,77 @@ def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.share
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
+//===----------------------------------------------------------------------===//
+
+// Attributes for the floating point rounding modes supported by PTX
+def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
+def FPRoundingModeRN   : I32EnumAttrCase<"RN",   1, "rn">;
+def FPRoundingModeRM   : I32EnumAttrCase<"RM",   2, "rm">;
+def FPRoundingModeRP   : I32EnumAttrCase<"RP",   3, "rp">;
+def FPRoundingModeRZ   : I32EnumAttrCase<"RZ",   4, "rz">;
+def FPRoundingModeRNA  : I32EnumAttrCase<"RNA",  5, "rna">;
+
+def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
+  [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
+    FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def SaturationModeNone   : I32EnumAttrCase<"NONE", 0, "none">;
+def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
+
+def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
+  [SaturationModeNone, SaturationModeFinite]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
+  let summary = "Convert the given float input to TF32";
+  let description = [{
+    This Op converts the given f32 input to tf32.
+    The result `res` is represented as an i32 type.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction. The `rnd` and `sat` attributes specify the
+    the rounding and saturation modes respectively.
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+
+  let hasVerifier = 1;
+  let results = (outs I32:$res);
+  let arguments = (ins
+    F32:$src,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+
+  let assemblyFormat = "$src attr-dict";
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode,
+                                              NVVM::SaturationMode,
+                                              bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
+    $res = createIntrinsicCall(builder, intId, {$src});
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM MMA Ops
+//===----------------------------------------------------------------------===//
 /// Helpers to instantiate different version of wmma intrinsics.
 /// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
 /// combinations of the intrinsics.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d8fde3e765ac49..ccb5ad05f0bf72 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -138,6 +138,26 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
                                          getLoc());
 }
 
+LogicalResult CvtFloatToTF32Op::verify() {
+  using RndMode = NVVM::FPRoundingMode;
+  switch (getRnd()) {
+  case RndMode::RNA:
+    if (getRelu())
+      return emitError("Relu not supported with rna rounding mode.");
+    break;
+  case RndMode::RN:
+  case RndMode::RZ:
+    if (getSat() != NVVM::SaturationMode::NONE)
+      return emitError(
+          "Saturation mode not supported with rn/rz rounding modes.");
+    break;
+  default:
+    return emitError(
+        "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
+  }
+  return success();
+}
+
 // Given the element type of an operand and whether or not it is an accumulator,
 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
 // operand's element type.
@@ -1163,6 +1183,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
   llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
 }
 
+llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
+                                                     NVVM::SaturationMode sat,
+                                                     bool hasRelu) {
+  using RndMode = NVVM::FPRoundingMode;
+  switch (rnd) {
+  case RndMode::RN:
+    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
+                   : llvm::Intrinsic::nvvm_f2tf32_rn;
+  case RndMode::RZ:
+    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
+                   : llvm::Intrinsic::nvvm_f2tf32_rz;
+  case RndMode::RNA:
+    return (sat == NVVM::SaturationMode::SATFINITE)
+               ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
+               : llvm::Intrinsic::nvvm_f2tf32_rna;
+  default:
+    llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
+  }
+}
+
 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
 /// have ConstantRangeAttr.
 static void nvvmInferResultRanges(Operation *op, Value result,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
new file mode 100644
index 00000000000000..90a232e4baac6f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate -mlir-to-llvmir %s  -split-input-file --verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_tf32_rna
+llvm.func @convert_float_to_tf32_rna(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rna_sf
+llvm.func @convert_float_to_tf32_rna_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn
+llvm.func @convert_float_to_tf32_rn(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn_relu
+llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz
+llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_relu
+llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
+  llvm.return %res : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 44c7126255dc4f..cb08064590bc30 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -106,3 +106,35 @@ llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0
   nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
   llvm.return
 }
+
+// -----
+
+llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
+  // expected-error @below {{Relu not supported with rna rounding mode.}}
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, relu=true}
+  llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
+  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
+  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
+  // expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
+  %res = nvvm.cvt.float.to.tf32 %src
+  llvm.return %res : i32
+}

@durga4github
Copy link
Contributor Author

@grypp , Please help with the review.

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have an NVVM OP for this conversion since we have PTX. LGTM!

Orthogonal topic: We also need a generic conversion operation that can be used with vectors. This operation should convert anything to anything. We should place that operation in the NVGPU dialect, as it's a bridge dialect. We can revisit this problem when we need it. I'm just thinking out loud here.

@durga4github
Copy link
Contributor Author

We should have an NVVM OP for this conversion since we have PTX. LGTM!

Orthogonal topic: We also need a generic conversion operation that can be used with vectors. This operation should convert anything to anything. We should place that operation in the NVGPU dialect, as it's a bridge dialect. We can revisit this problem when we need it. I'm just thinking out loud here.

Yes, totally agree. I have a few more CVT ops to add matching the available intrinsics/PTX.
Once we have a set of them, I believe, we will have a better picture (a clear pattern) to build NVGPU Op(s).

@durga4github durga4github merged commit 6dcb2a0 into llvm:main Jan 17, 2025
11 checks passed
@durga4github durga4github deleted the durgadossr/mlir_cvt_tf32 branch January 17, 2025 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants