18
18
#include " mlir/Dialect/Tosa/IR/TosaOps.h"
19
19
#include " mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20
20
#include " mlir/IR/Builders.h"
21
+ #include " mlir/IR/ImplicitLocOpBuilder.h"
21
22
#include " mlir/Interfaces/InferTypeOpInterface.h"
22
23
#include " mlir/Pass/Pass.h"
23
24
#include " mlir/Transforms/DialectConversion.h"
@@ -39,30 +40,27 @@ namespace {
39
40
// type-inference related interface.
40
41
// When a non-replaceable use is encountered, the value is wrapped in a
41
42
// cast back to the original type after inference.
42
- bool isReplaceableUser (Operation *user) {
43
- // Handle unregistered dialects.
43
+ bool canBeRefined (Operation *user) {
44
44
if (!user->getDialect ())
45
45
return false ;
46
-
47
- return user->getDialect ()->getNamespace () ==
48
- TosaDialect::getDialectNamespace () ||
46
+ return user->getDialect ()->getTypeID () == TypeID::get<TosaDialect>() ||
49
47
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
50
48
}
51
49
52
50
// During type propagation, the types of values in the operator graph are
53
51
// updated. For the tosa.while_loop operation, types are speculatively updated
54
52
// within the body region to determine the output type of the while_loop. This
55
53
// process is performed until a fixed point is reached, then the types are
56
- // reverted .
54
+ // rolled back .
57
55
//
58
- // This class encapsulates the state information needed to perform the reversion
56
+ // This class encapsulates the state information needed to perform the roll back
59
57
// process or to commit to the final changes.
60
58
class TypeModificationState {
61
59
public:
62
60
TypeModificationState () = default ;
63
61
64
62
~TypeModificationState () {
65
- // Ensure the recorded modifications are either committed or reverted .
63
+ // Ensure the recorded modifications are either committed or rolled back .
66
64
assert (oldTypes.empty () && " unhandled type modifications" );
67
65
}
68
66
@@ -74,10 +72,9 @@ class TypeModificationState {
74
72
}
75
73
}
76
74
77
- // Revert changes made to the types in the IR by setting all the affected
75
+ // Roll back changes made to the types in the IR by setting all the affected
78
76
// values to their old types.
79
- void revert () {
80
- // Otherwise revert the changes.
77
+ void rollBack () {
81
78
for (auto [value, type] : oldTypes)
82
79
value.setType (type);
83
80
@@ -91,15 +88,18 @@ class TypeModificationState {
91
88
// For each use whose type changed, cast the value with the new type back to
92
89
// the old type.
93
90
for (auto [value, oldType] : oldTypes) {
91
+ tensor::CastOp castedValue;
94
92
for (auto &use : value.getUses ()) {
95
- if (isReplaceableUser (use.getOwner ()))
93
+ if (canBeRefined (use.getOwner ()))
96
94
continue ;
97
95
98
- OpBuilder builder (value.getContext ());
99
- builder.setInsertionPoint (use.getOwner ());
96
+ // Cache the cast to avoid generating duplicates
97
+ if (!castedValue) {
98
+ ImplicitLocOpBuilder builder{value.getLoc (), use.getOwner ()};
99
+ castedValue = builder.create <tensor::CastOp>(oldType, value);
100
+ }
100
101
101
- Location loc = value.getLoc ();
102
- use.set (builder.create <tensor::CastOp>(loc, oldType, value));
102
+ use.set (castedValue);
103
103
}
104
104
}
105
105
@@ -211,8 +211,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
211
211
argTypes[i] = newType;
212
212
}
213
213
214
- // Revert all changes made during the speculative part of the algorithm.
215
- localState.revert ();
214
+ // Roll back all changes made during the speculative part of the algorithm.
215
+ localState.rollBack ();
216
216
}
217
217
218
218
// We now set the block arguments according to the most recent shape
@@ -228,10 +228,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
228
228
}
229
229
230
230
void propagateShapesInRegion (Region ®ion, TypeModificationState &state) {
231
+ Dialect *tosaDialect = region.getContext ()->getLoadedDialect <TosaDialect>();
232
+
231
233
for (auto &block : region) {
232
234
for (Operation &op : block) {
233
- if (!op.getDialect () ||
234
- op.getDialect ()->getNamespace () != TosaDialect::getDialectNamespace ())
235
+ if (op.getDialect () != tosaDialect)
235
236
continue ;
236
237
237
238
propagateShapesToTosaIf (op, state);
0 commit comments