Skip to content

Commit f363cfa

Browse files
[mlir][tosa][tosa-to-linalg] Ignore Int NaN Mode (#129041)
For non floating point operations NaN propagation mode has no meaning and can be safely ignored. For non integer types skip the compare and select materialization for NaN propagation even in "IGNORE" mode. This fixes a bug where an unchecked `cast<FloatType>()` was called in the "IGNORE" case even when the operation is acting on integers. Update the lit tests for the NaN propagation lowering to check that the propagation logic is not materialized in the case of a non floating point type e.g. i8. Signed-off-by: Jack Frankland <[email protected]>
1 parent 36f0838 commit f363cfa

File tree

4 files changed

+128
-2
lines changed

4 files changed

+128
-2
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ using namespace mlir::tosa;
4949
// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
5050
// code:
5151
//
52+
// In the case that the op is operating on non floating point types we ignore
53+
// the attribute completely, this is consistent with the TOSA spec which has
54+
// the following wording: "This attribute is ignored by non floating-point
55+
// types."
56+
//
5257
// binary<op>(lhs, rhs):
5358
// result = op(lhs, rhs)
5459
// if lhs == NaN return rhs
@@ -58,6 +63,10 @@ template <typename OpTy>
5863
static Value
5964
materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
6065
Value lhs, Value rhs, Value result) {
66+
// NaN propagation has no meaning for non floating point types.
67+
if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
68+
return result;
69+
6170
auto nanMode = op.getNanMode();
6271
if (nanMode == "PROPAGATE")
6372
return result;
@@ -449,6 +458,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
449458

450459
auto clampOp = llvm::cast<tosa::ClampOp>(op);
451460
const auto nanMode = clampOp.getNanMode();
461+
462+
// NaN propagation has no meaning for non floating point types.
463+
if (!isa<FloatType>(elementTy))
464+
return result;
465+
452466
// In the case of "PROPAGATE" semantics no compare and selection is
453467
// required.
454468
if (nanMode == "PROPAGATE")
@@ -1192,7 +1206,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11921206
bool isNanIgnoreMode = false;
11931207
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
11941208
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1195-
if (op.getNanMode() == "IGNORE") {
1209+
// NaN propagation has no meaning for non floating point types.
1210+
if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
11961211
isNanIgnoreMode = true;
11971212
// Because the TOSA spec requires the result be NaN iff all elements in
11981213
// the reduction are NaN we can't simply perform a compare and select.
@@ -2282,7 +2297,8 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
22822297
// In the case "IGNORE" we check if the current argument is NaN and
22832298
// select the old index and value otherwise take the updated index and
22842299
// value.
2285-
if (const auto nanMode = argmaxOp.getNanMode(); nanMode == "IGNORE") {
2300+
if (const auto nanMode = argmaxOp.getNanMode();
2301+
isa<FloatType>(inElementTy) && nanMode == "IGNORE") {
22862302
// Unordered comparison of NaN against itself will always return
22872303
// true.
22882304
Value isNaN = rewriter.create<arith::CmpFOp>(

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,11 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
748748
dilationAttr);
749749

750750
rewriter.replaceOp(op, resultOp);
751+
752+
// NaN propagation has no meaning for non floating point types.
753+
if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
754+
return success();
755+
751756
// "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
752757
// compare and select materialization is required.
753758
//

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,16 @@ func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4
940940

941941
// -----
942942

943+
// CHECK-LABEL: @max_pool2d_nan_ignore_int
944+
func.func @max_pool2d_nan_ignore_int(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x4x32x62xi8>) {
945+
// CHECK: linalg.pooling_nhwc_max
946+
// CHECK-NOT: linalg.generic
947+
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xi8>) -> tensor<1x4x32x62xi8>
948+
return %0: tensor<1x4x32x62xi8>
949+
}
950+
951+
// -----
952+
943953
// CHECK-LABEL: @max_pool2d_nan_ignore
944954
func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
945955
// CHECK-NOT: linalg.pooling_nhwc_max

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,44 @@ func.func @reduce_max_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf3
20332033

20342034
// -----
20352035

2036+
// CHECK-LABEL: @reduce_min_nan_ignore_int
2037+
func.func @reduce_min_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
2038+
// CHECK: linalg.reduce
2039+
// CHECK: arith.minsi
2040+
// CHECK-NOT: arith.cmpf uno
2041+
// CHECK-NOT: arith.select
2042+
// CHECK: linalg.yield
2043+
// CHECK-NOT: arith.constant 0x7FC00000
2044+
// CHECK-NOT: tensor.empty()
2045+
// CHECK-NOT: linalg.fill
2046+
// CHECK-NOT: tensor.empty()
2047+
// CHECK-NOT: select
2048+
// CHECK: return
2049+
%5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<1x4xi8>
2050+
return
2051+
}
2052+
2053+
// -----
2054+
2055+
// CHECK-LABEL: @reduce_max_nan_ignore_int
2056+
func.func @reduce_max_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
2057+
// CHECK: linalg.reduce
2058+
// CHECK: arith.maxsi
2059+
// CHECK-NOT: arith.cmpf uno
2060+
// CHECK-NOT: arith.select
2061+
// CHECK: linalg.yield
2062+
// CHECK-NOT: arith.constant 0x7FC00000
2063+
// CHECK-NOT: tensor.empty()
2064+
// CHECK-NOT: linalg.fill
2065+
// CHECK-NOT: tensor.empty()
2066+
// CHECK-NOT: select
2067+
// CHECK: return
2068+
%6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<1x4xi8>
2069+
return
2070+
}
2071+
2072+
// -----
2073+
20362074
// CHECK-LABEL: @reduce_min_nan_ignore
20372075
func.func @reduce_min_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
20382076
// CHECK: linalg.reduce
@@ -2095,6 +2133,32 @@ func.func @maximum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
20952133

20962134
// -----
20972135

2136+
// CHECK-LABEL: @minimum_nan_ignore_int
2137+
func.func @minimum_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
2138+
// CHECK: linalg.generic
2139+
// CHECK: arith.minsi
2140+
// CHECK-NOT: arith.cmpf uno
2141+
// CHECK-NOT: arith.select
2142+
// CHECK: linalg.yield
2143+
%9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
2144+
return
2145+
}
2146+
2147+
// -----
2148+
2149+
// CHECK-LABEL: @maximum_nan_ignore_int
2150+
func.func @maximum_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
2151+
// CHECK: linalg.generic
2152+
// CHECK: arith.maxsi
2153+
// CHECK-NOT: arith.cmpf uno
2154+
// CHECK-NOT: arith.select
2155+
// CHECK: linalg.yield
2156+
%10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
2157+
return
2158+
}
2159+
2160+
// -----
2161+
20982162
// CHECK-LABEL: @minimum_nan_ignore
20992163
func.func @minimum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
21002164
// CHECK: linalg.generic
@@ -2142,6 +2206,23 @@ func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
21422206

21432207
// -----
21442208

2209+
// CHECK-LABEL: @argmax_nan_ignore_int
2210+
func.func @argmax_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
2211+
// CHECK: linalg.generic
2212+
// CHECK: arith.cmpi sgt
2213+
// CHECK: arith.select
2214+
// CHECK: arith.select
2215+
// CHECK-NOT: arith.cmpf uno
2216+
// CHECK-NOT: arith.cmpf uno
2217+
// CHECK-NOT: arith.select
2218+
// CHECK-NOT: arith.select
2219+
// CHECK: linalg.yield
2220+
%12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<4xi32>
2221+
return
2222+
}
2223+
2224+
// -----
2225+
21452226
// CHECK-LABEL: @argmax_nan_ignore
21462227
func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
21472228
// CHECK: linalg.generic
@@ -2172,6 +2253,20 @@ func.func @clamp_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -
21722253

21732254
// -----
21742255

2256+
// CHECK-LABEL: @clamp_nan_ignore_int
2257+
func.func @clamp_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -> () {
2258+
// CHECK: linalg.generic
2259+
// CHECK: arith.maxsi
2260+
// CHECK: arith.minsi
2261+
// CHECK-NOT: arith.cmpf uno
2262+
// CHECK-NOT: arith.select
2263+
// CHECK: linalg.yield
2264+
%14 = tosa.clamp %arg0 {min_val = 1 : i8, max_val = 5 : i8, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<5x4xi8>
2265+
return
2266+
}
2267+
2268+
// -----
2269+
21752270
// CHECK-LABEL: @clamp_nan_ignore
21762271
func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
21772272
// CHECK: linalg.generic

0 commit comments

Comments
 (0)