-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][NVVM] Update Float to TF32 conversion Op #125048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Update Float to TF32 conversion Op #125048
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change updates the Float to TF32 conversion MLIR Op to include lowering to the new intrinsics introduced in sm_100 through ptx8.6:
PTX Spec Reference: Full diff: https://github.com/llvm/llvm-project/pull/125048.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3e0a6987bd85b0..1ad20bb35273ea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -147,9 +147,6 @@ LogicalResult CvtFloatToTF32Op::verify() {
break;
case RndMode::RN:
case RndMode::RZ:
- if (getSat() != NVVM::SaturationMode::NONE)
- return emitError(
- "Saturation mode not supported with rn/rz rounding modes.");
break;
default:
return emitError(
@@ -1225,17 +1222,29 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu) {
using RndMode = NVVM::FPRoundingMode;
+ bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+ bool hasReluAndSatFinite = hasRelu && hasSatFinite;
switch (rnd) {
case RndMode::RN:
- return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
- : llvm::Intrinsic::nvvm_f2tf32_rn;
+ if(hasReluAndSatFinite)
+ return llvm::Intrinsic::nvvm_f2tf32_rn_relu_satfinite;
+ if(hasRelu)
+ return llvm::Intrinsic::nvvm_f2tf32_rn_relu;
+ if(hasSatFinite)
+ return llvm::Intrinsic::nvvm_f2tf32_rn_satfinite;
+ return llvm::Intrinsic::nvvm_f2tf32_rn;
case RndMode::RZ:
- return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
- : llvm::Intrinsic::nvvm_f2tf32_rz;
+ if(hasReluAndSatFinite)
+ return llvm::Intrinsic::nvvm_f2tf32_rz_relu_satfinite;
+ if(hasRelu)
+ return llvm::Intrinsic::nvvm_f2tf32_rz_relu;
+ if(hasSatFinite)
+ return llvm::Intrinsic::nvvm_f2tf32_rz_satfinite;
+ return llvm::Intrinsic::nvvm_f2tf32_rz;
case RndMode::RNA:
- return (sat == NVVM::SaturationMode::SATFINITE)
- ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
- : llvm::Intrinsic::nvvm_f2tf32_rna;
+ return hasSatFinite
+ ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
+ : llvm::Intrinsic::nvvm_f2tf32_rna;
default:
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
index 90a232e4baac6f..ff7bad0149d4cf 100644
--- a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
@@ -28,6 +28,20 @@ llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
llvm.return %res : i32
}
+// CHECK-LABEL: @convert_float_to_tf32_rn_sf
+llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn_relu_sf
+llvm.func @convert_float_to_tf32_rn_relu_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
// CHECK-LABEL: @convert_float_to_tf32_rz
llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
@@ -41,3 +55,17 @@ llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
llvm.return %res : i32
}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_sf
+llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_relu_sf
+llvm.func @convert_float_to_tf32_rz_relu_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index cb08064590bc30..8957377607dad6 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -117,22 +117,6 @@ llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
// -----
-llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
- // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
- %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
- llvm.return %res : i32
-}
-
-// -----
-
-llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
- // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
- %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
- llvm.return %res : i32
-}
-
-// -----
-
llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
// expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
%res = nvvm.cvt.float.to.tf32 %src
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
1227749
to
18f2268
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This change updates the Float to TF32 conversion MLIR Op to include lowering to the new intrinsics introduced in sm_100 through ptx8.6: - nvvm_f2tf32_rn_satfinite - nvvm_f2tf32_rn_relu_satfinite - nvvm_f2tf32_rz_satfinite - nvvm_f2tf32_rz_relu_satfinite PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
18f2268
to
4b7ceb3
Compare
This change updates the Float to TF32 conversion MLIR Op to include lowering to the new intrinsics introduced in sm_100 through ptx8.6:
nvvm_f2tf32_rn_satfinite
nvvm_f2tf32_rn_relu_satfinite
nvvm_f2tf32_rz_satfinite
nvvm_f2tf32_rz_relu_satfinite
PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt