Skip to content

Commit 870ac13

Browse files
TinaAMDGitHub Enterprise
authored andcommitted
Merge pull request #8 from ACT/tina.tosarsqrtfolding
[FXML-1727] Implement RSQRT and refactor reciprocal
2 parents 999ab5f + 1152911 commit 870ac13

File tree

8 files changed

+392
-70
lines changed

8 files changed

+392
-70
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ void populateTosaDecomposeDepthwise(MLIRContext *ctx,
3131
RewritePatternSet &patterns);
3232
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
3333
RewritePatternSet &patterns);
34+
void populateTosaFoldConstantRSQRTPatterns(MLIRContext *ctx,
35+
RewritePatternSet &patterns);
3436
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
3537
RewritePatternSet &patterns);
3638

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- TosaFoldCommon.h - Helper Functions for Folds ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Helper functions useful for various different TOSA constant folds.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H
13+
#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H
14+
15+
#include <llvm/ADT/APFloat.h>
16+
#include <functional>
17+
#include <mlir/Dialect/Tosa/IR/TosaOps.h>
18+
#include <mlir/IR/PatternMatch.h>
19+
20+
namespace mlir {
21+
namespace tosa {
22+
23+
// Transform a tensor with the given transformation function.
24+
DenseElementsAttr applyElementWise(
25+
const DenseElementsAttr &toTransform,
26+
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply);
27+
28+
/// Function that checks if arg is a dense TOSA constant float tensor
29+
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
30+
TosaOp location,
31+
PatternRewriter &);
32+
33+
/// Function that checks if arg is a dense TOSA constant tensor
34+
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
35+
TosaOp location,
36+
PatternRewriter &);
37+
38+
/// Function that checks if the contained type is float
39+
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
40+
PatternRewriter &);
41+
42+
/// Function to compute the reciprocal
43+
APFloat computeReciprocal(const APFloat &, Type);
44+
45+
} // namespace tosa
46+
} // namespace mlir
47+
48+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ add_mlir_dialect_library(MLIRTosaTransforms
22
TosaDecomposeTransposeConv.cpp
33
TosaDecomposeConv2D.cpp
44
TosaDecomposeDepthwise.cpp
5+
TosaFoldCommon.cpp
56
TosaFoldConstantReciprocal.cpp
7+
TosaFoldConstantRSQRT.cpp
68
TosaFoldConstantTranspose.cpp
79
TosaInferShapes.cpp
810
TosaLayerwiseConstantFoldPass.cpp
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//===- TosaFoldCommon.cpp -------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Helper functions useful for various different TOSA constant folds.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
14+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15+
#include <llvm/ADT/APFloat.h>
16+
#include <mlir/IR/BuiltinAttributes.h>
17+
#include <mlir/IR/BuiltinTypes.h>
18+
#include <mlir/IR/Matchers.h>
19+
#include <mlir/Support/LogicalResult.h>
20+
21+
using namespace mlir;
22+
using namespace mlir::tosa;
23+
24+
namespace {
25+
static constexpr llvm::RoundingMode reciprocalRoundingMode =
26+
APFloat::rmNearestTiesToEven;
27+
} // namespace
28+
29+
DenseElementsAttr mlir::tosa::applyElementWise(
30+
const DenseElementsAttr &toTransform,
31+
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply) {
32+
llvm::SmallVector<llvm::APFloat, 1> transformedValues;
33+
// We already know the amount of values we will insert, reserve space for
34+
// all of them to avoid dynamic resizing
35+
transformedValues.reserve(toTransform.getNumElements());
36+
for (auto val : toTransform.getValues<llvm::APFloat>()) {
37+
auto recipVal = toApply(val, toTransform.getElementType());
38+
transformedValues.push_back(recipVal);
39+
}
40+
41+
// Replace the current tensor with one containing the computed reciprocals
42+
auto newTensor =
43+
DenseElementsAttr::get(toTransform.getType(), transformedValues);
44+
return newTensor;
45+
}
46+
47+
LogicalResult
48+
mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
49+
TosaOp location,
50+
PatternRewriter &rewriter) {
51+
auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
52+
if (failed(floatCheck)) {
53+
return floatCheck;
54+
}
55+
return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
56+
}
57+
58+
LogicalResult
59+
mlir::tosa::notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
60+
TosaOp location,
61+
PatternRewriter &rewriter) {
62+
// Check whether the tensor is constant and dense
63+
// TODO We currently ensure the tensor is dense by using the correct type for
64+
// the bind_value, however we do not actually need this value. It would be
65+
// nicer to only have a check here.
66+
DenseElementsAttr tmp;
67+
if (!matchPattern(toCheck, m_Constant(&tmp))) {
68+
return rewriter.notifyMatchFailure(location,
69+
"Non-const or non-dense input tensor");
70+
}
71+
72+
// Make sure it actually is a TOSA constant (the match allows for other
73+
// constants as well)
74+
if (isa<ConstOp>(toCheck.getDefiningOp())) {
75+
return success();
76+
}
77+
78+
return rewriter.notifyMatchFailure(location,
79+
"The reciprocal can only be folded if "
80+
"it operates on a TOSA constant");
81+
}
82+
83+
LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue<TensorType> toCheck,
84+
TosaOp location,
85+
PatternRewriter &rewriter) {
86+
if (isa<FloatType>(toCheck.getType().getElementType())) {
87+
return success();
88+
}
89+
return rewriter.notifyMatchFailure(location,
90+
"Unexpected input tensor type: the "
91+
"TOSA spec only allows floats");
92+
}
93+
94+
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {
95+
auto recipAttr = FloatAttr::get(floatTy, 1.0);
96+
APFloat recip = recipAttr.getValue();
97+
recip.divide(floatVal, reciprocalRoundingMode);
98+
99+
return recip;
100+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- TosaFoldConstantRSQRT.cpp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA RSQRT (reciprocal square root) operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <llvm/ADT/APFloat.h>
19+
#include <llvm/ADT/FloatingPointMode.h>
20+
#include <llvm/IR/Constants.h>
21+
#include <cmath>
22+
#include <mlir/IR/BuiltinAttributes.h>
23+
#include <mlir/IR/BuiltinTypes.h>
24+
#include <mlir/Support/LogicalResult.h>
25+
26+
using namespace mlir;
27+
using namespace mlir::tosa;
28+
29+
namespace {
30+
31+
struct TosaFoldConstantRSQRT : public OpRewritePattern<RsqrtOp> {
32+
33+
using OpRewritePattern::OpRewritePattern;
34+
35+
static APFloat computeRSQRT(const APFloat &apFloatVal, Type floatTy) {
36+
// The result for negative values (apart from zero) is always NaN
37+
if (apFloatVal.isNegative() && !apFloatVal.isNegZero()) {
38+
return APFloat::getNaN(apFloatVal.getSemantics());
39+
}
40+
41+
// Compute the square root (APFloat unfortunately does not provide this
42+
// function, such that we need to unpack here)
43+
auto floatVal = apFloatVal.convertToFloat();
44+
auto sqrtVal = std::sqrt(floatVal);
45+
APFloat apSqrtVal(sqrtVal);
46+
47+
// Compute the reciprocal
48+
return computeReciprocal(apSqrtVal, floatTy);
49+
}
50+
51+
LogicalResult matchAndRewrite(RsqrtOp rsqrt,
52+
PatternRewriter &rewriter) const override {
53+
auto inputTensor = rsqrt.getInput1();
54+
55+
// Reject non-float or non-dense tensors
56+
auto foldable =
57+
notifyIfNotConstantFloatTosaTensor(inputTensor, rsqrt, rewriter);
58+
if (failed(foldable)) {
59+
return foldable;
60+
}
61+
62+
// Extract the tensor values
63+
DenseElementsAttr inputValues;
64+
matchPattern(inputTensor, m_Constant(&inputValues));
65+
66+
// Only fold splat tensors and those used only once to avoid duplicating
67+
// them.
68+
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(inputValues)) {
69+
return rewriter.notifyMatchFailure(
70+
rsqrt, "Currently, reciprocals will only be folded if the input "
71+
"tensor has a single user");
72+
}
73+
74+
// Create a new tensor with the updated values
75+
auto newTensor = applyElementWise(inputValues, &computeRSQRT);
76+
77+
// Replace the use of the reciprocal with the transformed tensor
78+
rewriter.replaceOpWithNewOp<ConstOp>(rsqrt, newTensor.getType(), newTensor);
79+
80+
return success();
81+
}
82+
};
83+
84+
} // namespace
85+
86+
void mlir::tosa::populateTosaFoldConstantRSQRTPatterns(
87+
MLIRContext *ctx, RewritePatternSet &patterns) {
88+
patterns.add<TosaFoldConstantRSQRT>(ctx);
89+
}

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp

Lines changed: 12 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1414
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
1516
#include "mlir/IR/Matchers.h"
1617
#include "mlir/Pass/Pass.h"
1718
#include <llvm/ADT/APFloat.h>
@@ -28,93 +29,35 @@ namespace {
2829
struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
2930

3031
using OpRewritePattern::OpRewritePattern;
31-
static constexpr llvm::RoundingMode reciprocalRoundingMode =
32-
APFloat::rmNearestTiesToEven;
33-
34-
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy) const {
35-
auto recipAttr = FloatAttr::get(floatTy, 1.0);
36-
APFloat recip = recipAttr.getValue();
37-
recip.divide(floatVal, reciprocalRoundingMode);
38-
39-
return recip;
40-
}
41-
42-
DenseElementsAttr
43-
replaceTensorWithReciprocal(ConstOp tensorToReplace,
44-
const DenseElementsAttr &inputValues) const {
45-
// TODO it would be nicer to do this in-place
46-
47-
// Compute the reciprocal for each tensor element
48-
llvm::SmallVector<APFloat, 1> transformedValues;
49-
// We already know the amount of values we will insert, reserve space for
50-
// all of them to avoid dynamic resizing
51-
transformedValues.reserve(inputValues.getNumElements());
52-
for (auto val : inputValues.getValues<APFloat>()) {
53-
auto recipVal = computeReciprocal(val, inputValues.getElementType());
54-
transformedValues.push_back(recipVal);
55-
}
56-
57-
// Replace the current tensor with one containing the computed reciprocals
58-
auto newTensor =
59-
DenseElementsAttr::get(inputValues.getType(), transformedValues);
60-
return newTensor;
61-
}
6232

6333
LogicalResult matchAndRewrite(ReciprocalOp recip,
6434
PatternRewriter &rewriter) const override {
6535
auto inputTensor = recip.getInput1();
66-
auto elemType = inputTensor.getType().getElementType();
67-
// TOSA only allows for floats as inputs to the reciprocal operation, so
68-
// bail if anything else is contained
69-
if (!isa<FloatType>(elemType)) {
70-
return rewriter.notifyMatchFailure(recip,
71-
"Unexpected input tensor type: the "
72-
"TOSA spec only allows floats");
73-
}
7436

75-
// Check whether the tensor is constant and dense
76-
DenseElementsAttr inputValues;
77-
if (!matchPattern(inputTensor, m_Constant(&inputValues))) {
78-
return rewriter.notifyMatchFailure(
79-
recip, "Non-const or non-dense input to reciprocal");
80-
}
81-
82-
// In case we have a splat, we only need to calculate the reciprocal once
83-
// and update the tensor to the transformed splat value.
84-
if (auto splatAttrs = dyn_cast<SplatElementsAttr>(inputValues)) {
85-
// Transform the splat value
86-
auto splatVal = splatAttrs.getSplatValue<APFloat>();
87-
auto newSplatRecipAttr = computeReciprocal(splatVal, elemType);
88-
89-
// Create a tensor with the transformed splat value
90-
auto newSplatTensor =
91-
DenseElementsAttr::get(splatAttrs.getType(), newSplatRecipAttr);
92-
93-
// Replace the reciprocal op with the newly constructed tensor
94-
rewriter.replaceOpWithNewOp<ConstOp>(recip, newSplatTensor.getType(),
95-
newSplatTensor);
96-
return success();
37+
// Check that we can apply folding
38+
auto preCondCheck =
39+
notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
40+
if (failed(preCondCheck)) {
41+
return preCondCheck;
9742
}
9843

99-
if (!isa<ConstOp>(inputTensor.getDefiningOp())) {
100-
return rewriter.notifyMatchFailure(recip,
101-
"The reciprocal can only be folded if "
102-
"it operates on a TOSA constant");
103-
}
104-
auto definingConstOp = cast<ConstOp>(inputTensor.getDefiningOp());
44+
// Extract the tensor values
45+
DenseElementsAttr inputValues;
46+
matchPattern(inputTensor, m_Constant(&inputValues));
10547

10648
// Our transformation replaces the input tensor with the transformed tensor.
10749
// If the input has several users we need to keep the input. This can
10850
// result in a significantly increased memory usage, such that we currently
10951
// refrain from applying the transformation in that case.
110-
if (!definingConstOp->hasOneUse()) {
52+
// Allow this only for splat values, because the amount of data is small.
53+
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(inputValues)) {
11154
return rewriter.notifyMatchFailure(
11255
recip, "Currently, reciprocals will only be folded if the input "
11356
"tensor has a single user");
11457
}
11558

11659
// Create a new tensor with the updated values
117-
auto newTensor = replaceTensorWithReciprocal(definingConstOp, inputValues);
60+
auto newTensor = applyElementWise(inputValues, &computeReciprocal);
11861

11962
// Replace the use of the reciprocal with the transformed tensor
12063
rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ struct TosaLayerwiseConstantFoldPass
5151
auto func = getOperation();
5252

5353
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
54+
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
5455
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
5556
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
5657

57-
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
58+
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
5859
signalPassFailure();
5960
}
6061
};

0 commit comments

Comments
 (0)