17
17
#include " mlir/Pass/Pass.h"
18
18
#include < llvm/ADT/APFloat.h>
19
19
#include < llvm/ADT/APInt.h>
20
- #include < llvm/Support/Debug.h>
21
20
#include < mlir/IR/BuiltinAttributes.h>
22
21
#include < mlir/IR/BuiltinTypes.h>
23
22
#include < mlir/Support/LogicalResult.h>
@@ -47,31 +46,24 @@ struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {
47
46
auto comparisonWidth =
48
47
std::max (inputValues.getElementType ().getIntOrFloatBitWidth (),
49
48
lowerBound.getBitWidth ());
49
+ // Sign-extend the upper and lower bound
50
+ auto extUpperBound = upperBound.sext (comparisonWidth);
51
+ auto extLowerBound = lowerBound.sext (comparisonWidth);
50
52
53
+ // Determine the result type
51
54
auto resultingIntType = cast<IntegerType>(resultType.getElementType ());
52
55
53
- // Ensure that the value is larger than the lower bound
54
- auto clampLower = [&lowerBound , &comparisonWidth]( const APInt &val ,
55
- IntegerType type) {
56
- auto clampedLower = llvm::APIntOps::smax (
57
- val.sext (comparisonWidth), lowerBound. sext (comparisonWidth) );
58
- // Make sure the output value has the correct type
59
- assert (type.getWidth () >= clampedLower .getSignificantBits ());
60
- return clampedLower .trunc (type.getWidth ());
56
+ // Lambda to perform the clamp
57
+ auto clampUpper = [&extLowerBound , &extUpperBound ,
58
+ &comparisonWidth]( const APInt &val, IntegerType type) {
59
+ auto clampedUpper =
60
+ llvm::APIntOps::smin ( val.sext (comparisonWidth), extUpperBound );
61
+ auto fullyClamped = llvm::APIntOps::smax (clampedUpper, extLowerBound);
62
+ assert (type.getWidth () >= fullyClamped .getSignificantBits ());
63
+ return fullyClamped .trunc (type.getWidth ());
61
64
};
62
65
auto newTensor = applyElementWise<APInt, APInt, IntegerType>(
63
- inputValues, clampLower, resultingIntType);
64
-
65
- // Next, make sure the upper bound is adhered to
66
- auto clampUpper = [&upperBound, &comparisonWidth](const APInt &val,
67
- IntegerType type) {
68
- auto clampedUpper = llvm::APIntOps::smin (
69
- val.sext (comparisonWidth), upperBound.sext (comparisonWidth));
70
- assert (type.getWidth () >= clampedUpper.getSignificantBits ());
71
- return clampedUpper.trunc (type.getWidth ());
72
- };
73
- newTensor = applyElementWise<APInt, APInt, IntegerType>(
74
- newTensor, clampUpper, resultingIntType);
66
+ inputValues, clampUpper, resultingIntType);
75
67
76
68
return newTensor;
77
69
}
@@ -91,34 +83,22 @@ struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {
91
83
92
84
auto resultingFloatType = cast<FloatType>(resultType.getElementType ());
93
85
94
- // Ensure that the value is larger than the lower bound
95
- auto clampLower = [&lowerBound, &comparisonSem](APFloat val,
96
- FloatType type) {
86
+ // Ensure that the value is larger than the lower bound and smaller than the
87
+ // upper bound
88
+ auto clampLower = [&lowerBound, &upperBound,
89
+ &comparisonSem](APFloat val, FloatType type) {
97
90
if (val.isNaN ()) {
98
91
return APFloat::getNaN (type.getFloatSemantics ());
99
92
}
100
93
changeSemanticsLossless (val, comparisonSem);
101
- auto clampedLower = val < lowerBound ? lowerBound : val;
102
- changeSemanticsLossless (clampedLower, &type.getFloatSemantics ());
103
- return clampedLower;
94
+ auto clampedUpper = val < upperBound ? val : upperBound;
95
+ auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper;
96
+ changeSemanticsLossless (fullyClamped, &type.getFloatSemantics ());
97
+ return fullyClamped;
104
98
};
105
99
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
106
100
inputValues, clampLower, resultingFloatType);
107
101
108
- // Next, make sure the upper bound is adhered to
109
- auto clampUpper = [&upperBound, &comparisonSem](APFloat val,
110
- FloatType type) {
111
- if (val.isNaN ()) {
112
- return APFloat::getNaN (type.getFloatSemantics ());
113
- }
114
- changeSemanticsLossless (val, comparisonSem);
115
- auto clampedUpper = val < upperBound ? val : upperBound;
116
- changeSemanticsLossless (clampedUpper, &type.getFloatSemantics ());
117
- return clampedUpper;
118
- };
119
- newTensor = applyElementWise<APFloat, APFloat, FloatType>(
120
- newTensor, clampUpper, resultingFloatType);
121
-
122
102
return newTensor;
123
103
}
124
104
0 commit comments