Skip to content

[mlir][spirv] Remove enableFastMathMode flag from SPIR-V conversion #86578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 26, 2024

Conversation

Hardcode84
Copy link
Contributor

Most of arith/math ops support fastmath attribute, use it instead of global flag.

Most of arith/math ops support fastmath attribute, use it instead of global flag.
@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Most of arith/math ops support fastmath attribute, use it instead of global flag.


Full diff: https://github.com/llvm/llvm-project/pull/86578.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (-4)
  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (-5)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+3-4)
  • (modified) mlir/test/Conversion/ArithToSPIRV/fast-math.mlir (+7-7)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7e7ee3a2f780f6..d094ee3b36ab95 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -172,10 +172,6 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by "
            "the target">,
-    Option<"enableFastMath", "enable-fast-math",
-           "bool", /*default=*/"false",
-           "Enable fast math mode (assuming no NaN and infinity for floating "
-           "point values) when performing conversion">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 933d62e35fce8c..09eecafc0c8a51 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -55,11 +55,6 @@ struct SPIRVConversionOptions {
   /// values will be packed into one 32-bit value to be memory efficient.
   bool emulateLT32BitScalarTypes{true};
 
-  /// Whether to enable fast math mode during conversion. If true, various
-  /// patterns would assume no NaN/infinity numbers as inputs, and thus there
-  /// will be no special guards emitted to check and handle such cases.
-  bool enableFastMathMode{false};
-
   /// Use 64-bit integers when converting index types.
   bool use64bitIndex{false};
 };
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..eb338c2da4e887 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -995,7 +995,7 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
     auto *converter = getTypeConverter<SPIRVTypeConverter>();
 
     Value replace;
-    if (converter->getOptions().enableFastMathMode) {
+    if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
       if (op.getPredicate() == arith::CmpFPredicate::ORD) {
         // Ordered comparsion checks if neither operand is NaN.
         replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
@@ -1122,7 +1122,7 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
     Value spirvOp =
         rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
 
-    if (converter->getOptions().enableFastMathMode) {
+    if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
       rewriter.replaceOp(op, spirvOp);
       return success();
     }
@@ -1177,7 +1177,7 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
         rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
 
     if (!shouldInsertNanGuards<SPIRVOp>() ||
-        converter->getOptions().enableFastMathMode) {
+        bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
       rewriter.replaceOp(op, spirvOp);
       return success();
     }
@@ -1286,7 +1286,6 @@ struct ConvertArithToSPIRVPass
 
     SPIRVConversionOptions options;
     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
-    options.enableFastMathMode = this->enableFastMath;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index dbf0361c2ab35b..9bbe28fb127a78 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-arith-to-spirv=enable-fast-math -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s | FileCheck %s
 
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
@@ -8,7 +8,7 @@ module attributes {
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
 func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 {
   // CHECK: %[[T:.+]] = spirv.Constant true
-  %0 = arith.cmpf ord, %arg0, %arg1 : f32
+  %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
   // CHECK: return %[[T]]
   return %0: i1
 }
@@ -17,7 +17,7 @@ func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 {
 // CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
 func.func @cmpf_unordered(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xi1> {
   // CHECK: %[[F:.+]] = spirv.Constant dense<false>
-  %0 = arith.cmpf uno, %arg0, %arg1 : vector<4xf32>
+  %0 = arith.cmpf uno, %arg0, %arg1 fastmath<nnan> : vector<4xf32>
   // CHECK: return %[[F]]
   return %0: vector<4xi1>
 }
@@ -34,7 +34,7 @@ module attributes {
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
 func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
-  %0 = arith.minimumf %arg0, %arg1 : f32
+  %0 = arith.minimumf %arg0, %arg1 fastmath<fast> : f32
   // CHECK: return %[[F]]
   return %0: f32
 }
@@ -43,7 +43,7 @@ func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
 // CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
 func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
   // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
-  %0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
+  %0 = arith.maximumf %arg0, %arg1 fastmath<fast> : vector<4xf32>
   // CHECK: return %[[F]]
   return %0: vector<4xf32>
 }
@@ -52,7 +52,7 @@ func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf3
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
 func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
-  %0 = arith.minnumf %arg0, %arg1 : f32
+  %0 = arith.minnumf %arg0, %arg1 fastmath<fast> : f32
   // CHECK: return %[[F]]
   return %0: f32
 }
@@ -61,7 +61,7 @@ func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
 // CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
 func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
   // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
-  %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
+  %0 = arith.maxnumf %arg0, %arg1 fastmath<fast> : vector<4xf32>
   // CHECK: return %[[F]]
   return %0: vector<4xf32>
 }

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me

@Hardcode84 Hardcode84 merged commit f050a09 into llvm:main Mar 26, 2024
@Hardcode84 Hardcode84 deleted the spirv-fastmath branch March 26, 2024 17:06
%0 = arith.cmpf ord, %arg0, %arg1 : f32
%0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a pass in upstream MLIR to add fastmath attributes to arith dialect ops?

We were using the spirv fast math mode downstream here: https://github.com/openxla/iree/blob/aacdd33eb57f79711ecae088dbb37c9bd17d7031/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp#L148-L159 . To adapt to this change, I think we'd want to add these attributes to all arith ops prior to running arith to spirv patterns here: https://github.com/openxla/iree/blob/aacdd33eb57f79711ecae088dbb37c9bd17d7031/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp#L655-L660.

So in our case, we start with modules that don't generally use fastmath, but then partway through our lowering process we specialize for one target that requires it. For that, we want this same "just use fast math" behavior, without needing to enumerate all possible ops that need an attribute ourselves.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fastify-math pass sounds like a good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a pass in upstream MLIR to add fastmath attributes to arith dialect ops?

I don't think upstream have such pass, but it should be straightforward to implement such pass without enumerating specific ops. All arith ops which support fastmath are implementing ArithFastMathInterface, so pass may look something like this:

void runOnOperation() {
    getOperation()->walk([](ArithFastMathInterface iface) {
        ... set fastmath attribute
    });
}

kuhar added a commit to kuhar/llvm-project that referenced this pull request Mar 28, 2024
These non-finite math ops are supported by SPIR-V but not by WGSL.
Assume finite floating point values and expand these ops into `false`.

Previously, this worked by adding fast math flags during conversion from
arith to spirv, but this got removed in
llvm#86578.
kuhar added a commit that referenced this pull request Mar 28, 2024
These non-finite math ops are supported by SPIR-V but not by WGSL.
Assume finite floating point values and expand these ops into `false`.

Previously, this worked by adding fast math flags during conversion from
arith to spirv, but this got removed in
#86578.

Also do some misc cleanups in the surrounding code.
bjacob added a commit to iree-org/iree that referenced this pull request Apr 3, 2024
IREE-side changes to adapt to MLIR changes:
1. `initializeOptions` changes to adapt to
llvm/llvm-project#87289
2. `enableFastMathMode` removal:
llvm/llvm-project#86578.
3. Bazel changes to adapt to
llvm/llvm-project#86819

IREE-side fixes for preexisting bugs revealed by a MLIR change:
1. `mlp_tosa` test fix: the shapes were inconsistent, used to
accidentally work, until MLIR started catching it since
llvm/llvm-project#85798. See diagnostic in
[87396](llvm/llvm-project#87396 (comment)).
FYI @MaheshRavishankar.

IREE-side fixes accidentally lumped into this:
1. The `iree_copts.cmake` change: It just happens that my bleeding-edge
Clang was updated and started diagnosing some code relying on C++20
semantics. Filed #16946 as TODO.

---------

Co-authored-by: Scott Todd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants