Skip to content

Commit 98b22fd

Browse files
committed
[mlir][Arith] Add FTZ (Flush-to-Zero) fast-math flag
The Flush to Zero (FTZ) modifier is used in floating-point arithmetic to set  very small numbers, known as denormal or subnormal numbers, to zero. FTZ is done to improve performance, as handling these small numbers can slow down computations. Note that this attribute does not specify if the rounding happens toward positive or negative zero since it is architecture (or vendor)-dependent.
1 parent cf046c8 commit 98b22fd

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,22 @@ def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>;
108108
def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>;
109109
def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>;
110110
def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>;
111+
def FASTMATH_FTZ : I32BitEnumAttrCaseBit<"ftz", 7>;
111112
def FASTMATH_FAST : I32BitEnumAttrCaseGroup<
112113
"fast",
113114
[
114115
FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS,
115116
FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT,
116-
FASTMATH_APPROX_FUNC]>;
117+
FASTMATH_APPROX_FUNC, FASTMATH_FTZ]>;
117118

118119
def FastMathFlags : I32BitEnumAttr<
119120
"FastMathFlags",
120121
"Floating point fast math flags",
121122
[
122123
FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS,
123124
FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP,
124-
FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> {
125+
FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FTZ,
126+
FASTMATH_FAST]> {
125127
let separator = ",";
126128
let cppNamespace = "::mlir::arith";
127129
let genSpecializedAttr = 0;

mlir/test/Dialect/Arith/ops.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
11271127
// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
11281128
%7 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
11291129
// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath<fast> : f32
1130-
%8 = arith.mulf %arg0, %arg1 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
1130+
%8 = arith.mulf %arg0, %arg1 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn,ftz> : f32
11311131
// CHECK: {{.*}} = arith.cmpf oeq, %arg0, %arg1 fastmath<fast> : f32
11321132
%9 = arith.cmpf oeq, %arg0, %arg1 fastmath<fast> : f32
11331133

@@ -1161,3 +1161,17 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
11611161
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
11621162
return
11631163
}
1164+
1165+
// CHECK-LABEL: flush_to_zero
1166+
// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32
1167+
func.func @flush_to_zero(%arg0: f32, %arg1: f32) {
1168+
// CHECK: %{{.+}} = arith.addf %[[ARG0]], %[[ARG1]] fastmath<ftz> : f32
1169+
// CHECK-NEXT: %{{.+}} = arith.subf %[[ARG0]], %[[ARG1]] fastmath<ftz> : f32
1170+
// CHECK-NEXT: %{{.+}} = arith.mulf %[[ARG0]], %[[ARG1]] fastmath<ftz> : f32
1171+
// CHECK-NEXT: %{{.+}} = arith.divf %[[ARG0]], %[[ARG1]] fastmath<ftz> : f32
1172+
%0 = arith.addf %arg0, %arg1 fastmath<ftz> : f32
1173+
%1 = arith.subf %arg0, %arg1 fastmath<ftz> : f32
1174+
%2 = arith.mulf %arg0, %arg1 fastmath<ftz> : f32
1175+
%3 = arith.divf %arg0, %arg1 fastmath<ftz> : f32
1176+
return
1177+
}

mlir/test/Dialect/Math/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func.func @fastmath(%f: f32, %i: i32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>)
289289
// CHECK: math.trunc %[[F]] fastmath<fast> : f32
290290
%0 = math.trunc %f fastmath<fast> : f32
291291
// CHECK: math.powf %[[V]], %[[V]] fastmath<fast> : vector<4xf32>
292-
%1 = math.powf %v, %v fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : vector<4xf32>
292+
%1 = math.powf %v, %v fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn,ftz> : vector<4xf32>
293293
// CHECK: math.fma %[[T]], %[[T]], %[[T]] : tensor<4x4x?xf32>
294294
%2 = math.fma %t, %t, %t fastmath<none> : tensor<4x4x?xf32>
295295
// CHECK: math.absf %[[F]] fastmath<ninf> : f32

0 commit comments

Comments
 (0)