Skip to content

Commit 199ae05

Browse files
committed
[mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering
Add support for NaN propagation lowering in the `tosa-to-linalg` and `tosa-to-linalg-named` conversions by conditionally checking for NaN in the case of ignore semantics and materializing the appropriate select operations. Note that the default behviour of "propagate" matches that of the arith dialect and so in that case we can avoid creating the checks altogether. Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate. This affects the following TOSA operators: * arg_max * max_pool_2d * clamp * reduce_max * reduce_min * maximum * minimum Signed-off-by: Jack Frankland <[email protected]>
1 parent d9af03b commit 199ae05

File tree

5 files changed

+325
-9
lines changed

5 files changed

+325
-9
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,22 @@ namespace tosa {
141141

142142
bool isa_tosa_shape_type(mlir::Type t);
143143

144+
// Helper function to materialize the semantically correct compare and select
145+
// operations a reduction operation with a specific NaN propagation mode.
146+
//
147+
// In the case of "PROPAGATE" semantics no compare and selection is required and
148+
// this function does nothing.
149+
//
150+
// In the case of "IGNORE" semantics this function materializes a comparison of
151+
// the current operand to the reduction which will return true for a NaN
152+
// argument and then selects between the initial reduction value and the
153+
// calculated result based on whether the argument is NaN or not. In pseudo
154+
// code:
155+
//
156+
// reduce<op>(x, init):
157+
// result = op(init, x)
158+
// return init if x == NaN else result
159+
144160
} // namespace tosa
145161

146162
} // namespace mlir
@@ -267,6 +283,31 @@ extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
267283

268284
return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
269285
}
286+
287+
// Helper to determine if an operation should have the "nan_mode" string
288+
// attribute.
289+
inline bool shouldHaveNanPropagation(Operation *op) {
290+
return isa<tosa::ClampOp>(op) || isa<tosa::MaxPool2dOp>(op) ||
291+
isa<tosa::ReduceMinOp>(op) || isa<tosa::ReduceMaxOp>(op) ||
292+
isa<tosa::MaximumOp>(op) || isa<tosa::MinimumOp>(op);
293+
}
294+
295+
// Helper function to extract the NaN propagation mode from an operation.
296+
// Note that the for operations which support NaN mode propagation the attribute
297+
// is optional and its default value is "PROPAGATE".
298+
//
299+
// If the function is called with an operator that doesn't support the NaN mode
300+
// attribute it will return a std::nullopt.
301+
inline std::optional<std::string> getNanMode(Operation *op,
302+
PatternRewriter &rewriter) {
303+
if (shouldHaveNanPropagation(op))
304+
return op->hasAttr("nan_mode") ? op->getAttrOfType<StringAttr>(
305+
rewriter.getStringAttr("nan_mode"))
306+
.str()
307+
: "PROPAGATE";
308+
return std::nullopt;
309+
}
310+
270311
} // namespace tosa
271312
} // namespace mlir
272313

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,81 @@
3636
using namespace mlir;
3737
using namespace mlir::tosa;
3838

39+
// Helper function to materialize the semantically correct compare and select
40+
// operations a reduction operation with a specific NaN propagation mode.
41+
//
42+
// In the case of "PROPAGATE" semantics no compare and selection is required and
43+
// this function does nothing.
44+
//
45+
// In the case of "IGNORE" semantics this function materializes a comparison of
46+
// the current operand to the reduction which will return true for a NaN
47+
// argument and then selects between the initial reduction value and the
48+
// calculated result based on whether the argument is NaN or not. In pseudo
49+
// code:
50+
//
51+
// reduce<op>(x, init):
52+
// result = op(init, x)
53+
// return init if x == NaN else result
54+
static Value materializeReductionNanCheckIfRequired(Operation *op,
55+
PatternRewriter &rewriter,
56+
Value in, Value init,
57+
Value result) {
58+
const auto nanMode = getNanMode(op, rewriter);
59+
if (!nanMode)
60+
return {};
61+
62+
if (*nanMode == "PROPAGATE")
63+
return result;
64+
65+
assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
66+
67+
// Unordered comparison of NaN against itself will always return true.
68+
Value isNaN = rewriter.create<arith::CmpFOp>(
69+
op->getLoc(), arith::CmpFPredicate::UNO, in, in);
70+
return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, init, result);
71+
}
72+
73+
// Helper function to materialize the semantically correct compare and select
74+
// operations a binary operation with a specific NaN propagation mode.
75+
//
76+
// In the case of "PROPAGATE" semantics no compare and selection is required and
77+
// this function does nothing.
78+
//
79+
// In the case of "IGNORE" semantics this function materializes a comparison of
80+
// the current operands to the op which will return true for any NaN
81+
// argument and then selects between the non-NaN operation argument and the
82+
// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
83+
// code:
84+
//
85+
// binary<op>(lhs, rhs):
86+
// result = op(lhs, rhs)
87+
// if lhs == NaN return rhs
88+
// if rhs == NaN return lhs
89+
// return result
90+
static Value materializeBinaryNanCheckIfRequired(Operation *op,
91+
PatternRewriter &rewriter,
92+
Value lhs, Value rhs,
93+
Value result) {
94+
const auto nanMode = getNanMode(op, rewriter);
95+
if (!nanMode)
96+
return {};
97+
98+
if (*nanMode == "PROPAGATE")
99+
return result;
100+
101+
assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
102+
103+
// Unordered comparison of NaN against itself will always return true.
104+
Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
105+
op->getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
106+
Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
107+
op->getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
108+
Value rhsOrResult =
109+
rewriter.create<arith::SelectOp>(op->getLoc(), lhsIsNaN, rhs, result);
110+
return rewriter.create<arith::SelectOp>(op->getLoc(), rhsIsNaN, lhs,
111+
rhsOrResult);
112+
}
113+
39114
template <typename T>
40115
static arith::ConstantOp
41116
createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -358,7 +433,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
358433

359434
// tosa::MaximumOp
360435
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
361-
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
436+
auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
437+
return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
438+
max);
362439
}
363440

364441
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -367,7 +444,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367444

368445
// tosa::MinimumOp
369446
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
370-
return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
447+
auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
448+
return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
449+
min);
371450
}
372451

373452
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -395,7 +474,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
395474
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
396475
auto max = rewriter.create<arith::ConstantOp>(
397476
loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
398-
return clampFloatHelper(loc, args[0], min, max, rewriter);
477+
auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
478+
// TOSA specifies that in "ignore" NaN mode the result is "min" if the input
479+
// is NaN.
480+
return materializeReductionNanCheckIfRequired(op, rewriter, args[0], min,
481+
result);
399482
}
400483

401484
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1042,15 +1125,19 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
10421125
}
10431126

10441127
if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1045-
return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1128+
auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1129+
return materializeReductionNanCheckIfRequired(op, rewriter, args[0],
1130+
args[1], min);
10461131
}
10471132

10481133
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
10491134
return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
10501135
}
10511136

10521137
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1053-
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1138+
auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1139+
return materializeReductionNanCheckIfRequired(op, rewriter, args[0],
1140+
args[1], max);
10541141
}
10551142

10561143
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
@@ -2078,6 +2165,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
20782165
nestedLoc, predicate, newValue, oldValue);
20792166
auto resultIndex = rewriter.create<arith::SelectOp>(
20802167
nestedLoc, predicate, newIndex, oldIndex);
2168+
2169+
// Check if we need to materialize compare and select for the given
2170+
// NaN propagation mode.
2171+
const auto nanMode = getNanMode(argmaxOp, rewriter);
2172+
if (!nanMode) {
2173+
didEncounterError = true;
2174+
return;
2175+
}
2176+
2177+
// "PROPAGATE" matches the default NaN propagation mode of the arith
2178+
// dialect so no compare and select is required.
2179+
//
2180+
// In the case "IGNORE" we check if the current argument is NaN and
2181+
// select the old index and value otherwise take the updated index and
2182+
// value.
2183+
if (*nanMode == "IGNORE") {
2184+
// Unordered comparison of NaN against itself will always return
2185+
// true.
2186+
Value isNaN = rewriter.create<arith::CmpFOp>(
2187+
argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
2188+
newValue);
2189+
resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
2190+
oldValue, resultMax);
2191+
resultIndex = rewriter.create<arith::SelectOp>(
2192+
nestedLoc, isNaN, oldIndex, resultIndex);
2193+
}
20812194
nestedBuilder.create<linalg::YieldOp>(
20822195
nestedLoc, ValueRange({resultIndex, resultMax}));
20832196
});

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -807,11 +807,47 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
807807
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
808808
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
809809
filledEmptyTensor, strideAttr, dilationAttr);
810-
} else {
811-
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
812-
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
813-
filledEmptyTensor, strideAttr, dilationAttr);
810+
return llvm::success();
814811
}
812+
813+
auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
814+
op->getLoc(), ArrayRef<Type>{resultTy},
815+
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
816+
dilationAttr);
817+
818+
// Check the NaN propgation mode is present.
819+
const auto nanMode = getNanMode(op, rewriter);
820+
if (!nanMode)
821+
return failure();
822+
823+
// "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
824+
// compare and select materialization is required.
825+
//
826+
// In the case of "IGNORE" we need to insert a compare and select. Since
827+
// we've already produced a named op we will just take its body and modify
828+
// it to include the appropriate checks. If the current value is NaN the
829+
// old value of pool will be taken otherwise we use the result.
830+
if (nanMode == "IGNORE") {
831+
auto *block = resultOp.getBlock();
832+
rewriter.setInsertionPointToEnd(block);
833+
834+
auto in = block->getArgument(0);
835+
auto out = block->getArgument(1);
836+
837+
auto *oldYieldOp = &*block->rbegin();
838+
auto result = oldYieldOp->getOperand(0);
839+
840+
Value isNaN = rewriter.create<arith::CmpFOp>(
841+
op->getLoc(), arith::CmpFPredicate::UNO, in, in);
842+
843+
auto selectOp =
844+
rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, out, result);
845+
auto newYieldOp = rewriter.create<linalg::YieldOp>(oldYieldOp->getLoc(),
846+
selectOp.getResult());
847+
rewriter.replaceOp(oldYieldOp, newYieldOp);
848+
}
849+
850+
rewriter.replaceOp(op, resultOp);
815851
return success();
816852
}
817853
};

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
22
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
33
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
4+
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,linalg-generalize-named-ops))" %s -verify-diagnostics -o -| FileCheck %s --check-prefix="CHECK-NAN"
45

56
// CHECK-LABEL: @matmul
67
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -977,3 +978,24 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
977978
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
978979
return
979980
}
981+
982+
// -----
983+
984+
// CHECK-NAN-LABEL: @nan_propagation_modes
985+
func.func @nan_propagation_modes(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>) {
986+
// CHECK-NAN: linalg.generic
987+
// CHECK-NAN-NOT: arith.maximumf
988+
// CHECK-NAN-NOT: arith.cmpf uno
989+
// CHECK-NAN-NOT: arith.select
990+
// CHECK-NAN: linalg.yield
991+
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
992+
993+
// CHECK-NAN: linalg.generic
994+
// CHECK-NAN: arith.maximumf
995+
// CHECK-NAN: arith.cmpf uno
996+
// CHECK-NAN: arith.select
997+
// CHECK-NAN: linalg.yield
998+
%1 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
999+
1000+
return %0, %1 : tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>
1001+
}

0 commit comments

Comments
 (0)