-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Remove enableFastMathMode
flag from SPIR-V conversion
#86578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Most of arith/math ops support fastmath attribute, use it instead of global flag.
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesMost of arith/math ops support fastmath attribute, use it instead of global flag. Full diff: https://github.com/llvm/llvm-project/pull/86578.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7e7ee3a2f780f6..d094ee3b36ab95 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -172,10 +172,6 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
- Option<"enableFastMath", "enable-fast-math",
- "bool", /*default=*/"false",
- "Enable fast math mode (assuming no NaN and infinity for floating "
- "point values) when performing conversion">
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 933d62e35fce8c..09eecafc0c8a51 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -55,11 +55,6 @@ struct SPIRVConversionOptions {
/// values will be packed into one 32-bit value to be memory efficient.
bool emulateLT32BitScalarTypes{true};
- /// Whether to enable fast math mode during conversion. If true, various
- /// patterns would assume no NaN/infinity numbers as inputs, and thus there
- /// will be no special guards emitted to check and handle such cases.
- bool enableFastMathMode{false};
-
/// Use 64-bit integers when converting index types.
bool use64bitIndex{false};
};
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..eb338c2da4e887 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -995,7 +995,7 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
auto *converter = getTypeConverter<SPIRVTypeConverter>();
Value replace;
- if (converter->getOptions().enableFastMathMode) {
+ if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
if (op.getPredicate() == arith::CmpFPredicate::ORD) {
// Ordered comparsion checks if neither operand is NaN.
replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
@@ -1122,7 +1122,7 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
Value spirvOp =
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
- if (converter->getOptions().enableFastMathMode) {
+ if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
rewriter.replaceOp(op, spirvOp);
return success();
}
@@ -1177,7 +1177,7 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
if (!shouldInsertNanGuards<SPIRVOp>() ||
- converter->getOptions().enableFastMathMode) {
+ bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
rewriter.replaceOp(op, spirvOp);
return success();
}
@@ -1286,7 +1286,6 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
- options.enableFastMathMode = this->enableFastMath;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index dbf0361c2ab35b..9bbe28fb127a78 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-arith-to-spirv=enable-fast-math -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s | FileCheck %s
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
@@ -8,7 +8,7 @@ module attributes {
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 {
// CHECK: %[[T:.+]] = spirv.Constant true
- %0 = arith.cmpf ord, %arg0, %arg1 : f32
+ %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
// CHECK: return %[[T]]
return %0: i1
}
@@ -17,7 +17,7 @@ func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 {
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
func.func @cmpf_unordered(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xi1> {
// CHECK: %[[F:.+]] = spirv.Constant dense<false>
- %0 = arith.cmpf uno, %arg0, %arg1 : vector<4xf32>
+ %0 = arith.cmpf uno, %arg0, %arg1 fastmath<nnan> : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xi1>
}
@@ -34,7 +34,7 @@ module attributes {
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
- %0 = arith.minimumf %arg0, %arg1 : f32
+ %0 = arith.minimumf %arg0, %arg1 fastmath<fast> : f32
// CHECK: return %[[F]]
return %0: f32
}
@@ -43,7 +43,7 @@ func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
- %0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
+ %0 = arith.maximumf %arg0, %arg1 fastmath<fast> : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}
@@ -52,7 +52,7 @@ func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf3
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
- %0 = arith.minnumf %arg0, %arg1 : f32
+ %0 = arith.minnumf %arg0, %arg1 fastmath<fast> : f32
// CHECK: return %[[F]]
return %0: f32
}
@@ -61,7 +61,7 @@ func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
- %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
+ %0 = arith.maxnumf %arg0, %arg1 fastmath<fast> : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}
|
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me
%0 = arith.cmpf ord, %arg0, %arg1 : f32 | ||
%0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a pass in upstream MLIR to add fastmath attributes to arith
dialect ops?
We were using the spirv fast math mode downstream here: https://github.com/openxla/iree/blob/aacdd33eb57f79711ecae088dbb37c9bd17d7031/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp#L148-L159 . To adapt to this change, I think we'd want to add these attributes to all arith ops prior to running arith to spirv patterns here: https://github.com/openxla/iree/blob/aacdd33eb57f79711ecae088dbb37c9bd17d7031/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp#L655-L660.
So in our case, we start with modules that don't generally use fastmath, but then partway through our lowering process we specialize for one target that requires it. For that, we want this same "just use fast math" behavior, without needing to enumerate all possible ops that need an attribute ourselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fastify-math pass sounds like a good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a pass in upstream MLIR to add fastmath attributes to arith dialect ops?
I don't think upstream have such pass, but it should be straightforward to implement such pass without enumerating specific ops. All arith ops which support fastmath are implementing ArithFastMathInterface
, so pass may look something like this:
void runOnOperation() {
getOperation()->walk([](ArithFastMathInterface iface) {
... set fastmath attribute
});
}
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`. Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in llvm#86578.
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`. Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in #86578. Also do some misc cleanups in the surrounding code.
IREE-side changes to adapt to MLIR changes: 1. `initializeOptions` changes to adapt to llvm/llvm-project#87289 2. `enableFastMathMode` removal: llvm/llvm-project#86578. 3. Bazel changes to adapt to llvm/llvm-project#86819 IREE-side fixes for preexisting bugs revealed by a MLIR change: 1. `mlp_tosa` test fix: the shapes were inconsistent, used to accidentally work, until MLIR started catching it since llvm/llvm-project#85798. See diagnostic in [87396](llvm/llvm-project#87396 (comment)). FYI @MaheshRavishankar. IREE-side fixes accidentally lumped into this: 1. The `iree_copts.cmake` change: It just happens that my bleeding-edge Clang was updated and started diagnosing some code relying on C++20 semantics. Filed #16946 as TODO. --------- Co-authored-by: Scott Todd <[email protected]>
Most of arith/math ops support fastmath attribute, use it instead of global flag.