Skip to content

Commit 5ba05b9

Browse files
tatwaichongFranklandJack
authored andcommitted
[mlir][tosa] Add NaN Propagation Mode Support
The TOSA-V1.0 specification adds "nan propagation" modes as attributes for several operators. Adjust the ODS definitions of the relevant operations to include this attribute. The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is set by default. MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX support this attribute. Refactor the clamp + clamp optimization in order to better handle edge cases such as invalid NaN propgation combinations and disjoint clamp ranges. Signed-off-by: Jack Frankland <[email protected]>
1 parent 998bdae commit 5ba05b9

File tree

4 files changed

+91
-25
lines changed

4 files changed

+91
-25
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
4242

4343
let arguments = (ins
4444
Tosa_Tensor: $input,
45-
I32Attr: $axis
45+
I32Attr: $axis,
46+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
4647
);
4748

4849
let results = (outs
@@ -284,7 +285,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
284285

285286
Tosa_IntArrayAttr2:$kernel,
286287
Tosa_IntArrayAttr2:$stride,
287-
Tosa_IntArrayAttr4:$pad
288+
Tosa_IntArrayAttr4:$pad,
289+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
288290
);
289291

290292
let results = (outs
@@ -383,7 +385,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
383385
I64Attr:$min_int,
384386
I64Attr:$max_int,
385387
Tosa_FloatAttr:$min_fp,
386-
Tosa_FloatAttr:$max_fp
388+
Tosa_FloatAttr:$max_fp,
389+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
387390
);
388391

389392
let results = (outs
@@ -747,7 +750,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
747750

748751
let arguments = (ins
749752
Tosa_Tensor:$input1,
750-
Tosa_Tensor:$input2
753+
Tosa_Tensor:$input2,
754+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
751755
);
752756

753757
let results = (outs
@@ -770,7 +774,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
770774

771775
let arguments = (ins
772776
Tosa_Tensor:$input1,
773-
Tosa_Tensor:$input2
777+
Tosa_Tensor:$input2,
778+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
774779
);
775780

776781
let results = (outs
@@ -1377,7 +1382,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
13771382

13781383
let arguments = (ins
13791384
Tosa_Tensor:$input,
1380-
I32Attr:$axis
1385+
I32Attr:$axis,
1386+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
13811387
);
13821388

13831389
let results = (outs
@@ -1412,7 +1418,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
14121418

14131419
let arguments = (ins
14141420
Tosa_Tensor:$input,
1415-
I32Attr:$axis
1421+
I32Attr:$axis,
1422+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
14161423
);
14171424

14181425
let results = (outs

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
202202
//===----------------------------------------------------------------------===//
203203
// Iterable attributes.
204204
//===----------------------------------------------------------------------===//
205+
// Defined in `section 3. Enumerations` of the TOSA specification.
206+
205207
// Supported regimes for tosa.resize.
206208
def Tosa_ResizeTypeAttr : StringBasedAttr<
207209
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
208210
"::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
209211
"Supported resize/upsampling strategies">;
210212

213+
// Supported NaN propagation strategies.
214+
def Tosa_NanPropagationAttr : StringBasedAttr<
215+
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
216+
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
217+
"Supported NaN propagation strategies">;
218+
211219
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
212220

213221
// Tensor to buffer types.

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -339,33 +339,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
339339
}
340340
};
341341

342+
// Attempts the following transformation:
343+
//
344+
// For integers a, b, a', and b' such that [a, b] ∩ [c, d] ≠ ∅ and input
345+
// tensor X the following identity holds:
346+
//
347+
// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
348+
//
349+
// subject to the following valid NaN propagation semantics:
350+
// --------------------------------------------
351+
// | opNanMode | clampNanMode | resultNanMode |
352+
// |-----------|--------------|---------------|
353+
// | PROPAGATE | PROPAGATE | PROPAGATE |
354+
// | PROPAGATE | IGNORE | IGNORE |
355+
// | IGNORE | PROPAGATE | INVALID |
356+
// | IGNORE | IGNORE | INGORE |
357+
// |------------------------------------------|
358+
342359
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
343360
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
344361

362+
// Helper structure to describe the range of a clamp operation.
363+
template <typename T>
364+
struct ClampRange {
365+
ClampRange(const T &start, const T &end) : start(start), end(end) {}
366+
T start;
367+
T end;
368+
369+
// Helper function to determine if two Clamp ranges intersect.
370+
bool intersects(const ClampRange<T> &otherRange) {
371+
return start < otherRange.end && otherRange.start < end;
372+
}
373+
};
374+
345375
LogicalResult matchAndRewrite(tosa::ClampOp op,
346376
PatternRewriter &rewriter) const override {
347-
Value input = op.getInput();
348-
349-
Operation *definingOp = input.getDefiningOp();
350-
if (!definingOp)
377+
// Check the input to the CLAMP op is itself a CLAMP.
378+
auto clampOp =
379+
dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
380+
if (!clampOp)
351381
return failure();
352382

353-
if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354-
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355-
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
383+
// Check we have a valid NaN propagation combination.
384+
const auto opNanMode = op.getNanMode();
385+
const auto clampNanMode = clampOp.getNanMode();
386+
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
387+
return failure();
356388

357-
auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
358-
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
389+
// Check we have intersecting ranges.
390+
const auto opMinInt = op.getMinInt();
391+
const auto opMaxInt = op.getMaxInt();
392+
const auto clampOpMinInt = clampOp.getMinInt();
393+
const auto clampOpMaxInt = clampOp.getMaxInt();
394+
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
395+
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
396+
if (!opRangeIntRange.intersects(clampRangeIntRange))
397+
return failure();
359398

360-
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
361-
op, op.getType(), clampOp.getInput(),
362-
rewriter.getI64IntegerAttr(minInt),
363-
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
364-
rewriter.getF32FloatAttr(maxFp));
365-
return success();
366-
}
399+
const auto opMinFloat = op.getMinFp();
400+
const auto opMaxFloat = op.getMaxFp();
401+
const auto clampOpMinFloat = clampOp.getMinFp();
402+
const auto clampOpMaxFloat = clampOp.getMaxFp();
403+
ClampRange opRangeFloatRange(opMinFloat, opMaxFloat);
404+
ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
405+
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
406+
return failure();
367407

368-
return failure();
408+
// Run the transformation.
409+
const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
410+
const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
411+
const auto minInt = std::max(opMinInt, clampOpMinInt);
412+
const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
413+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
414+
op, op.getType(), clampOp.getInput(),
415+
rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
416+
rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
417+
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
418+
: opNanMode));
419+
return success();
369420
}
370421
};
371422

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
130130

131131
// CHECK-LABEL: @clamp_twice_is_single_clamp
132132
func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
133-
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
133+
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
134134
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8>
135135
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} : (tensor<4xi8>) -> tensor<4xi8>
136136
return %1 : tensor<4xi8>

0 commit comments

Comments
 (0)