Skip to content

Commit 2f88299

Browse files
committed
Clean-up for TOSA pow folding
Address review comments: * Add names to all function arguments in header * `toBeBroadcasted` -> `toBeBroadcastedShape` * Describe return values for broadcasting functions * Add shortcut for the offset computation if no broadcasting is needed
1 parent a8712f6 commit 2f88299

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,23 @@ DenseElementsAttr applyElementWise(
3636
/// tensors. If the input tensors do not match \p targetType, broadcasting is
3737
/// applied.
3838
DenseElementsAttr applyElementWise(
39-
const DenseElementsAttr &, const DenseElementsAttr &, TensorType targetType,
39+
const DenseElementsAttr &first, const DenseElementsAttr &second,
40+
TensorType targetType,
4041
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply);
4142

4243
/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
4344
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
4445
TosaOp location,
45-
PatternRewriter &);
46+
PatternRewriter &rewriter);
4647

4748
/// Function that checks if \p toCheck is a dense TOSA constant tensor.
4849
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
4950
TosaOp location,
50-
PatternRewriter &);
51+
PatternRewriter &rewriter);
5152

5253
/// Function that checks if the type contained in \p toCheck is float.
5354
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
54-
PatternRewriter &);
55+
PatternRewriter &rewriter);
5556

5657
/// Compute the offset in \p shape which corresponds to the given \p index.
5758
OffsetType indexToOffset(DimensionType shape, DimensionType index);
@@ -60,18 +61,20 @@ OffsetType indexToOffset(DimensionType shape, DimensionType index);
6061
SmallVector<int64_t> offsetToIndex(DimensionType shape, OffsetType offset);
6162

6263
/// Given an \p index into \p desiredShape, compute the corresponding index into
63-
/// \p toBeBroadcasted.
64+
/// \p toBeBroadcastedShape.
65+
/// \returns broadcasted index into \p toBeBroadcastedShape.
6466
SmallVector<int64_t> getBroadcastedIndex(DimensionType desiredShape,
65-
DimensionType toBeBroadcasted,
67+
DimensionType toBeBroadcastedShape,
6668
DimensionType index);
6769
/// Given an \p offset into \p desiredShape, compute the corresponding offset
68-
/// into \p toBeBroadcasted.
70+
/// into \p toBeBroadcastedShape.
71+
/// \returns broadcasted offset into \p toBeBroadcastedShape.
6972
OffsetType getBroadcastedOffset(DimensionType desiredShape,
70-
DimensionType toBeBroadcasted,
73+
DimensionType toBeBroadcastedShape,
7174
OffsetType offset);
7275

7376
/// Function to compute the reciprocal.
74-
APFloat computeReciprocal(const APFloat &, Type);
77+
APFloat computeReciprocal(const APFloat &floatVal, Type floatTy);
7578

7679
} // namespace tosa
7780
} // namespace mlir

mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
1414
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15-
#include <algorithm>
1615
#include <llvm/ADT/APFloat.h>
1716
#include <llvm/ADT/SmallVector.h>
17+
#include <algorithm>
1818
#include <mlir/IR/BuiltinAttributes.h>
1919
#include <mlir/IR/BuiltinTypes.h>
2020
#include <mlir/IR/Matchers.h>
@@ -155,13 +155,13 @@ SmallVector<int64_t> mlir::tosa::offsetToIndex(DimensionType shape,
155155

156156
SmallVector<int64_t>
157157
mlir::tosa::getBroadcastedIndex(DimensionType desiredShape,
158-
DimensionType toBeBroadcasted,
158+
DimensionType toBeBroadcastedShape,
159159
DimensionType index) {
160160
SmallVector<int64_t> broadCasted;
161161
broadCasted.reserve(desiredShape.size());
162162
for (size_t i = 0; i < desiredShape.size(); i++) {
163163
auto toInsert = 0;
164-
if (toBeBroadcasted[i] == desiredShape[i]) {
164+
if (toBeBroadcastedShape[i] == desiredShape[i]) {
165165
toInsert = index[i];
166166
}
167167
broadCasted.push_back(toInsert);
@@ -170,12 +170,16 @@ mlir::tosa::getBroadcastedIndex(DimensionType desiredShape,
170170
}
171171

172172
OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
173-
DimensionType toBeBroadcasted,
173+
DimensionType toBeBroadcastedShape,
174174
OffsetType offset) {
175+
// Simply return the offset if the shapes are equal.
176+
if (desiredShape.equals(toBeBroadcastedShape)) {
177+
return offset;
178+
}
175179
auto indexInTarget = offsetToIndex(desiredShape, offset);
176180
auto indexBroadcasted =
177-
getBroadcastedIndex(desiredShape, toBeBroadcasted, indexInTarget);
178-
return indexToOffset(toBeBroadcasted, indexBroadcasted);
181+
getBroadcastedIndex(desiredShape, toBeBroadcastedShape, indexInTarget);
182+
return indexToOffset(toBeBroadcastedShape, indexBroadcasted);
179183
}
180184

181185
APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {

0 commit comments

Comments
 (0)