@@ -183,17 +183,27 @@ void propagateShapesToTosaWhile(Operation &op) {
183
183
}
184
184
}
185
185
186
+ // Track the old type for each operand whose type was updated
187
+ // during inference. This information is used to introduce casts
188
+ // back to the type expected by the operand after inference.
189
+ struct TypeRewriteInfo {
190
+ OpOperand *operand;
191
+ Type oldType;
192
+ };
193
+
186
194
void propagateShapesInRegion (Region ®ion) {
187
195
// Check whether this use case is replaceable. We define an op as
188
- // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
196
+ // being replaceable if it is used by a TosaOp, or an op with a
189
197
// type-inference related interface.
198
+ // When a non-replaceable use is encountered, the value is wrapped in a
199
+ // cast back to the original type after inference.
190
200
auto isReplaceableUser = [](Operation *user) -> bool {
191
- return isa<func::ReturnOp>(user) ||
192
- user->getDialect ()->getNamespace () ==
201
+ return user->getDialect ()->getNamespace () ==
193
202
TosaDialect::getDialectNamespace () ||
194
203
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
195
204
};
196
205
206
+ llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
197
207
for (auto &block : region) {
198
208
for (Operation &op : block) {
199
209
if (op.getDialect ()->getNamespace () != TosaDialect::getDialectNamespace ())
@@ -219,9 +229,6 @@ void propagateShapesInRegion(Region ®ion) {
219
229
Value result = std::get<0 >(it);
220
230
ShapedTypeComponents predictedShape = std::get<1 >(it);
221
231
222
- if (!llvm::all_of (result.getUsers (), isReplaceableUser))
223
- continue ;
224
-
225
232
// Determine the knowledge based on the output type.
226
233
// TODO: should also query WIP type probably
227
234
Type resultTy = result.getType ();
@@ -246,10 +253,29 @@ void propagateShapesInRegion(Region ®ion) {
246
253
247
254
// Set new type
248
255
result.setType (newKnowledge.getType ());
256
+
257
+ // Collect all uses of the operation which require update.
258
+ for (auto &user : result.getUses ()) {
259
+ if (!isReplaceableUser (user.getOwner ()))
260
+ requiresUpdate.push_back ({&user, resultTy});
261
+ }
249
262
}
250
263
}
251
264
}
252
265
}
266
+
267
+ // For each use whose type changed, cast the value with the new type back to
268
+ // the old type.
269
+ IRRewriter rewriter (region.getContext ());
270
+ for (auto [operand, oldType] : requiresUpdate) {
271
+ rewriter.setInsertionPoint (operand->getOwner ());
272
+
273
+ auto oldValue = operand->get ();
274
+
275
+ auto loc = oldValue.getLoc ();
276
+ auto castOp = rewriter.create <tensor::CastOp>(loc, oldType, oldValue);
277
+ operand->set (castOp);
278
+ }
253
279
}
254
280
255
281
// / Pass that performs shape propagation across TOSA operations. This includes
@@ -259,44 +285,7 @@ struct TosaInferShapes
259
285
public:
260
286
void runOnOperation () override {
261
287
func::FuncOp func = getOperation ();
262
-
263
- IRRewriter rewriter (func.getContext ());
264
-
265
288
propagateShapesInRegion (func.getBody ());
266
-
267
- // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
268
- // the FuncOp type.
269
- func.walk ([&](func::ReturnOp op) {
270
- func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp ());
271
- if (!parent)
272
- return ;
273
-
274
- rewriter.setInsertionPoint (op);
275
- FunctionType funcTy = func.getFunctionType ();
276
- auto resultTys = funcTy.getResults ();
277
-
278
- bool castAdded = false ;
279
- SmallVector<Value> castedValues;
280
- for (auto it : llvm::zip (op->getOperands (), resultTys)) {
281
- auto operand = std::get<0 >(it);
282
- auto currentTy = operand.getType ();
283
- auto castTy = std::get<1 >(it);
284
- if (currentTy == castTy) {
285
- castedValues.push_back (operand);
286
- continue ;
287
- }
288
-
289
- castedValues.push_back (
290
- rewriter.create <tensor::CastOp>(op.getLoc (), castTy, operand)
291
- .getResult ());
292
-
293
- castAdded = true ;
294
- }
295
-
296
- if (castAdded) {
297
- rewriter.replaceOpWithNewOp <func::ReturnOp>(op, castedValues);
298
- }
299
- });
300
289
}
301
290
};
302
291
} // namespace
0 commit comments