Skip to content

[MLIR][NVVM] Update Float to TF32 conversion Op #125048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Wolfram70
Copy link
Contributor

This change updates the Float to TF32 conversion MLIR Op to include lowering to the new intrinsics introduced in sm_100 through ptx8.6:

  • nvvm_f2tf32_rn_satfinite
  • nvvm_f2tf32_rn_relu_satfinite
  • nvvm_f2tf32_rz_satfinite
  • nvvm_f2tf32_rz_relu_satfinite

PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

@llvmbot
Copy link
Member

llvmbot commented Jan 30, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

Changes

This change updates the Float to TF32 conversion MLIR Op to include lowering to the new intrinsics introduced in sm_100 through ptx8.6:

  • nvvm_f2tf32_rn_satfinite
  • nvvm_f2tf32_rn_relu_satfinite
  • nvvm_f2tf32_rz_satfinite
  • nvvm_f2tf32_rz_relu_satfinite

PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt


Full diff: https://github.com/llvm/llvm-project/pull/125048.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+19-10)
  • (modified) mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir (+28)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (-16)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3e0a6987bd85b0..1ad20bb35273ea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -147,9 +147,6 @@ LogicalResult CvtFloatToTF32Op::verify() {
     break;
   case RndMode::RN:
   case RndMode::RZ:
-    if (getSat() != NVVM::SaturationMode::NONE)
-      return emitError(
-          "Saturation mode not supported with rn/rz rounding modes.");
     break;
   default:
     return emitError(
@@ -1225,17 +1222,29 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
                                                      NVVM::SaturationMode sat,
                                                      bool hasRelu) {
   using RndMode = NVVM::FPRoundingMode;
+  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+  bool hasReluAndSatFinite = hasRelu && hasSatFinite;
   switch (rnd) {
   case RndMode::RN:
-    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
-                   : llvm::Intrinsic::nvvm_f2tf32_rn;
+    if(hasReluAndSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rn_relu_satfinite;
+    if(hasRelu)
+      return llvm::Intrinsic::nvvm_f2tf32_rn_relu;
+    if(hasSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rn_satfinite;
+    return llvm::Intrinsic::nvvm_f2tf32_rn;
   case RndMode::RZ:
-    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
-                   : llvm::Intrinsic::nvvm_f2tf32_rz;
+    if(hasReluAndSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rz_relu_satfinite;
+    if(hasRelu)
+      return llvm::Intrinsic::nvvm_f2tf32_rz_relu;
+    if(hasSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rz_satfinite;
+    return llvm::Intrinsic::nvvm_f2tf32_rz;
   case RndMode::RNA:
-    return (sat == NVVM::SaturationMode::SATFINITE)
-               ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
-               : llvm::Intrinsic::nvvm_f2tf32_rna;
+    return hasSatFinite
+            ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
+            : llvm::Intrinsic::nvvm_f2tf32_rna;
   default:
     llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
   }
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
index 90a232e4baac6f..ff7bad0149d4cf 100644
--- a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
@@ -28,6 +28,20 @@ llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
   llvm.return %res : i32
 }
 
+// CHECK-LABEL: @convert_float_to_tf32_rn_sf
+llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn_relu_sf
+llvm.func @convert_float_to_tf32_rn_relu_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
 // CHECK-LABEL: @convert_float_to_tf32_rz
 llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
   // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
@@ -41,3 +55,17 @@ llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
   %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
   llvm.return %res : i32
 }
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_sf
+llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_relu_sf
+llvm.func @convert_float_to_tf32_rz_relu_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index cb08064590bc30..8957377607dad6 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -117,22 +117,6 @@ llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
 
 // -----
 
-llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
-  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
-  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
-  llvm.return %res : i32
-}
-
-// -----
-
-llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
-  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
-  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
-  llvm.return %res : i32
-}
-
-// -----
-
 llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
   // expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
   %res = nvvm.cvt.float.to.tf32 %src

Copy link

github-actions bot commented Jan 30, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/nvvm-mlir-f2tf32-op-update branch from 1227749 to 18f2268 Compare January 31, 2025 06:41
Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This change updates the Float to TF32 conversion MLIR Op to include
lowering to the new intrinsics introduced in sm_100 through ptx8.6:

- nvvm_f2tf32_rn_satfinite
- nvvm_f2tf32_rn_relu_satfinite
- nvvm_f2tf32_rz_satfinite
- nvvm_f2tf32_rz_relu_satfinite

PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/nvvm-mlir-f2tf32-op-update branch from 18f2268 to 4b7ceb3 Compare January 31, 2025 11:15
@durga4github durga4github merged commit 83cad68 into llvm:main Feb 1, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants