Skip to content

Commit f050a09

Browse files
authored
[mlir][spirv] Remove enableFastMathMode flag from SPIR-V conversion (#86578)
Most of arith/math ops support fastmath attribute, use it instead of global flag.
1 parent 4c4ea24 commit f050a09

File tree

4 files changed

+10
-20
lines changed

4 files changed

+10
-20
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,6 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
172172
"bool", /*default=*/"true",
173173
"Emulate narrower scalar types with 32-bit ones if not supported by "
174174
"the target">,
175-
Option<"enableFastMath", "enable-fast-math",
176-
"bool", /*default=*/"false",
177-
"Enable fast math mode (assuming no NaN and infinity for floating "
178-
"point values) when performing conversion">
179175
];
180176
}
181177

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ struct SPIRVConversionOptions {
5555
/// values will be packed into one 32-bit value to be memory efficient.
5656
bool emulateLT32BitScalarTypes{true};
5757

58-
/// Whether to enable fast math mode during conversion. If true, various
59-
/// patterns would assume no NaN/infinity numbers as inputs, and thus there
60-
/// will be no special guards emitted to check and handle such cases.
61-
bool enableFastMathMode{false};
62-
6358
/// Use 64-bit integers when converting index types.
6459
bool use64bitIndex{false};
6560
};

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
995995
auto *converter = getTypeConverter<SPIRVTypeConverter>();
996996

997997
Value replace;
998-
if (converter->getOptions().enableFastMathMode) {
998+
if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
999999
if (op.getPredicate() == arith::CmpFPredicate::ORD) {
10001000
// Ordered comparsion checks if neither operand is NaN.
10011001
replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
@@ -1122,7 +1122,7 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
11221122
Value spirvOp =
11231123
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
11241124

1125-
if (converter->getOptions().enableFastMathMode) {
1125+
if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
11261126
rewriter.replaceOp(op, spirvOp);
11271127
return success();
11281128
}
@@ -1177,7 +1177,7 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
11771177
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
11781178

11791179
if (!shouldInsertNanGuards<SPIRVOp>() ||
1180-
converter->getOptions().enableFastMathMode) {
1180+
bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
11811181
rewriter.replaceOp(op, spirvOp);
11821182
return success();
11831183
}
@@ -1286,7 +1286,6 @@ struct ConvertArithToSPIRVPass
12861286

12871287
SPIRVConversionOptions options;
12881288
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1289-
options.enableFastMathMode = this->enableFastMath;
12901289
SPIRVTypeConverter typeConverter(targetAttr, options);
12911290

12921291
// Use UnrealizedConversionCast as the bridge so that we don't need to pull

mlir/test/Conversion/ArithToSPIRV/fast-math.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -split-input-file -convert-arith-to-spirv=enable-fast-math -verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s | FileCheck %s
22

33
module attributes {
44
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
@@ -8,7 +8,7 @@ module attributes {
88
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
99
func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 {
1010
// CHECK: %[[T:.+]] = spirv.Constant true
11-
%0 = arith.cmpf ord, %arg0, %arg1 : f32
11+
%0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
1212
// CHECK: return %[[T]]
1313
return %0: i1
1414
}
@@ -17,7 +17,7 @@ func.func @cmpf_ordered(%arg0 : f32, %arg1 : f32) -> i1 {
1717
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
1818
func.func @cmpf_unordered(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xi1> {
1919
// CHECK: %[[F:.+]] = spirv.Constant dense<false>
20-
%0 = arith.cmpf uno, %arg0, %arg1 : vector<4xf32>
20+
%0 = arith.cmpf uno, %arg0, %arg1 fastmath<nnan> : vector<4xf32>
2121
// CHECK: return %[[F]]
2222
return %0: vector<4xi1>
2323
}
@@ -34,7 +34,7 @@ module attributes {
3434
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
3535
func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
3636
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
37-
%0 = arith.minimumf %arg0, %arg1 : f32
37+
%0 = arith.minimumf %arg0, %arg1 fastmath<fast> : f32
3838
// CHECK: return %[[F]]
3939
return %0: f32
4040
}
@@ -43,7 +43,7 @@ func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
4343
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
4444
func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
4545
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
46-
%0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
46+
%0 = arith.maximumf %arg0, %arg1 fastmath<fast> : vector<4xf32>
4747
// CHECK: return %[[F]]
4848
return %0: vector<4xf32>
4949
}
@@ -52,7 +52,7 @@ func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf3
5252
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
5353
func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
5454
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
55-
%0 = arith.minnumf %arg0, %arg1 : f32
55+
%0 = arith.minnumf %arg0, %arg1 fastmath<fast> : f32
5656
// CHECK: return %[[F]]
5757
return %0: f32
5858
}
@@ -61,7 +61,7 @@ func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
6161
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
6262
func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
6363
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
64-
%0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
64+
%0 = arith.maxnumf %arg0, %arg1 fastmath<fast> : vector<4xf32>
6565
// CHECK: return %[[F]]
6666
return %0: vector<4xf32>
6767
}

0 commit comments

Comments
 (0)