36
36
using namespace mlir ;
37
37
using namespace mlir ::tosa;
38
38
39
+ // Helper function to materialize the semantically correct compare and select
40
+ // operations a reduction operation with a specific NaN propagation mode.
41
+ //
42
+ // In the case of "PROPAGATE" semantics no compare and selection is required and
43
+ // this function does nothing.
44
+ //
45
+ // In the case of "IGNORE" semantics this function materializes a comparison of
46
+ // the current operand to the reduction which will return true for a NaN
47
+ // argument and then selects between the initial reduction value and the
48
+ // calculated result based on whether the argument is NaN or not. In pseudo
49
+ // code:
50
+ //
51
+ // reduce<op>(x, init):
52
+ // result = op(init, x)
53
+ // return init if x == NaN else result
54
+ static Value materializeReductionNanCheckIfRequired (Operation *op,
55
+ PatternRewriter &rewriter,
56
+ Value in, Value init,
57
+ Value result) {
58
+ const auto nanMode = getNanMode (op, rewriter);
59
+ if (!nanMode)
60
+ return {};
61
+
62
+ if (*nanMode == " PROPAGATE" )
63
+ return result;
64
+
65
+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
66
+
67
+ // Unordered comparison of NaN against itself will always return true.
68
+ Value isNaN = rewriter.create <arith::CmpFOp>(
69
+ op->getLoc (), arith::CmpFPredicate::UNO, in, in);
70
+ return rewriter.create <arith::SelectOp>(op->getLoc (), isNaN, init, result);
71
+ }
72
+
73
+ // Helper function to materialize the semantically correct compare and select
74
+ // operations a binary operation with a specific NaN propagation mode.
75
+ //
76
+ // In the case of "PROPAGATE" semantics no compare and selection is required and
77
+ // this function does nothing.
78
+ //
79
+ // In the case of "IGNORE" semantics this function materializes a comparison of
80
+ // the current operands to the op which will return true for any NaN
81
+ // argument and then selects between the non-NaN operation argument and the
82
+ // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
83
+ // code:
84
+ //
85
+ // binary<op>(lhs, rhs):
86
+ // result = op(lhs, rhs)
87
+ // if lhs == NaN return rhs
88
+ // if rhs == NaN return lhs
89
+ // return result
90
+ static Value materializeBinaryNanCheckIfRequired (Operation *op,
91
+ PatternRewriter &rewriter,
92
+ Value lhs, Value rhs,
93
+ Value result) {
94
+ const auto nanMode = getNanMode (op, rewriter);
95
+ if (!nanMode)
96
+ return {};
97
+
98
+ if (*nanMode == " PROPAGATE" )
99
+ return result;
100
+
101
+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
102
+
103
+ // Unordered comparison of NaN against itself will always return true.
104
+ Value lhsIsNaN = rewriter.create <arith::CmpFOp>(
105
+ op->getLoc (), arith::CmpFPredicate::UNO, lhs, lhs);
106
+ Value rhsIsNaN = rewriter.create <arith::CmpFOp>(
107
+ op->getLoc (), arith::CmpFPredicate::UNO, rhs, rhs);
108
+ Value rhsOrResult =
109
+ rewriter.create <arith::SelectOp>(op->getLoc (), lhsIsNaN, rhs, result);
110
+ return rewriter.create <arith::SelectOp>(op->getLoc (), rhsIsNaN, lhs,
111
+ rhsOrResult);
112
+ }
113
+
39
114
template <typename T>
40
115
static arith::ConstantOp
41
116
createConstFromIntAttribute (Operation *op, const std::string &attrName,
@@ -358,7 +433,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
358
433
359
434
// tosa::MaximumOp
360
435
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
361
- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
436
+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
437
+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
438
+ max);
362
439
}
363
440
364
441
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -367,7 +444,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367
444
368
445
// tosa::MinimumOp
369
446
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
370
- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
447
+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
448
+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
449
+ min);
371
450
}
372
451
373
452
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -395,7 +474,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
395
474
loc, elementTy, rewriter.getFloatAttr (elementTy, minApf));
396
475
auto max = rewriter.create <arith::ConstantOp>(
397
476
loc, elementTy, rewriter.getFloatAttr (elementTy, maxApf));
398
- return clampFloatHelper (loc, args[0 ], min, max, rewriter);
477
+ auto result = clampFloatHelper (loc, args[0 ], min, max, rewriter);
478
+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
479
+ // is NaN.
480
+ return materializeReductionNanCheckIfRequired (op, rewriter, args[0 ], min,
481
+ result);
399
482
}
400
483
401
484
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1042,15 +1125,19 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
1042
1125
}
1043
1126
1044
1127
if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1045
- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
1128
+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
1129
+ return materializeReductionNanCheckIfRequired (op, rewriter, args[0 ],
1130
+ args[1 ], min);
1046
1131
}
1047
1132
1048
1133
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1049
1134
return rewriter.create <arith::MinSIOp>(loc, args[0 ], args[1 ]);
1050
1135
}
1051
1136
1052
1137
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1053
- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
1138
+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
1139
+ return materializeReductionNanCheckIfRequired (op, rewriter, args[0 ],
1140
+ args[1 ], max);
1054
1141
}
1055
1142
1056
1143
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
@@ -2078,6 +2165,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2078
2165
nestedLoc, predicate, newValue, oldValue);
2079
2166
auto resultIndex = rewriter.create <arith::SelectOp>(
2080
2167
nestedLoc, predicate, newIndex, oldIndex);
2168
+
2169
+ // Check if we need to materialize compare and select for the given
2170
+ // NaN propagation mode.
2171
+ const auto nanMode = getNanMode (argmaxOp, rewriter);
2172
+ if (!nanMode) {
2173
+ didEncounterError = true ;
2174
+ return ;
2175
+ }
2176
+
2177
+ // "PROPAGATE" matches the default NaN propagation mode of the arith
2178
+ // dialect so no compare and select is required.
2179
+ //
2180
+ // In the case "IGNORE" we check if the current argument is NaN and
2181
+ // select the old index and value otherwise take the updated index and
2182
+ // value.
2183
+ if (*nanMode == " IGNORE" ) {
2184
+ // Unordered comparison of NaN against itself will always return
2185
+ // true.
2186
+ Value isNaN = rewriter.create <arith::CmpFOp>(
2187
+ argmaxOp.getLoc (), arith::CmpFPredicate::UNO, newValue,
2188
+ newValue);
2189
+ resultMax = rewriter.create <arith::SelectOp>(nestedLoc, isNaN,
2190
+ oldValue, resultMax);
2191
+ resultIndex = rewriter.create <arith::SelectOp>(
2192
+ nestedLoc, isNaN, oldIndex, resultIndex);
2193
+ }
2081
2194
nestedBuilder.create <linalg::YieldOp>(
2082
2195
nestedLoc, ValueRange ({resultIndex, resultMax}));
2083
2196
});
0 commit comments