Skip to content

Commit e513f2c

Browse files
sabaumaSpenser Bauman
andauthored
[mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes (#87660)
This change addresses some of the additional review feedback on #87234. A summary of the changes: 1. Cleaned up the language to use 'roll back' rather than revert to reduce the chance of confusion. Improved some function names as well. 2. Eliminated string comparisons on dialect names. 3. Prevented the introduction of redundant tensor.cast operations for the same value. --------- Co-authored-by: Spenser Bauman <sabauma@fastmail>
1 parent 514d80b commit e513f2c

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1919
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
2020
#include "mlir/IR/Builders.h"
21+
#include "mlir/IR/ImplicitLocOpBuilder.h"
2122
#include "mlir/Interfaces/InferTypeOpInterface.h"
2223
#include "mlir/Pass/Pass.h"
2324
#include "mlir/Transforms/DialectConversion.h"
@@ -39,30 +40,27 @@ namespace {
3940
// type-inference related interface.
4041
// When a non-replaceable use is encountered, the value is wrapped in a
4142
// cast back to the original type after inference.
42-
bool isReplaceableUser(Operation *user) {
43-
// Handle unregistered dialects.
43+
bool canBeRefined(Operation *user) {
4444
if (!user->getDialect())
4545
return false;
46-
47-
return user->getDialect()->getNamespace() ==
48-
TosaDialect::getDialectNamespace() ||
46+
return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
4947
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
5048
}
5149

5250
// During type propagation, the types of values in the operator graph are
5351
// updated. For the tosa.while_loop operation, types are speculatively updated
5452
// within the body region to determine the output type of the while_loop. This
5553
// process is performed until a fixed point is reached, then the types are
56-
// reverted.
54+
// rolled back.
5755
//
58-
// This class encapsulates the state information needed to perform the reversion
56+
// This class encapsulates the state information needed to perform the roll back
5957
// process or to commit to the final changes.
6058
class TypeModificationState {
6159
public:
6260
TypeModificationState() = default;
6361

6462
~TypeModificationState() {
65-
// Ensure the recorded modifications are either committed or reverted.
63+
// Ensure the recorded modifications are either committed or rolled back.
6664
assert(oldTypes.empty() && "unhandled type modifications");
6765
}
6866

@@ -74,10 +72,9 @@ class TypeModificationState {
7472
}
7573
}
7674

77-
// Revert changes made to the types in the IR by setting all the affected
75+
// Roll back changes made to the types in the IR by setting all the affected
7876
// values to their old types.
79-
void revert() {
80-
// Otherwise revert the changes.
77+
void rollBack() {
8178
for (auto [value, type] : oldTypes)
8279
value.setType(type);
8380

@@ -91,15 +88,18 @@ class TypeModificationState {
9188
// For each use whose type changed, cast the value with the new type back to
9289
// the old type.
9390
for (auto [value, oldType] : oldTypes) {
91+
tensor::CastOp castedValue;
9492
for (auto &use : value.getUses()) {
95-
if (isReplaceableUser(use.getOwner()))
93+
if (canBeRefined(use.getOwner()))
9694
continue;
9795

98-
OpBuilder builder(value.getContext());
99-
builder.setInsertionPoint(use.getOwner());
96+
// Cache the cast to avoid generating duplicates
97+
if (!castedValue) {
98+
ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
99+
castedValue = builder.create<tensor::CastOp>(oldType, value);
100+
}
100101

101-
Location loc = value.getLoc();
102-
use.set(builder.create<tensor::CastOp>(loc, oldType, value));
102+
use.set(castedValue);
103103
}
104104
}
105105

@@ -211,8 +211,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
211211
argTypes[i] = newType;
212212
}
213213

214-
// Revert all changes made during the speculative part of the algorithm.
215-
localState.revert();
214+
// Roll back all changes made during the speculative part of the algorithm.
215+
localState.rollBack();
216216
}
217217

218218
// We now set the block arguments according to the most recent shape
@@ -228,10 +228,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
228228
}
229229

230230
void propagateShapesInRegion(Region &region, TypeModificationState &state) {
231+
Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
232+
231233
for (auto &block : region) {
232234
for (Operation &op : block) {
233-
if (!op.getDialect() ||
234-
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
235+
if (op.getDialect() != tosaDialect)
235236
continue;
236237

237238
propagateShapesToTosaIf(op, state);

0 commit comments

Comments
 (0)