Skip to content

Commit a8712f6

Browse files
committed
Implement folding for the power operation.
Use the TOSA spec for POW ([0]+[1]). [0] https://www.mlplatform.org/tosa/tosa_spec.html#_pow [1] https://www.mlplatform.org/tosa/tosa_spec.html#_main_inference_profile tosa.pow can be applied to tensors of different shapes, in which case broadcasting is applied. Implement TOSA broadcasting helpers as specified here [2] as well. [2] https://www.mlplatform.org/tosa/tosa_spec.html#_tensor_access_helpers
1 parent 351076f commit a8712f6

File tree

7 files changed

+324
-0
lines changed

7 files changed

+324
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
2929
RewritePatternSet &patterns);
3030
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
3131
RewritePatternSet &patterns);
32+
void populateTosaFoldConstantPowPatterns(MLIRContext *ctx,
33+
RewritePatternSet &patterns);
3234
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
3335
RewritePatternSet &patterns);
3436
void populateTosaFoldConstantRSQRTPatterns(MLIRContext *ctx,

mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,32 @@
1313
#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H
1414

1515
#include <llvm/ADT/APFloat.h>
16+
#include <llvm/ADT/ArrayRef.h>
1617
#include <functional>
1718
#include <mlir/Dialect/Tosa/IR/TosaOps.h>
1819
#include <mlir/IR/PatternMatch.h>
1920

2021
namespace mlir {
2122
namespace tosa {
2223

24+
/// Type that represents tensor dimensions.
25+
using DimensionType = ArrayRef<int64_t>;
26+
27+
/// Type for tensor offsets.
28+
using OffsetType = size_t;
29+
2330
/// Transform a tensor with the given transformation function.
2431
DenseElementsAttr applyElementWise(
2532
const DenseElementsAttr &toTransform,
2633
const std::function<llvm::APFloat(const llvm::APFloat &, Type)> &toApply);
2734

35+
/// Apply the given transformation function on the elements of the given
36+
/// tensors. If the input tensors do not match \p targetType, broadcasting is
37+
/// applied.
38+
DenseElementsAttr applyElementWise(
39+
const DenseElementsAttr &, const DenseElementsAttr &, TensorType targetType,
40+
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply);
41+
2842
/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
2943
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
3044
TosaOp location,
@@ -39,6 +53,23 @@ LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
3953
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
4054
PatternRewriter &);
4155

56+
/// Compute the offset in \p shape which corresponds to the given \p index.
57+
OffsetType indexToOffset(DimensionType shape, DimensionType index);
58+
59+
/// Compute the index into \p shape which corresponds to the given \p offset.
60+
SmallVector<int64_t> offsetToIndex(DimensionType shape, OffsetType offset);
61+
62+
/// Given an \p index into \p desiredShape, compute the corresponding index into
63+
/// \p toBeBroadcasted.
64+
SmallVector<int64_t> getBroadcastedIndex(DimensionType desiredShape,
65+
DimensionType toBeBroadcasted,
66+
DimensionType index);
67+
/// Given an \p offset into \p desiredShape, compute the corresponding offset
68+
/// into \p toBeBroadcasted.
69+
OffsetType getBroadcastedOffset(DimensionType desiredShape,
70+
DimensionType toBeBroadcasted,
71+
OffsetType offset);
72+
4273
/// Function to compute the reciprocal.
4374
APFloat computeReciprocal(const APFloat &, Type);
4475

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
33
TosaDecomposeConv2D.cpp
44
TosaDecomposeDepthwise.cpp
55
TosaFoldCommon.cpp
6+
TosaFoldConstantPow.cpp
67
TosaFoldConstantReciprocal.cpp
78
TosaFoldConstantRSQRT.cpp
89
TosaFoldConstantTranspose.cpp

mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
1414
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15+
#include <algorithm>
1516
#include <llvm/ADT/APFloat.h>
17+
#include <llvm/ADT/SmallVector.h>
1618
#include <mlir/IR/BuiltinAttributes.h>
1719
#include <mlir/IR/BuiltinTypes.h>
1820
#include <mlir/IR/Matchers.h>
@@ -44,6 +46,42 @@ DenseElementsAttr mlir::tosa::applyElementWise(
4446
return newTensor;
4547
}
4648

49+
DenseElementsAttr mlir::tosa::applyElementWise(
50+
const DenseElementsAttr &first, const DenseElementsAttr &second,
51+
TensorType targetType,
52+
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply) {
53+
// Make sure to use the correct values in case broadcasting is required
54+
SmallVector<APFloat> transformedValues;
55+
// We already know the amount of values we will insert, reserve space for
56+
// all of them to avoid dynamic resizing
57+
auto targetSize = 1;
58+
auto targetShape = targetType.getShape();
59+
for (const auto &dimSize : targetShape) {
60+
targetSize *= dimSize;
61+
}
62+
transformedValues.reserve(targetSize);
63+
64+
// Apply the given function to each pair of values from the input tensors.
65+
// Make sure to broadcast the offsets properly.
66+
auto firstIt = first.getValues<APFloat>();
67+
auto firstShape = first.getType().getShape();
68+
auto secondIt = second.getValues<APFloat>();
69+
auto secondShape = second.getType().getShape();
70+
for (auto offset = 0; offset < targetSize; offset++) {
71+
OffsetType offsetInTargetFirst =
72+
getBroadcastedOffset(targetShape, firstShape, offset);
73+
OffsetType offsetInTargetSecond =
74+
getBroadcastedOffset(targetShape, secondShape, offset);
75+
auto res =
76+
toApply(firstIt[offsetInTargetFirst], secondIt[offsetInTargetSecond]);
77+
transformedValues.push_back(res);
78+
}
79+
80+
// Generate a tensor with the computed values.
81+
auto newTensor = DenseElementsAttr::get(targetType, transformedValues);
82+
return newTensor;
83+
}
84+
4785
LogicalResult
4886
mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
4987
TosaOp location,
@@ -91,6 +129,55 @@ LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue<TensorType> toCheck,
91129
"TOSA spec only allows floats");
92130
}
93131

132+
OffsetType mlir::tosa::indexToOffset(DimensionType shape, DimensionType index) {
133+
OffsetType offset = 0;
134+
for (size_t i = 0; i < shape.size(); i++) {
135+
offset = offset * shape[i] + index[i];
136+
}
137+
return offset;
138+
}
139+
140+
SmallVector<int64_t> mlir::tosa::offsetToIndex(DimensionType shape,
141+
OffsetType offset) {
142+
auto rank = shape.size();
143+
// The rank of the index will be equal to the rank of the shape
144+
SmallVector<int64_t> resultIndex;
145+
resultIndex.reserve(rank);
146+
// Compute all the index values from the last to the first one, reverse the
147+
// vector afterwards as there is no convenient push_front.
148+
for (int32_t i = rank - 1; i >= 0; i--) {
149+
resultIndex.push_back(offset % shape[i]);
150+
offset /= shape[i];
151+
}
152+
std::reverse(resultIndex.begin(), resultIndex.end());
153+
return resultIndex;
154+
}
155+
156+
SmallVector<int64_t>
157+
mlir::tosa::getBroadcastedIndex(DimensionType desiredShape,
158+
DimensionType toBeBroadcasted,
159+
DimensionType index) {
160+
SmallVector<int64_t> broadCasted;
161+
broadCasted.reserve(desiredShape.size());
162+
for (size_t i = 0; i < desiredShape.size(); i++) {
163+
auto toInsert = 0;
164+
if (toBeBroadcasted[i] == desiredShape[i]) {
165+
toInsert = index[i];
166+
}
167+
broadCasted.push_back(toInsert);
168+
}
169+
return broadCasted;
170+
}
171+
172+
OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
173+
DimensionType toBeBroadcasted,
174+
OffsetType offset) {
175+
auto indexInTarget = offsetToIndex(desiredShape, offset);
176+
auto indexBroadcasted =
177+
getBroadcastedIndex(desiredShape, toBeBroadcasted, indexInTarget);
178+
return indexToOffset(toBeBroadcasted, indexBroadcasted);
179+
}
180+
94181
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {
95182
auto recipAttr = FloatAttr::get(floatTy, 1.0);
96183
APFloat recip = recipAttr.getValue();
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- TosaFoldConstantPow.cpp --------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Fold TOSA Pow operation on constant data
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
14+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
15+
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
16+
#include "mlir/IR/Matchers.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include <cmath>
19+
#include <llvm/ADT/APFloat.h>
20+
#include <llvm/ADT/FloatingPointMode.h>
21+
#include <llvm/ADT/SmallVector.h>
22+
#include <mlir/IR/BuiltinAttributes.h>
23+
#include <mlir/Support/LogicalResult.h>
24+
25+
using namespace mlir;
26+
using namespace mlir::tosa;
27+
28+
namespace {
29+
30+
struct TosaFoldConstantPow : public OpRewritePattern<PowOp> {
31+
32+
using OpRewritePattern::OpRewritePattern;
33+
34+
static APFloat computePower(const APFloat &base, const APFloat &exp) {
35+
// Propagate NaN
36+
if (base.isNaN() || exp.isNaN()) {
37+
return APFloat::getNaN(base.getSemantics());
38+
}
39+
// TOSA defines 0.0**0.0 as NaN
40+
if (base.isZero() && exp.isZero()) {
41+
return APFloat::getNaN(base.getSemantics());
42+
}
43+
// In case the value is negative, the exponent needs to be an integer
44+
if (base.isNegative() && !base.isZero()) {
45+
if (!exp.isInteger()) {
46+
return APFloat::getNaN(base.getSemantics());
47+
}
48+
}
49+
50+
// Actually compute base**exp. Special cases for [-]infinity and [-]0 are
51+
// already handled in accordance with the TOSA spec.
52+
auto powFloat = std::pow(base.convertToFloat(), exp.convertToFloat());
53+
auto res = APFloat(powFloat);
54+
55+
bool lostPrecision;
56+
res.convert(base.getSemantics(), APFloat::rmNearestTiesToEven,
57+
&lostPrecision);
58+
return res;
59+
}
60+
61+
LogicalResult matchAndRewrite(PowOp powOp,
62+
PatternRewriter &rewriter) const override {
63+
auto baseOp = powOp.getInput1();
64+
auto expOp = powOp.getInput2();
65+
66+
// Check if both tensors are constant
67+
auto baseIsConstCheck =
68+
notifyIfNotConstantFloatTosaTensor(baseOp, powOp, rewriter);
69+
if (failed(baseIsConstCheck)) {
70+
return baseIsConstCheck;
71+
}
72+
auto expIsConstCheck =
73+
notifyIfNotConstantFloatTosaTensor(expOp, powOp, rewriter);
74+
if (failed(expIsConstCheck)) {
75+
return expIsConstCheck;
76+
}
77+
78+
// Extract the tensor values
79+
DenseElementsAttr baseValues;
80+
matchPattern(baseOp, m_Constant(&baseValues));
81+
82+
DenseElementsAttr expValues;
83+
matchPattern(expOp, m_Constant(&expValues));
84+
85+
// If both tensors are splat, we don't care for the number of users
86+
if (!isa<SplatElementsAttr>(baseValues) ||
87+
!isa<SplatElementsAttr>(expValues)) {
88+
// Make sure that at least one of the constant input tensors can be
89+
// replaced (i.e. only has a single user)
90+
if (!baseOp.hasOneUse() && !expOp.hasOneUse()) {
91+
return rewriter.notifyMatchFailure(
92+
powOp, "Currently, pows will only be folded if at least one input "
93+
"tensor only has a single user");
94+
}
95+
}
96+
97+
auto newTensor =
98+
applyElementWise(baseValues, expValues, powOp.getType(), &computePower);
99+
rewriter.replaceOpWithNewOp<ConstOp>(powOp, newTensor.getType(), newTensor);
100+
101+
return success();
102+
}
103+
};
104+
105+
} // namespace
106+
107+
void mlir::tosa::populateTosaFoldConstantPowPatterns(
108+
MLIRContext *ctx, RewritePatternSet &patterns) {
109+
patterns.add<TosaFoldConstantPow>(ctx);
110+
}

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
5050
RewritePatternSet patterns(ctx);
5151
auto func = getOperation();
5252

53+
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
5354
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5455
mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns);
5556
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
2+
3+
// CHECK-LABEL: @pow_fold_tiny
4+
func.func @pow_fold_tiny() -> tensor<f32> {
5+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.6{{0*}}e+01{{.*}}tensor<f32>
6+
// CHECK-NOT: tosa.pow
7+
// CHECK: return [[RES]]
8+
%0 = "tosa.const"() {value = dense<4.0> : tensor<f32>} : () -> tensor<f32>
9+
%1 = "tosa.const"() {value = dense<2.0> : tensor<f32>} : () -> tensor<f32>
10+
%2 = "tosa.pow"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
11+
return %2 : tensor<f32>
12+
}
13+
14+
// CHECK-LABEL: @pow_fold_tensor
15+
func.func @pow_fold_tensor() -> tensor<3xf16> {
16+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.56{{0*}}e+02, 1.191410e+00, -3.099610e+00{{.*}}tensor<3xf16>
17+
// CHECK-NOT: tosa.pow
18+
// CHECK: return [[RES]]
19+
%0 = "tosa.const"() {value = dense<[4.0, 2.22, -3.1]> : tensor<3xf16>} : () -> tensor<3xf16>
20+
%1 = "tosa.const"() {value = dense<[4.0, 0.22, 1.0]> : tensor<3xf16>} : () -> tensor<3xf16>
21+
%2 = "tosa.pow"(%0, %1) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
22+
return %2 : tensor<3xf16>
23+
}
24+
25+
// CHECK-LABEL: @pow_fold_overflow
26+
func.func @pow_fold_overflow() -> tensor<2xf16> {
27+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00{{.*}}tensor<2xf16>
28+
// CHECK-NOT: tosa.pow
29+
// CHECK: return [[RES]]
30+
%0 = "tosa.const"() {value = dense<[65500.0, -65500.0]> : tensor<2xf16>} : () -> tensor<2xf16>
31+
%1 = "tosa.const"() {value = dense<[2.0, 3.0]> : tensor<2xf16>} : () -> tensor<2xf16>
32+
%2 = "tosa.pow"(%0, %1) : (tensor<2xf16>, tensor<2xf16>) -> tensor<2xf16>
33+
return %2 : tensor<2xf16>
34+
}
35+
36+
// CHECK-LABEL: @pow_fold_underflow
37+
func.func @pow_fold_underflow() -> tensor<2xf16> {
38+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}[0.0{{0*}}e+00, -0.0{{0*}}e+00{{.*}}tensor<2xf16>
39+
// CHECK-NOT: tosa.pow
40+
// CHECK: return [[RES]]
41+
%0 = "tosa.const"() {value = dense<[0.000001, -0.000001]> : tensor<2xf16>} : () -> tensor<2xf16>
42+
%1 = "tosa.const"() {value = dense<[10.0, 9.0]> : tensor<2xf16>} : () -> tensor<2xf16>
43+
%2 = "tosa.pow"(%0, %1) : (tensor<2xf16>, tensor<2xf16>) -> tensor<2xf16>
44+
return %2 : tensor<2xf16>
45+
}
46+
47+
// CHECK-LABEL: @pow_fold_nan_cases
48+
func.func @pow_fold_nan_cases() -> tensor<3xf32> {
49+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0x7FC00000>{{.*}}tensor<3xf32>
50+
// CHECK-NOT: tosa.pow
51+
// CHECK: return [[RES]]
52+
%0 = "tosa.const"() {value = dense<[0.0, -1.25, 0x7FC00000]> : tensor<3xf32>} : () -> tensor<3xf32>
53+
%1 = "tosa.const"() {value = dense<[0.0, 0.745, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
54+
%2 = "tosa.pow"(%0, %1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
55+
return %2 : tensor<3xf32>
56+
}
57+
58+
// CHECK-LABEL: @pow_fold_tensor_broadcast_exp
59+
func.func @pow_fold_tensor_broadcast_exp() -> tensor<3xf16> {
60+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.6{{0*}}e+01, 4.929690e+00, 9.609370e+00{{.*}}tensor<3xf16>
61+
// CHECK-NOT: tosa.pow
62+
// CHECK: return [[RES]]
63+
%0 = "tosa.const"() {value = dense<[4.0, 2.22, -3.1]> : tensor<3xf16>} : () -> tensor<3xf16>
64+
%1 = "tosa.const"() {value = dense<2.0> : tensor<1xf16>} : () -> tensor<1xf16>
65+
%2 = "tosa.pow"(%0, %1) : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16>
66+
return %2 : tensor<3xf16>
67+
}
68+
69+
// CHECK-LABEL: @pow_fold_tensor_broadcast_base
70+
func.func @pow_fold_tensor_broadcast_base() -> tensor<3xf16> {
71+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.6{{0*}}e+01, 4.660160e+00, 1.166380e-01{{.*}}tensor<3xf16>
72+
// CHECK-NOT: tosa.pow
73+
// CHECK: return [[RES]]
74+
%0 = "tosa.const"() {value = dense<[4.0, 2.22, -3.1]> : tensor<3xf16>} : () -> tensor<3xf16>
75+
%1 = "tosa.const"() {value = dense<2.0> : tensor<1xf16>} : () -> tensor<1xf16>
76+
%2 = "tosa.pow"(%1, %0) : (tensor<1xf16>, tensor<3xf16>) -> tensor<3xf16>
77+
return %2 : tensor<3xf16>
78+
}
79+
80+
// CHECK-LABEL: @pow_fold_broadcast_two_dimensions
81+
func.func @pow_fold_broadcast_two_dimensions() -> tensor<3x3xf32> {
82+
// CHECK: [[RES:]] ={{.*}}tosa.const
83+
// CHECK-SAME{LITERAL}: [[388.023529, 1.102940e+03, 2554.37329],
84+
// CHECK-SAME{LITERAL}: [75281.1328, 538664.813, 0x4A1FF040],
85+
// CHECK-SAME{LITERAL}: [24.2514629, 42.4044418, 66.4508896]]
86+
// CHECK-NOT: tosa.pow
87+
// CHECK: return [[RES]]
88+
%0 = "tosa.const"() {value = dense<[[4.0, 5.1, 6.2]]> : tensor<1x3xf32>} : () -> tensor<1x3xf32>
89+
%1 = "tosa.const"() {value = dense<[[4.3], [8.1], [2.3]]> : tensor<3x1xf32>} : () -> tensor<3x1xf32>
90+
%2 = "tosa.pow"(%0, %1) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<3x3xf32>
91+
return %2 : tensor<3x3xf32>
92+
}

0 commit comments

Comments
 (0)