Skip to content

Commit 18c4a2e

Browse files
committed
Avoid multiple calls to applyElementWise
* Merge lambdas that clamp to the upper and lower bound into a single one performing both * Add tests with clamp boundaries which cannot be represented in the type of the value to be clamped
1 parent 6497e4a commit 18c4a2e

File tree

2 files changed

+46
-41
lines changed

2 files changed

+46
-41
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "mlir/Pass/Pass.h"
1818
#include <llvm/ADT/APFloat.h>
1919
#include <llvm/ADT/APInt.h>
20-
#include <llvm/Support/Debug.h>
2120
#include <mlir/IR/BuiltinAttributes.h>
2221
#include <mlir/IR/BuiltinTypes.h>
2322
#include <mlir/Support/LogicalResult.h>
@@ -47,31 +46,24 @@ struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {
4746
auto comparisonWidth =
4847
std::max(inputValues.getElementType().getIntOrFloatBitWidth(),
4948
lowerBound.getBitWidth());
49+
// Sign-extend the upper and lower bound
50+
auto extUpperBound = upperBound.sext(comparisonWidth);
51+
auto extLowerBound = lowerBound.sext(comparisonWidth);
5052

53+
// Determine the result type
5154
auto resultingIntType = cast<IntegerType>(resultType.getElementType());
5255

53-
// Ensure that the value is larger than the lower bound
54-
auto clampLower = [&lowerBound, &comparisonWidth](const APInt &val,
55-
IntegerType type) {
56-
auto clampedLower = llvm::APIntOps::smax(
57-
val.sext(comparisonWidth), lowerBound.sext(comparisonWidth));
58-
// Make sure the output value has the correct type
59-
assert(type.getWidth() >= clampedLower.getSignificantBits());
60-
return clampedLower.trunc(type.getWidth());
56+
// Lambda to perform the clamp
57+
auto clampUpper = [&extLowerBound, &extUpperBound,
58+
&comparisonWidth](const APInt &val, IntegerType type) {
59+
auto clampedUpper =
60+
llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound);
61+
auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound);
62+
assert(type.getWidth() >= fullyClamped.getSignificantBits());
63+
return fullyClamped.trunc(type.getWidth());
6164
};
6265
auto newTensor = applyElementWise<APInt, APInt, IntegerType>(
63-
inputValues, clampLower, resultingIntType);
64-
65-
// Next, make sure the upper bound is adhered to
66-
auto clampUpper = [&upperBound, &comparisonWidth](const APInt &val,
67-
IntegerType type) {
68-
auto clampedUpper = llvm::APIntOps::smin(
69-
val.sext(comparisonWidth), upperBound.sext(comparisonWidth));
70-
assert(type.getWidth() >= clampedUpper.getSignificantBits());
71-
return clampedUpper.trunc(type.getWidth());
72-
};
73-
newTensor = applyElementWise<APInt, APInt, IntegerType>(
74-
newTensor, clampUpper, resultingIntType);
66+
inputValues, clampUpper, resultingIntType);
7567

7668
return newTensor;
7769
}
@@ -91,34 +83,22 @@ struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {
9183

9284
auto resultingFloatType = cast<FloatType>(resultType.getElementType());
9385

94-
// Ensure that the value is larger than the lower bound
95-
auto clampLower = [&lowerBound, &comparisonSem](APFloat val,
96-
FloatType type) {
86+
// Ensure that the value is larger than the lower bound and smaller than the
87+
// upper bound
88+
auto clampLower = [&lowerBound, &upperBound,
89+
&comparisonSem](APFloat val, FloatType type) {
9790
if (val.isNaN()) {
9891
return APFloat::getNaN(type.getFloatSemantics());
9992
}
10093
changeSemanticsLossless(val, comparisonSem);
101-
auto clampedLower = val < lowerBound ? lowerBound : val;
102-
changeSemanticsLossless(clampedLower, &type.getFloatSemantics());
103-
return clampedLower;
94+
auto clampedUpper = val < upperBound ? val : upperBound;
95+
auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper;
96+
changeSemanticsLossless(fullyClamped, &type.getFloatSemantics());
97+
return fullyClamped;
10498
};
10599
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
106100
inputValues, clampLower, resultingFloatType);
107101

108-
// Next, make sure the upper bound is adhered to
109-
auto clampUpper = [&upperBound, &comparisonSem](APFloat val,
110-
FloatType type) {
111-
if (val.isNaN()) {
112-
return APFloat::getNaN(type.getFloatSemantics());
113-
}
114-
changeSemanticsLossless(val, comparisonSem);
115-
auto clampedUpper = val < upperBound ? val : upperBound;
116-
changeSemanticsLossless(clampedUpper, &type.getFloatSemantics());
117-
return clampedUpper;
118-
};
119-
newTensor = applyElementWise<APFloat, APFloat, FloatType>(
120-
newTensor, clampUpper, resultingFloatType);
121-
122102
return newTensor;
123103
}
124104

mlir/test/Dialect/Tosa/constant-clamp-opt.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> {
2424
return %1 : tensor<3xi8>
2525
}
2626

27+
// CHECK-LABEL: @clamp_fold_integer_maximum_larger_than_result_type
28+
func.func @clamp_fold_integer_maximum_larger_than_result_type() -> tensor<3xi8> {
29+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9, 4, 4{{.*}}tensor<3xi8>
30+
// CHECK-NOT: tosa.clamp
31+
// CHECK: return [[RES]]
32+
%0 = "tosa.const"() {value = dense<[9, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8>
33+
%1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, min_int = 4 : i64}
34+
: (tensor<3xi8>) -> tensor<3xi8>
35+
return %1 : tensor<3xi8>
36+
}
37+
2738
// Float clamp
2839

2940
// CHECK-LABEL: @clamp_fold_float
@@ -64,3 +75,17 @@ func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> {
6475
: (tensor<5xf32>) -> tensor<5xf32>
6576
return %1 : tensor<5xf32>
6677
}
78+
79+
// CHECK-LABEL: @clamp_fold_float_maximum_larger_than_result_type
80+
func.func @clamp_fold_float_maximum_larger_than_result_type() -> tensor<2xf16> {
81+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.83{{[0-9]*}}e+01, -5.{{0*}}e-01
82+
// CHECK-NOT: tosa.clamp
83+
// CHECK: return [[RES]]
84+
%0 = "tosa.const"() {value =
85+
dense<[18.32, -0.98747]> :
86+
tensor<2xf16>
87+
} : () -> tensor<2xf16>
88+
%1 = "tosa.clamp"(%0) {max_fp = 3.4028234e+38 : f32, max_int = 1594 : i64, min_fp = -0.5 : f32, min_int = -17 : i64}
89+
: (tensor<2xf16>) -> tensor<2xf16>
90+
return %1 : tensor<2xf16>
91+
}

0 commit comments

Comments
 (0)