Skip to content

Commit 32a051d

Browse files
author
Spenser Bauman
committed
Address review feedback from jpiennar
This change addresses some of the additional review feedback on #87234. A summary of the changes: 1. Cleaned up the language to use 'roll back' rather than revert to reduce the chance of confusion. Improved some function names as well. 2. Eliminated string comparisons on dialect names. 3. Prevented the introduction of redundant tensor.cast operations for the same value.
1 parent e329b68 commit 32a051d

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1919
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
2020
#include "mlir/IR/Builders.h"
21+
#include "mlir/IR/ImplicitLocOpBuilder.h"
2122
#include "mlir/Interfaces/InferTypeOpInterface.h"
2223
#include "mlir/Pass/Pass.h"
2324
#include "mlir/Transforms/DialectConversion.h"
@@ -39,30 +40,26 @@ namespace {
3940
// type-inference related interface.
4041
// When a non-replaceable use is encountered, the value is wrapped in a
4142
// cast back to the original type after inference.
42-
bool isReplaceableUser(Operation *user) {
43-
// Handle unregistered dialects.
44-
if (!user->getDialect())
45-
return false;
46-
47-
return user->getDialect()->getNamespace() ==
48-
TosaDialect::getDialectNamespace() ||
43+
bool canBeRefined(Operation *user) {
44+
Dialect *tosaDialect = user->getContext()->getLoadedDialect<TosaDialect>();
45+
return user->getDialect() == tosaDialect ||
4946
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
5047
}
5148

5249
// During type propagation, the types of values in the operator graph are
5350
// updated. For the tosa.while_loop operation, types are speculatively updated
5451
// within the body region to determine the output type of the while_loop. This
5552
// process is performed until a fixed point is reached, then the types are
56-
// reverted.
53+
// rolled back.
5754
//
58-
// This class encapsulates the state information needed to perform the reversion
55+
// This class encapsulates the state information needed to perform the roll back
5956
// process or to commit to the final changes.
6057
class TypeModificationState {
6158
public:
6259
TypeModificationState() = default;
6360

6461
~TypeModificationState() {
65-
// Ensure the recorded modifications are either committed or reverted.
62+
// Ensure the recorded modifications are either committed or rolled back.
6663
assert(oldTypes.empty() && "unhandled type modifications");
6764
}
6865

@@ -74,10 +71,9 @@ class TypeModificationState {
7471
}
7572
}
7673

77-
// Revert changes made to the types in the IR by setting all the affected
74+
// Roll back changes made to the types in the IR by setting all the affected
7875
// values to their old types.
79-
void revert() {
80-
// Otherwise revert the changes.
76+
void rollBack() {
8177
for (auto [value, type] : oldTypes)
8278
value.setType(type);
8379

@@ -91,15 +87,18 @@ class TypeModificationState {
9187
// For each use whose type changed, cast the value with the new type back to
9288
// the old type.
9389
for (auto [value, oldType] : oldTypes) {
90+
tensor::CastOp castedValue;
9491
for (auto &use : value.getUses()) {
95-
if (isReplaceableUser(use.getOwner()))
92+
if (canBeRefined(use.getOwner()))
9693
continue;
9794

98-
OpBuilder builder(value.getContext());
99-
builder.setInsertionPoint(use.getOwner());
95+
// Cache the cast to avoid generating duplicates
96+
if (!castedValue) {
97+
ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
98+
castedValue = builder.create<tensor::CastOp>(oldType, value);
99+
}
100100

101-
Location loc = value.getLoc();
102-
use.set(builder.create<tensor::CastOp>(loc, oldType, value));
101+
use.set(castedValue);
103102
}
104103
}
105104

@@ -211,8 +210,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
211210
argTypes[i] = newType;
212211
}
213212

214-
// Revert all changes made during the speculative part of the algorithm.
215-
localState.revert();
213+
// Roll back all changes made during the speculative part of the algorithm.
214+
localState.rollBack();
216215
}
217216

218217
// We now set the block arguments according to the most recent shape
@@ -228,10 +227,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
228227
}
229228

230229
void propagateShapesInRegion(Region &region, TypeModificationState &state) {
230+
Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
231+
231232
for (auto &block : region) {
232233
for (Operation &op : block) {
233-
if (!op.getDialect() ||
234-
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
234+
if (op.getDialect() != tosaDialect)
235235
continue;
236236

237237
propagateShapesToTosaIf(op, state);

0 commit comments

Comments
 (0)