18
18
#include " mlir/Dialect/Tosa/IR/TosaOps.h"
19
19
#include " mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20
20
#include " mlir/IR/Builders.h"
21
+ #include " mlir/IR/ImplicitLocOpBuilder.h"
21
22
#include " mlir/Interfaces/InferTypeOpInterface.h"
22
23
#include " mlir/Pass/Pass.h"
23
24
#include " mlir/Transforms/DialectConversion.h"
@@ -39,30 +40,26 @@ namespace {
39
40
// type-inference related interface.
40
41
// When a non-replaceable use is encountered, the value is wrapped in a
41
42
// cast back to the original type after inference.
42
- bool isReplaceableUser (Operation *user) {
43
- // Handle unregistered dialects.
44
- if (!user->getDialect ())
45
- return false ;
46
-
47
- return user->getDialect ()->getNamespace () ==
48
- TosaDialect::getDialectNamespace () ||
43
+ bool canBeRefined (Operation *user) {
44
+ Dialect *tosaDialect = user->getContext ()->getLoadedDialect <TosaDialect>();
45
+ return user->getDialect () == tosaDialect ||
49
46
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
50
47
}
51
48
52
49
// During type propagation, the types of values in the operator graph are
53
50
// updated. For the tosa.while_loop operation, types are speculatively updated
54
51
// within the body region to determine the output type of the while_loop. This
55
52
// process is performed until a fixed point is reached, then the types are
56
- // reverted .
53
+ // rolled back .
57
54
//
58
- // This class encapsulates the state information needed to perform the reversion
55
+ // This class encapsulates the state information needed to perform the roll back
59
56
// process or to commit to the final changes.
60
57
class TypeModificationState {
61
58
public:
62
59
TypeModificationState () = default ;
63
60
64
61
~TypeModificationState () {
65
- // Ensure the recorded modifications are either committed or reverted .
62
+ // Ensure the recorded modifications are either committed or rolled back .
66
63
assert (oldTypes.empty () && " unhandled type modifications" );
67
64
}
68
65
@@ -74,10 +71,9 @@ class TypeModificationState {
74
71
}
75
72
}
76
73
77
- // Revert changes made to the types in the IR by setting all the affected
74
+ // Roll back changes made to the types in the IR by setting all the affected
78
75
// values to their old types.
79
- void revert () {
80
- // Otherwise revert the changes.
76
+ void rollBack () {
81
77
for (auto [value, type] : oldTypes)
82
78
value.setType (type);
83
79
@@ -91,15 +87,18 @@ class TypeModificationState {
91
87
// For each use whose type changed, cast the value with the new type back to
92
88
// the old type.
93
89
for (auto [value, oldType] : oldTypes) {
90
+ tensor::CastOp castedValue;
94
91
for (auto &use : value.getUses ()) {
95
- if (isReplaceableUser (use.getOwner ()))
92
+ if (canBeRefined (use.getOwner ()))
96
93
continue ;
97
94
98
- OpBuilder builder (value.getContext ());
99
- builder.setInsertionPoint (use.getOwner ());
95
+ // Cache the cast to avoid generating duplicates
96
+ if (!castedValue) {
97
+ ImplicitLocOpBuilder builder{value.getLoc (), use.getOwner ()};
98
+ castedValue = builder.create <tensor::CastOp>(oldType, value);
99
+ }
100
100
101
- Location loc = value.getLoc ();
102
- use.set (builder.create <tensor::CastOp>(loc, oldType, value));
101
+ use.set (castedValue);
103
102
}
104
103
}
105
104
@@ -211,8 +210,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
211
210
argTypes[i] = newType;
212
211
}
213
212
214
- // Revert all changes made during the speculative part of the algorithm.
215
- localState.revert ();
213
+ // Roll back all changes made during the speculative part of the algorithm.
214
+ localState.rollBack ();
216
215
}
217
216
218
217
// We now set the block arguments according to the most recent shape
@@ -228,10 +227,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
228
227
}
229
228
230
229
void propagateShapesInRegion (Region ®ion, TypeModificationState &state) {
230
+ Dialect *tosaDialect = region.getContext ()->getLoadedDialect <TosaDialect>();
231
+
231
232
for (auto &block : region) {
232
233
for (Operation &op : block) {
233
- if (!op.getDialect () ||
234
- op.getDialect ()->getNamespace () != TosaDialect::getDialectNamespace ())
234
+ if (op.getDialect () != tosaDialect)
235
235
continue ;
236
236
237
237
propagateShapesToTosaIf (op, state);
0 commit comments