Skip to content

Commit c697ece

Browse files
authored
Merge pull request #25 from Xilinx/tina.tosaclampfolding
[FXML-1931] Implement constant clamp folding
2 parents 45db2af + 44349a7 commit c697ece

File tree

9 files changed

+290
-9
lines changed

9 files changed

+290
-9
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ void populateTosaDecomposeDepthwise(MLIRContext *ctx,
3232
RewritePatternSet &patterns);
3333
void populateTosaFoldConstantAddPatterns(MLIRContext *ctx,
3434
RewritePatternSet &patterns);
35+
void populateTosaFoldConstantClampPatterns(MLIRContext *ctx,
36+
RewritePatternSet &patterns);
3537
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
3638
RewritePatternSet &patterns,
3739
bool enableIntCastFolding);

mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ bool constantBinaryOpShouldBeFolded(TosaOp binaryOp,
9393
DenseElementsAttr valuesFirst,
9494
DenseElementsAttr valuesSecond);
9595

96+
/// Heuristic to decide when to replace a unary operation on a constant with the
97+
/// folded value.
98+
/// Folding operations on constants can lead to an increased memory usage
99+
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
100+
/// this will currently only suggest folding when the memory impact is
101+
/// negligible.
102+
/// Takes the \p unaryOp and the constant input \p values.
103+
/// \returns Whether folding should be applied.
104+
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values);
105+
96106
/// Function to compute the reciprocal.
97107
APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy);
98108

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
55
TosaFoldCommon.cpp
66
TosaFoldConstantAdd.cpp
77
TosaFoldConstantCast.cpp
8+
TosaFoldConstantClamp.cpp
89
TosaFoldConstantPow.cpp
910
TosaFoldConstantReciprocal.cpp
1011
TosaFoldConstantRSQRT.cpp

mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,21 @@ bool mlir::tosa::constantBinaryOpShouldBeFolded(
243243
return firstOp == secondOp && numUsers == 2;
244244
}
245245

246+
bool mlir::tosa::constantUnaryOpShouldBeFolded(TosaOp unaryOp,
247+
DenseElementsAttr values) {
248+
assert(unaryOp->getNumOperands() == 1);
249+
auto inputOp = unaryOp->getOperand(0);
250+
251+
// If the input is a splat, we don't care for the number of users
252+
if (isa<SplatElementsAttr>(values)) {
253+
return true;
254+
}
255+
256+
// If this is the only use of the tensors it will be replaced an no
257+
// additional memory is required.
258+
return inputOp.hasOneUse();
259+
}
260+
246261
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal,
247262
FloatType floatTy) {
248263
auto recipAttr = FloatAttr::get(floatTy, 1.0);
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
//===- TosaFoldConstantClamp.cpp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA Clamp operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <llvm/ADT/APFloat.h>
19+
#include <llvm/ADT/APInt.h>
20+
#include <mlir/IR/BuiltinAttributes.h>
21+
#include <mlir/IR/BuiltinTypes.h>
22+
#include <mlir/Support/LogicalResult.h>
23+
24+
using namespace mlir;
25+
using namespace mlir::tosa;
26+
27+
namespace {
28+
29+
struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {
30+
31+
using OpRewritePattern::OpRewritePattern;
32+
33+
static void
34+
changeSemanticsLossless(APFloat &floatVal,
35+
const llvm::fltSemantics *floatSemantics) {
36+
bool losesInfo;
37+
floatVal.convert(*floatSemantics, tosaRoundingMode, &losesInfo);
38+
assert(!losesInfo);
39+
}
40+
41+
DenseElementsAttr applyClamp(DenseElementsAttr inputValues,
42+
const APInt &lowerBound, const APInt &upperBound,
43+
TensorType resultType) const {
44+
45+
// Determine the width for the APInt comparison
46+
auto comparisonWidth =
47+
std::max(inputValues.getElementType().getIntOrFloatBitWidth(),
48+
lowerBound.getBitWidth());
49+
// Sign-extend the upper and lower bound
50+
auto extUpperBound = upperBound.sext(comparisonWidth);
51+
auto extLowerBound = lowerBound.sext(comparisonWidth);
52+
53+
// Determine the result type
54+
auto resultingIntType = cast<IntegerType>(resultType.getElementType());
55+
56+
// Lambda to perform the clamp
57+
auto clampFun = [&extLowerBound, &extUpperBound,
58+
&comparisonWidth](const APInt &val, IntegerType type) {
59+
auto clampedUpper =
60+
llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound);
61+
auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound);
62+
assert(type.getWidth() >= fullyClamped.getSignificantBits());
63+
return fullyClamped.trunc(type.getWidth());
64+
};
65+
auto newTensor = applyElementWise<APInt, APInt, IntegerType>(
66+
inputValues, clampFun, resultingIntType);
67+
68+
return newTensor;
69+
}
70+
71+
DenseElementsAttr applyClamp(DenseElementsAttr inputValues,
72+
APFloat lowerBound, APFloat upperBound,
73+
TensorType resultType) const {
74+
auto inputValType = cast<FloatType>(inputValues.getElementType());
75+
auto inputWidth = inputValType.getWidth();
76+
auto bWidth = APFloat::semanticsSizeInBits(lowerBound.getSemantics());
77+
auto *comparisonSem = inputWidth < bWidth
78+
? &lowerBound.getSemantics()
79+
: &inputValType.getFloatSemantics();
80+
81+
changeSemanticsLossless(lowerBound, comparisonSem);
82+
changeSemanticsLossless(upperBound, comparisonSem);
83+
84+
auto resultingFloatType = cast<FloatType>(resultType.getElementType());
85+
86+
// Ensure that the value is larger than the lower bound and smaller than the
87+
// upper bound
88+
auto clampFun = [&lowerBound, &upperBound, &comparisonSem](APFloat val,
89+
FloatType type) {
90+
if (val.isNaN()) {
91+
return APFloat::getNaN(type.getFloatSemantics());
92+
}
93+
changeSemanticsLossless(val, comparisonSem);
94+
auto clampedUpper = val < upperBound ? val : upperBound;
95+
auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper;
96+
changeSemanticsLossless(fullyClamped, &type.getFloatSemantics());
97+
return fullyClamped;
98+
};
99+
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
100+
inputValues, clampFun, resultingFloatType);
101+
102+
return newTensor;
103+
}
104+
105+
LogicalResult matchAndRewrite(ClampOp clampOp,
106+
PatternRewriter &rewriter) const override {
107+
auto valsToClamp = clampOp.getInput();
108+
auto inputElementType = valsToClamp.getType().getElementType();
109+
110+
// Check if the input is constant
111+
if (failed(notifyIfNoTosaDenseConstantTensor(valsToClamp, clampOp,
112+
rewriter))) {
113+
return failure();
114+
}
115+
116+
if (isa<IntegerType>(inputElementType) &&
117+
cast<IntegerType>(inputElementType).isUnsigned()) {
118+
return rewriter.notifyMatchFailure(
119+
clampOp, "Currently, unsigned integer clamps are unsupported.");
120+
}
121+
122+
// Extract the tensor values
123+
DenseElementsAttr inputValues;
124+
matchPattern(valsToClamp, m_Constant(&inputValues));
125+
126+
if (!constantUnaryOpShouldBeFolded(clampOp, inputValues)) {
127+
return rewriter.notifyMatchFailure(
128+
clampOp,
129+
"Currently, clamps will only be folded if this requires only "
130+
"little additional memory usage.");
131+
}
132+
133+
// Apply the clamp to all values of the int/float tensor
134+
auto resultType = clampOp.getType();
135+
DenseElementsAttr newTensor;
136+
if (isa<IntegerType>(inputElementType)) {
137+
auto lowerBoundVal = clampOp.getMinIntAttr().getValue();
138+
auto upperBoundVal = clampOp.getMaxIntAttr().getValue();
139+
assert(lowerBoundVal.getBitWidth() == upperBoundVal.getBitWidth());
140+
141+
newTensor =
142+
applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType);
143+
} else {
144+
assert(isa<FloatType>(inputElementType));
145+
auto lowerBoundVal = clampOp.getMinFp();
146+
auto upperBoundVal = clampOp.getMaxFp();
147+
assert(APFloat::getSizeInBits(lowerBoundVal.getSemantics()) ==
148+
APFloat::getSizeInBits(upperBoundVal.getSemantics()));
149+
150+
newTensor =
151+
applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType);
152+
}
153+
154+
rewriter.replaceOpWithNewOp<ConstOp>(clampOp, newTensor.getType(),
155+
newTensor);
156+
157+
return success();
158+
}
159+
};
160+
161+
} // namespace
162+
163+
void mlir::tosa::populateTosaFoldConstantClampPatterns(
164+
MLIRContext *ctx, RewritePatternSet &patterns) {
165+
patterns.add<TosaFoldConstantClamp>(ctx);
166+
}

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
6363
DenseElementsAttr inputValues;
6464
matchPattern(inputTensor, m_Constant(&inputValues));
6565

66-
// Only fold splat tensors and those used only once to avoid duplicating
67-
// them.
68-
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(inputValues)) {
66+
// Check whether this should be folded.
67+
if (!constantUnaryOpShouldBeFolded(rsqrt, inputValues)) {
6968
return rewriter.notifyMatchFailure(
7069
rsqrt, "Currently, reciprocals will only be folded if the input "
7170
"tensor has a single user");

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,8 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
4545
DenseElementsAttr inputValues;
4646
matchPattern(inputTensor, m_Constant(&inputValues));
4747

48-
// Our transformation replaces the input tensor with the transformed tensor.
49-
// If the input has several users we need to keep the input. This can
50-
// result in a significantly increased memory usage, such that we currently
51-
// refrain from applying the transformation in that case.
52-
// Allow this only for splat values, because the amount of data is small.
53-
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(inputValues)) {
48+
// Check whether this should be folded.
49+
if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
5450
return rewriter.notifyMatchFailure(
5551
recip, "Currently, reciprocals will only be folded if the input "
5652
"tensor has a single user");

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct TosaLayerwiseConstantFoldPass
5353
mlir::tosa::populateTosaFoldConstantAddPatterns(ctx, patterns);
5454
mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns,
5555
enableIntCastFolding);
56+
mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns);
5657
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
5758
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5859
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s
2+
3+
// Int clamp
4+
5+
// CHECK-LABEL: @clamp_fold_integer
6+
func.func @clamp_fold_integer() -> tensor<3xi16> {
7+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2, 0, 1{{.*}}tensor<3xi16>
8+
// CHECK-NOT: tosa.clamp
9+
// CHECK: return [[RES]]
10+
%0 = "tosa.const"() {value = dense<[-12, 0, 5]> : tensor<3xi16>} : () -> tensor<3xi16>
11+
%1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 1 : i64, min_fp = 0.0 : f32, min_int = -2 : i64}
12+
: (tensor<3xi16>) -> tensor<3xi16>
13+
return %1 : tensor<3xi16>
14+
}
15+
16+
// CHECK-LABEL: @clamp_fold_integer_equal_lower_upper
17+
func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> {
18+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<17>{{.*}}tensor<3xi8>
19+
// CHECK-NOT: tosa.clamp
20+
// CHECK: return [[RES]]
21+
%0 = "tosa.const"() {value = dense<[2, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8>
22+
%1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 17 : i64, min_fp = 0.0 : f32, min_int = 17 : i64}
23+
: (tensor<3xi8>) -> tensor<3xi8>
24+
return %1 : tensor<3xi8>
25+
}
26+
27+
// CHECK-LABEL: @clamp_fold_integer_maximum_larger_than_result_type
28+
func.func @clamp_fold_integer_maximum_larger_than_result_type() -> tensor<3xi8> {
29+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9, 4, 4{{.*}}tensor<3xi8>
30+
// CHECK-NOT: tosa.clamp
31+
// CHECK: return [[RES]]
32+
%0 = "tosa.const"() {value = dense<[9, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8>
33+
%1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, min_int = 4 : i64}
34+
: (tensor<3xi8>) -> tensor<3xi8>
35+
return %1 : tensor<3xi8>
36+
}
37+
38+
// Float clamp
39+
40+
// CHECK-LABEL: @clamp_fold_float
41+
func.func @clamp_fold_float() -> tensor<3xf16> {
42+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2.{{0*}}e+00, {{[8-9]}}.{{[0-9]*}}e-01, 1.{{0*}}e+00{{.*}}tensor<3xf16>
43+
// CHECK-NOT: tosa.clamp
44+
// CHECK: return [[RES]]
45+
%0 = "tosa.const"() {value = dense<[-12.4, 0.9, 5.2]> : tensor<3xf16>} : () -> tensor<3xf16>
46+
%1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64}
47+
: (tensor<3xf16>) -> tensor<3xf16>
48+
return %1 : tensor<3xf16>
49+
}
50+
51+
// CHECK-LABEL: @clamp_fold_float_infty_nan
52+
func.func @clamp_fold_float_infty_nan() -> tensor<5xf32> {
53+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.{{0*}}e+00, -2.{{0*}}e+00, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32>
54+
// CHECK-NOT: tosa.clamp
55+
// CHECK: return [[RES]]
56+
%0 = "tosa.const"() {value =
57+
dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> :
58+
tensor<5xf32>
59+
} : () -> tensor<5xf32>
60+
%1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64}
61+
: (tensor<5xf32>) -> tensor<5xf32>
62+
return %1 : tensor<5xf32>
63+
}
64+
65+
// CHECK-LABEL: @clamp_fold_float_infinity_upper
66+
func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> {
67+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, -2.{{0*}}e+00, 9.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32>
68+
// CHECK-NOT: tosa.clamp
69+
// CHECK: return [[RES]]
70+
%0 = "tosa.const"() {value =
71+
dense<[0x7F800000, 0xFF800000, 9.0, -0.0, 0x7FC00000]> :
72+
tensor<5xf32>
73+
} : () -> tensor<5xf32>
74+
%1 = "tosa.clamp"(%0) {max_fp = 0x7F800000 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64}
75+
: (tensor<5xf32>) -> tensor<5xf32>
76+
return %1 : tensor<5xf32>
77+
}
78+
79+
// CHECK-LABEL: @clamp_fold_float_maximum_larger_than_result_type
80+
func.func @clamp_fold_float_maximum_larger_than_result_type() -> tensor<2xf16> {
81+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.83{{[0-9]*}}e+01, -5.{{0*}}e-01
82+
// CHECK-NOT: tosa.clamp
83+
// CHECK: return [[RES]]
84+
%0 = "tosa.const"() {value =
85+
dense<[18.32, -0.98747]> :
86+
tensor<2xf16>
87+
} : () -> tensor<2xf16>
88+
%1 = "tosa.clamp"(%0) {max_fp = 3.4028234e+38 : f32, max_int = 1594 : i64, min_fp = -0.5 : f32, min_int = -17 : i64}
89+
: (tensor<2xf16>) -> tensor<2xf16>
90+
return %1 : tensor<2xf16>
91+
}

0 commit comments

Comments
 (0)