Skip to content

Commit 6b18842

Browse files
authored
Merge pull request #15 from Xilinx/tina.tosafoldpow
Implement folding for the TOSA power operation
2 parents 0125996 + 2f88299 commit 6b18842

File tree

7 files changed

+335
-4
lines changed

7 files changed

+335
-4
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
2929
RewritePatternSet &patterns);
3030
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
3131
RewritePatternSet &patterns);
32+
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
33+
RewritePatternSet &patterns);
3234
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
3335
RewritePatternSet &patterns);
3436
void populateTosaFoldConstantRSQRTPatterns(MLIRContext *ctx,

mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,68 @@
1313
#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H
1414

1515
#include <llvm/ADT/APFloat.h>
16+
#include <llvm/ADT/ArrayRef.h>
1617
#include <functional>
1718
#include <mlir/Dialect/Tosa/IR/TosaOps.h>
1819
#include <mlir/IR/PatternMatch.h>
1920

2021
namespace mlir {
2122
namespace tosa {
2223

24+
/// Type that represents tensor dimensions.
25+
using DimensionType = ArrayRef<int64_t>;
26+
27+
/// Type for tensor offsets.
28+
using OffsetType = size_t;
29+
2330
/// Transform a tensor with the given transformation function.
2431
DenseElementsAttr applyElementWise(
2532
const DenseElementsAttr &toTransform,
2633
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply);
2734

35+
/// Apply the given transformation function on the elements of the given
36+
/// tensors. If the input tensors do not match \p targetType, broadcasting is
37+
/// applied.
38+
DenseElementsAttr applyElementWise(
39+
const DenseElementsAttr &first, const DenseElementsAttr &second,
40+
TensorType targetType,
41+
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply);
42+
2843
/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
2944
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
3045
TosaOp location,
31-
PatternRewriter &);
46+
PatternRewriter &rewriter);
3247

3348
/// Function that checks if \p toCheck is a dense TOSA constant tensor.
3449
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
3550
TosaOp location,
36-
PatternRewriter &);
51+
PatternRewriter &rewriter);
3752

3853
/// Function that checks if the type contained in \p toCheck is float.
3954
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
40-
PatternRewriter &);
55+
PatternRewriter &rewriter);
56+
57+
/// Compute the offset in \p shape which corresponds to the given \p index.
58+
OffsetType indexToOffset(DimensionType shape, DimensionType index);
59+
60+
/// Compute the index into \p shape which corresponds to the given \p offset.
61+
SmallVector<int64_t> offsetToIndex(DimensionType shape, OffsetType offset);
62+
63+
/// Given an \p index into \p desiredShape, compute the corresponding index into
64+
/// \p toBeBroadcastedShape.
65+
/// \returns broadcasted index into \p toBeBroadcastedShape.
66+
SmallVector<int64_t> getBroadcastedIndex(DimensionType desiredShape,
67+
DimensionType toBeBroadcastedShape,
68+
DimensionType index);
69+
/// Given an \p offset into \p desiredShape, compute the corresponding offset
70+
/// into \p toBeBroadcastedShape.
71+
/// \returns broadcasted offset into \p toBeBroadcastedShape.
72+
OffsetType getBroadcastedOffset(DimensionType desiredShape,
73+
DimensionType toBeBroadcastedShape,
74+
OffsetType offset);
4175

4276
/// Function to compute the reciprocal.
43-
APFloat computeReciprocal(const APFloat &, Type);
77+
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy);
4478

4579
} // namespace tosa
4680
} // namespace mlir

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
33
TosaDecomposeConv2D.cpp
44
TosaDecomposeDepthwise.cpp
55
TosaFoldCommon.cpp
6+
TosaFoldConstantPow.cpp
67
TosaFoldConstantReciprocal.cpp
78
TosaFoldConstantRSQRT.cpp
89
TosaFoldConstantTranspose.cpp

mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
1414
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1515
#include <llvm/ADT/APFloat.h>
16+
#include <llvm/ADT/SmallVector.h>
17+
#include <algorithm>
1618
#include <mlir/IR/BuiltinAttributes.h>
1719
#include <mlir/IR/BuiltinTypes.h>
1820
#include <mlir/IR/Matchers.h>
@@ -44,6 +46,42 @@ DenseElementsAttr mlir::tosa::applyElementWise(
4446
return newTensor;
4547
}
4648

49+
DenseElementsAttr mlir::tosa::applyElementWise(
50+
const DenseElementsAttr &first, const DenseElementsAttr &second,
51+
TensorType targetType,
52+
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply) {
53+
// Make sure to use the correct values in case broadcasting is required
54+
SmallVector<APFloat> transformedValues;
55+
// We already know the amount of values we will insert, reserve space for
56+
// all of them to avoid dynamic resizing
57+
auto targetSize = 1;
58+
auto targetShape = targetType.getShape();
59+
for (const auto &dimSize : targetShape) {
60+
targetSize *= dimSize;
61+
}
62+
transformedValues.reserve(targetSize);
63+
64+
// Apply the given function to each pair of values from the input tensors.
65+
// Make sure to broadcast the offsets properly.
66+
auto firstIt = first.getValues<APFloat>();
67+
auto firstShape = first.getType().getShape();
68+
auto secondIt = second.getValues<APFloat>();
69+
auto secondShape = second.getType().getShape();
70+
for (auto offset = 0; offset < targetSize; offset++) {
71+
OffsetType offsetInTargetFirst =
72+
getBroadcastedOffset(targetShape, firstShape, offset);
73+
OffsetType offsetInTargetSecond =
74+
getBroadcastedOffset(targetShape, secondShape, offset);
75+
auto res =
76+
toApply(firstIt[offsetInTargetFirst], secondIt[offsetInTargetSecond]);
77+
transformedValues.push_back(res);
78+
}
79+
80+
// Generate a tensor with the computed values.
81+
auto newTensor = DenseElementsAttr::get(targetType, transformedValues);
82+
return newTensor;
83+
}
84+
4785
LogicalResult
4886
mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
4987
TosaOp location,
@@ -91,6 +129,59 @@ LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue<TensorType> toCheck,
91129
"TOSA spec only allows floats");
92130
}
93131

132+
OffsetType mlir::tosa::indexToOffset(DimensionType shape, DimensionType index) {
133+
OffsetType offset = 0;
134+
for (size_t i = 0; i < shape.size(); i++) {
135+
offset = offset * shape[i] + index[i];
136+
}
137+
return offset;
138+
}
139+
140+
SmallVector<int64_t> mlir::tosa::offsetToIndex(DimensionType shape,
141+
OffsetType offset) {
142+
auto rank = shape.size();
143+
// The rank of the index will be equal to the rank of the shape
144+
SmallVector<int64_t> resultIndex;
145+
resultIndex.reserve(rank);
146+
// Compute all the index values from the last to the first one, reverse the
147+
// vector afterwards as there is no convenient push_front.
148+
for (int32_t i = rank - 1; i >= 0; i--) {
149+
resultIndex.push_back(offset % shape[i]);
150+
offset /= shape[i];
151+
}
152+
std::reverse(resultIndex.begin(), resultIndex.end());
153+
return resultIndex;
154+
}
155+
156+
SmallVector<int64_t>
157+
mlir::tosa::getBroadcastedIndex(DimensionType desiredShape,
158+
DimensionType toBeBroadcastedShape,
159+
DimensionType index) {
160+
SmallVector<int64_t> broadCasted;
161+
broadCasted.reserve(desiredShape.size());
162+
for (size_t i = 0; i < desiredShape.size(); i++) {
163+
auto toInsert = 0;
164+
if (toBeBroadcastedShape[i] == desiredShape[i]) {
165+
toInsert = index[i];
166+
}
167+
broadCasted.push_back(toInsert);
168+
}
169+
return broadCasted;
170+
}
171+
172+
OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
173+
DimensionType toBeBroadcastedShape,
174+
OffsetType offset) {
175+
// Simply return the offset if the shapes are equal.
176+
if (desiredShape.equals(toBeBroadcastedShape)) {
177+
return offset;
178+
}
179+
auto indexInTarget = offsetToIndex(desiredShape, offset);
180+
auto indexBroadcasted =
181+
getBroadcastedIndex(desiredShape, toBeBroadcastedShape, indexInTarget);
182+
return indexToOffset(toBeBroadcastedShape, indexBroadcasted);
183+
}
184+
94185
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {
95186
auto recipAttr = FloatAttr::get(floatTy, 1.0);
96187
APFloat recip = recipAttr.getValue();
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- TosaFoldConstantPow.cpp --------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA Pow operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <cmath>
19+
#include <llvm/ADT/APFloat.h>
20+
#include <llvm/ADT/FloatingPointMode.h>
21+
#include <llvm/ADT/SmallVector.h>
22+
#include <mlir/IR/BuiltinAttributes.h>
23+
#include <mlir/Support/LogicalResult.h>
24+
25+
using namespace mlir;
26+
using namespace mlir::tosa;
27+
28+
namespace {
29+
30+
struct TosaFoldConstantPow : public OpRewritePattern<PowOp> {
31+
32+
using OpRewritePattern::OpRewritePattern;
33+
34+
static APFloat computePower(const APFloat &base, const APFloat &exp) {
35+
// Propagate NaN
36+
if (base.isNaN() || exp.isNaN()) {
37+
return APFloat::getNaN(base.getSemantics());
38+
}
39+
// TOSA defines 0.0**0.0 as NaN
40+
if (base.isZero() && exp.isZero()) {
41+
return APFloat::getNaN(base.getSemantics());
42+
}
43+
// In case the value is negative, the exponent needs to be an integer
44+
if (base.isNegative() && !base.isZero()) {
45+
if (!exp.isInteger()) {
46+
return APFloat::getNaN(base.getSemantics());
47+
}
48+
}
49+
50+
// Actually compute base**exp. Special cases for [-]infinity and [-]0 are
51+
// already handled in accordance with the TOSA spec.
52+
auto powFloat = std::pow(base.convertToFloat(), exp.convertToFloat());
53+
auto res = APFloat(powFloat);
54+
55+
bool lostPrecision;
56+
res.convert(base.getSemantics(), APFloat::rmNearestTiesToEven,
57+
&lostPrecision);
58+
return res;
59+
}
60+
61+
LogicalResult matchAndRewrite(PowOp powOp,
62+
PatternRewriter &rewriter) const override {
63+
auto baseOp = powOp.getInput1();
64+
auto expOp = powOp.getInput2();
65+
66+
// Check if both tensors are constant
67+
auto baseIsConstCheck =
68+
notifyIfNotConstantFloatTosaTensor(baseOp, powOp, rewriter);
69+
if (failed(baseIsConstCheck)) {
70+
return baseIsConstCheck;
71+
}
72+
auto expIsConstCheck =
73+
notifyIfNotConstantFloatTosaTensor(expOp, powOp, rewriter);
74+
if (failed(expIsConstCheck)) {
75+
return expIsConstCheck;
76+
}
77+
78+
// Extract the tensor values
79+
DenseElementsAttr baseValues;
80+
matchPattern(baseOp, m_Constant(&baseValues));
81+
82+
DenseElementsAttr expValues;
83+
matchPattern(expOp, m_Constant(&expValues));
84+
85+
// If both tensors are splat, we don't care for the number of users
86+
if (!isa<SplatElementsAttr>(baseValues) ||
87+
!isa<SplatElementsAttr>(expValues)) {
88+
// Make sure that at least one of the constant input tensors can be
89+
// replaced (i.e. only has a single user)
90+
if (!baseOp.hasOneUse() && !expOp.hasOneUse()) {
91+
return rewriter.notifyMatchFailure(
92+
powOp, "Currently, pows will only be folded if at least one input "
93+
"tensor only has a single user");
94+
}
95+
}
96+
97+
auto newTensor =
98+
applyElementWise(baseValues, expValues, powOp.getType(), &computePower);
99+
rewriter.replaceOpWithNewOp<ConstOp>(powOp, newTensor.getType(), newTensor);
100+
101+
return success();
102+
}
103+
};
104+
105+
} // namespace
106+
107+
void mlir::tosa::populateTosaFoldConstantPowPatterns(
108+
MLIRContext *ctx, RewritePatternSet &patterns) {
109+
patterns.add<TosaFoldConstantPow>(ctx);
110+
}

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
5050
RewritePatternSet patterns(ctx);
5151
auto func = getOperation();
5252

53+
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
5354
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5455
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
5556
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);

0 commit comments

Comments
 (0)