Skip to content

Commit 6497e4a

Browse files
committed
Implement constant clamp folding
* Introduce global heuristic when to fold unary operators * Add folding constant clamps + test case
1 parent aa8d96a commit 6497e4a

File tree

9 files changed

+285
-9
lines changed

9 files changed

+285
-9
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ void populateTosaDecomposeDepthwise(MLIRContext *ctx,
3232
RewritePatternSet &patterns);
3333
void populateTosaFoldConstantAddPatterns(MLIRContext *ctx,
3434
RewritePatternSet &patterns);
35+
void populateTosaFoldConstantClampPatterns(MLIRContext *ctx,
36+
RewritePatternSet &patterns);
3537
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
3638
RewritePatternSet &patterns,
3739
bool enableIntCastFolding);

mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ bool constantBinaryOpShouldBeFolded(TosaOp binaryOp,
9393
DenseElementsAttr valuesFirst,
9494
DenseElementsAttr valuesSecond);
9595

96+
/// Heuristic to decide when to replace a unary operation on a constant with the
97+
/// folded value.
98+
/// Folding operations on constants can lead to an increased memory usage
99+
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
100+
/// this will currently only suggest folding when the memory impact is
101+
/// negligible.
102+
/// Takes the \p unaryOp and the constant input \p values.
103+
/// \returns Whether folding should be applied.
104+
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values);
105+
96106
/// Function to compute the reciprocal.
97107
APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy);
98108

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
55
TosaFoldCommon.cpp
66
TosaFoldConstantAdd.cpp
77
TosaFoldConstantCast.cpp
8+
TosaFoldConstantClamp.cpp
89
TosaFoldConstantPow.cpp
910
TosaFoldConstantReciprocal.cpp
1011
TosaFoldConstantRSQRT.cpp

mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,21 @@ bool mlir::tosa::constantBinaryOpShouldBeFolded(
243243
return firstOp == secondOp && numUsers == 2;
244244
}
245245

246+
bool mlir::tosa::constantUnaryOpShouldBeFolded(TosaOp unaryOp,
247+
DenseElementsAttr values) {
248+
assert(unaryOp->getNumOperands() == 1);
249+
auto inputOp = unaryOp->getOperand(0);
250+
251+
// If the input is a splat, we don't care for the number of users
252+
if (isa<SplatElementsAttr>(values)) {
253+
return true;
254+
}
255+
256+
// If this is the only use of the tensors it will be replaced an no
257+
// additional memory is required.
258+
return inputOp.hasOneUse();
259+
}
260+
246261
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal,
247262
FloatType floatTy) {
248263
auto recipAttr = FloatAttr::get(floatTy, 1.0);
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
//===- TosaFoldConstantClamp.cpp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA Clamp operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <llvm/ADT/APFloat.h>
19+
#include <llvm/ADT/APInt.h>
20+
#include <llvm/Support/Debug.h>
21+
#include <mlir/IR/BuiltinAttributes.h>
22+
#include <mlir/IR/BuiltinTypes.h>
23+
#include <mlir/Support/LogicalResult.h>
24+
25+
using namespace mlir;
26+
using namespace mlir::tosa;
27+
28+
namespace {
29+
30+
struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {
31+
32+
using OpRewritePattern::OpRewritePattern;
33+
34+
static void
35+
changeSemanticsLossless(APFloat &floatVal,
36+
const llvm::fltSemantics *floatSemantics) {
37+
bool losesInfo;
38+
floatVal.convert(*floatSemantics, tosaRoundingMode, &losesInfo);
39+
assert(!losesInfo);
40+
}
41+
42+
DenseElementsAttr applyClamp(DenseElementsAttr inputValues,
43+
const APInt &lowerBound, const APInt &upperBound,
44+
TensorType resultType) const {
45+
46+
// Determine the width for the APInt comparison
47+
auto comparisonWidth =
48+
std::max(inputValues.getElementType().getIntOrFloatBitWidth(),
49+
lowerBound.getBitWidth());
50+
51+
auto resultingIntType = cast<IntegerType>(resultType.getElementType());
52+
53+
// Ensure that the value is larger than the lower bound
54+
auto clampLower = [&lowerBound, &comparisonWidth](const APInt &val,
55+
IntegerType type) {
56+
auto clampedLower = llvm::APIntOps::smax(
57+
val.sext(comparisonWidth), lowerBound.sext(comparisonWidth));
58+
// Make sure the output value has the correct type
59+
assert(type.getWidth() >= clampedLower.getSignificantBits());
60+
return clampedLower.trunc(type.getWidth());
61+
};
62+
auto newTensor = applyElementWise<APInt, APInt, IntegerType>(
63+
inputValues, clampLower, resultingIntType);
64+
65+
// Next, make sure the upper bound is adhered to
66+
auto clampUpper = [&upperBound, &comparisonWidth](const APInt &val,
67+
IntegerType type) {
68+
auto clampedUpper = llvm::APIntOps::smin(
69+
val.sext(comparisonWidth), upperBound.sext(comparisonWidth));
70+
assert(type.getWidth() >= clampedUpper.getSignificantBits());
71+
return clampedUpper.trunc(type.getWidth());
72+
};
73+
newTensor = applyElementWise<APInt, APInt, IntegerType>(
74+
newTensor, clampUpper, resultingIntType);
75+
76+
return newTensor;
77+
}
78+
79+
DenseElementsAttr applyClamp(DenseElementsAttr inputValues,
80+
APFloat lowerBound, APFloat upperBound,
81+
TensorType resultType) const {
82+
auto inputValType = cast<FloatType>(inputValues.getElementType());
83+
auto inputWidth = inputValType.getWidth();
84+
auto bWidth = APFloat::semanticsSizeInBits(lowerBound.getSemantics());
85+
auto *comparisonSem = inputWidth < bWidth
86+
? &lowerBound.getSemantics()
87+
: &inputValType.getFloatSemantics();
88+
89+
changeSemanticsLossless(lowerBound, comparisonSem);
90+
changeSemanticsLossless(upperBound, comparisonSem);
91+
92+
auto resultingFloatType = cast<FloatType>(resultType.getElementType());
93+
94+
// Ensure that the value is larger than the lower bound
95+
auto clampLower = [&lowerBound, &comparisonSem](APFloat val,
96+
FloatType type) {
97+
if (val.isNaN()) {
98+
return APFloat::getNaN(type.getFloatSemantics());
99+
}
100+
changeSemanticsLossless(val, comparisonSem);
101+
auto clampedLower = val < lowerBound ? lowerBound : val;
102+
changeSemanticsLossless(clampedLower, &type.getFloatSemantics());
103+
return clampedLower;
104+
};
105+
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
106+
inputValues, clampLower, resultingFloatType);
107+
108+
// Next, make sure the upper bound is adhered to
109+
auto clampUpper = [&upperBound, &comparisonSem](APFloat val,
110+
FloatType type) {
111+
if (val.isNaN()) {
112+
return APFloat::getNaN(type.getFloatSemantics());
113+
}
114+
changeSemanticsLossless(val, comparisonSem);
115+
auto clampedUpper = val < upperBound ? val : upperBound;
116+
changeSemanticsLossless(clampedUpper, &type.getFloatSemantics());
117+
return clampedUpper;
118+
};
119+
newTensor = applyElementWise<APFloat, APFloat, FloatType>(
120+
newTensor, clampUpper, resultingFloatType);
121+
122+
return newTensor;
123+
}
124+
125+
LogicalResult matchAndRewrite(ClampOp clampOp,
126+
PatternRewriter &rewriter) const override {
127+
auto valsToClamp = clampOp.getInput();
128+
auto inputElementType = valsToClamp.getType().getElementType();
129+
130+
// Check if the input is constant
131+
if (failed(notifyIfNoTosaDenseConstantTensor(valsToClamp, clampOp,
132+
rewriter))) {
133+
return failure();
134+
}
135+
136+
if (isa<IntegerType>(inputElementType) &&
137+
cast<IntegerType>(inputElementType).isUnsigned()) {
138+
return rewriter.notifyMatchFailure(
139+
clampOp, "Currently, unsigned integer clamps are unsupported.");
140+
}
141+
142+
// Extract the tensor values
143+
DenseElementsAttr inputValues;
144+
matchPattern(valsToClamp, m_Constant(&inputValues));
145+
146+
if (!constantUnaryOpShouldBeFolded(clampOp, inputValues)) {
147+
return rewriter.notifyMatchFailure(
148+
clampOp,
149+
"Currently, clamps will only be folded if this requires only "
150+
"little additional memory usage.");
151+
}
152+
153+
// Apply the clamp to all values of the int/float tensor
154+
auto resultType = clampOp.getType();
155+
DenseElementsAttr newTensor;
156+
if (isa<IntegerType>(inputElementType)) {
157+
auto lowerBoundVal = clampOp.getMinIntAttr().getValue();
158+
auto upperBoundVal = clampOp.getMaxIntAttr().getValue();
159+
assert(lowerBoundVal.getBitWidth() == upperBoundVal.getBitWidth());
160+
161+
newTensor =
162+
applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType);
163+
} else {
164+
assert(isa<FloatType>(inputElementType));
165+
auto lowerBoundVal = clampOp.getMinFp();
166+
auto upperBoundVal = clampOp.getMaxFp();
167+
assert(APFloat::getSizeInBits(lowerBoundVal.getSemantics()) ==
168+
APFloat::getSizeInBits(upperBoundVal.getSemantics()));
169+
170+
newTensor =
171+
applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType);
172+
}
173+
174+
rewriter.replaceOpWithNewOp<ConstOp>(clampOp, newTensor.getType(),
175+
newTensor);
176+
177+
return success();
178+
}
179+
};
180+
181+
} // namespace
182+
183+
void mlir::tosa::populateTosaFoldConstantClampPatterns(
184+
MLIRContext *ctx, RewritePatternSet &patterns) {
185+
patterns.add<TosaFoldConstantClamp>(ctx);
186+
}

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
6363
DenseElementsAttr inputValues;
6464
matchPattern(inputTensor, m_Constant(&inputValues));
6565

66-
// Only fold splat tensors and those used only once to avoid duplicating
67-
// them.
68-
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(inputValues)) {
66+
// Check whether this should be folded.
67+
if (!constantUnaryOpShouldBeFolded(rsqrt, inputValues)) {
6968
return rewriter.notifyMatchFailure(
7069
rsqrt, "Currently, reciprocals will only be folded if the input "
7170
"tensor has a single user");

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,8 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
4545
DenseElementsAttr inputValues;
4646
matchPattern(inputTensor, m_Constant(&inputValues));
4747

48-
// Our transformation replaces the input tensor with the transformed tensor.
49-
// If the input has several users we need to keep the input. This can
50-
// result in a significantly increased memory usage, such that we currently
51-
// refrain from applying the transformation in that case.
52-
// Allow this only for splat values, because the amount of data is small.
53-
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(inputValues)) {
48+
// Check whether this should be folded.
49+
if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
5450
return rewriter.notifyMatchFailure(
5551
recip, "Currently, reciprocals will only be folded if the input "
5652
"tensor has a single user");

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct TosaLayerwiseConstantFoldPass
5353
mlir::tosa::populateTosaFoldConstantAddPatterns(ctx, patterns);
5454
mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns,
5555
enableIntCastFolding);
56+
mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns);
5657
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
5758
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5859
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s
2+
3+
// Int clamp
4+
5+
// CHECK-LABEL: @clamp_fold_integer
6+
func.func @clamp_fold_integer() -> tensor<3xi16> {
7+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2, 0, 1{{.*}}tensor<3xi16>
8+
// CHECK-NOT: tosa.clamp
9+
// CHECK: return [[RES]]
10+
%0 = "tosa.const"() {value = dense<[-12, 0, 5]> : tensor<3xi16>} : () -> tensor<3xi16>
11+
%1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 1 : i64, min_fp = 0.0 : f32, min_int = -2 : i64}
12+
: (tensor<3xi16>) -> tensor<3xi16>
13+
return %1 : tensor<3xi16>
14+
}
15+
16+
// CHECK-LABEL: @clamp_fold_integer_equal_lower_upper
17+
func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> {
18+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<17>{{.*}}tensor<3xi8>
19+
// CHECK-NOT: tosa.clamp
20+
// CHECK: return [[RES]]
21+
%0 = "tosa.const"() {value = dense<[2, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8>
22+
%1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 17 : i64, min_fp = 0.0 : f32, min_int = 17 : i64}
23+
: (tensor<3xi8>) -> tensor<3xi8>
24+
return %1 : tensor<3xi8>
25+
}
26+
27+
// Float clamp
28+
29+
// CHECK-LABEL: @clamp_fold_float
30+
func.func @clamp_fold_float() -> tensor<3xf16> {
31+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2.{{0*}}e+00, {{[8-9]}}.{{[0-9]*}}e-01, 1.{{0*}}e+00{{.*}}tensor<3xf16>
32+
// CHECK-NOT: tosa.clamp
33+
// CHECK: return [[RES]]
34+
%0 = "tosa.const"() {value = dense<[-12.4, 0.9, 5.2]> : tensor<3xf16>} : () -> tensor<3xf16>
35+
%1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64}
36+
: (tensor<3xf16>) -> tensor<3xf16>
37+
return %1 : tensor<3xf16>
38+
}
39+
40+
// CHECK-LABEL: @clamp_fold_float_infty_nan
41+
func.func @clamp_fold_float_infty_nan() -> tensor<5xf32> {
42+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.{{0*}}e+00, -2.{{0*}}e+00, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32>
43+
// CHECK-NOT: tosa.clamp
44+
// CHECK: return [[RES]]
45+
%0 = "tosa.const"() {value =
46+
dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> :
47+
tensor<5xf32>
48+
} : () -> tensor<5xf32>
49+
%1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64}
50+
: (tensor<5xf32>) -> tensor<5xf32>
51+
return %1 : tensor<5xf32>
52+
}
53+
54+
// CHECK-LABEL: @clamp_fold_float_infinity_upper
55+
func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> {
56+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, -2.{{0*}}e+00, 9.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32>
57+
// CHECK-NOT: tosa.clamp
58+
// CHECK: return [[RES]]
59+
%0 = "tosa.const"() {value =
60+
dense<[0x7F800000, 0xFF800000, 9.0, -0.0, 0x7FC00000]> :
61+
tensor<5xf32>
62+
} : () -> tensor<5xf32>
63+
%1 = "tosa.clamp"(%0) {max_fp = 0x7F800000 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64}
64+
: (tensor<5xf32>) -> tensor<5xf32>
65+
return %1 : tensor<5xf32>
66+
}

0 commit comments

Comments
 (0)