Skip to content

Commit 852f6be

Browse files
authored
[mlir][tosa] Improve tosa-infer-shapes for ops consumed by non-TOSA operators (llvm#72715)
TOSA operators consumed by non-TOSA ops generally do not have their types inferred, as that would alter the types expected by their consumers. This prevents type refinement on many TOSA operators when the IR contains a mix of dialects. This change modifies tosa-infer-shapes to update the types of all TOSA operators during inference. When a consumer of that TOSA op is not safe to update, a tensor.cast is inserted back to the original type. This behavior is similar to how TOSA ops consumed by func.return are handled. This allows for more type refinement of TOSA ops, and the additional tensor.cast operators may be removed by later canonicalizations.
1 parent 8c13099 commit 852f6be

File tree

2 files changed

+44
-43
lines changed

2 files changed

+44
-43
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,27 @@ void propagateShapesToTosaWhile(Operation &op) {
183183
}
184184
}
185185

186+
// Track the old type for each operand whose type was updated
187+
// during inference. This information is used to introduce casts
188+
// back to the type expected by the operand after inference.
189+
struct TypeRewriteInfo {
190+
OpOperand *operand;
191+
Type oldType;
192+
};
193+
186194
void propagateShapesInRegion(Region &region) {
187195
// Check whether this use case is replaceable. We define an op as
188-
// being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
196+
// being replaceable if it is used by a TosaOp, or an op with a
189197
// type-inference related interface.
198+
// When a non-replaceable use is encountered, the value is wrapped in a
199+
// cast back to the original type after inference.
190200
auto isReplaceableUser = [](Operation *user) -> bool {
191-
return isa<func::ReturnOp>(user) ||
192-
user->getDialect()->getNamespace() ==
201+
return user->getDialect()->getNamespace() ==
193202
TosaDialect::getDialectNamespace() ||
194203
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
195204
};
196205

206+
llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
197207
for (auto &block : region) {
198208
for (Operation &op : block) {
199209
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
@@ -219,9 +229,6 @@ void propagateShapesInRegion(Region &region) {
219229
Value result = std::get<0>(it);
220230
ShapedTypeComponents predictedShape = std::get<1>(it);
221231

222-
if (!llvm::all_of(result.getUsers(), isReplaceableUser))
223-
continue;
224-
225232
// Determine the knowledge based on the output type.
226233
// TODO: should also query WIP type probably
227234
Type resultTy = result.getType();
@@ -246,10 +253,29 @@ void propagateShapesInRegion(Region &region) {
246253

247254
// Set new type
248255
result.setType(newKnowledge.getType());
256+
257+
// Collect all uses of the operation which require update.
258+
for (auto &user : result.getUses()) {
259+
if (!isReplaceableUser(user.getOwner()))
260+
requiresUpdate.push_back({&user, resultTy});
261+
}
249262
}
250263
}
251264
}
252265
}
266+
267+
// For each use whose type changed, cast the value with the new type back to
268+
// the old type.
269+
IRRewriter rewriter(region.getContext());
270+
for (auto [operand, oldType] : requiresUpdate) {
271+
rewriter.setInsertionPoint(operand->getOwner());
272+
273+
auto oldValue = operand->get();
274+
275+
auto loc = oldValue.getLoc();
276+
auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
277+
operand->set(castOp);
278+
}
253279
}
254280

255281
/// Pass that performs shape propagation across TOSA operations. This includes
@@ -259,44 +285,7 @@ struct TosaInferShapes
259285
public:
260286
void runOnOperation() override {
261287
func::FuncOp func = getOperation();
262-
263-
IRRewriter rewriter(func.getContext());
264-
265288
propagateShapesInRegion(func.getBody());
266-
267-
// Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
268-
// the FuncOp type.
269-
func.walk([&](func::ReturnOp op) {
270-
func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp());
271-
if (!parent)
272-
return;
273-
274-
rewriter.setInsertionPoint(op);
275-
FunctionType funcTy = func.getFunctionType();
276-
auto resultTys = funcTy.getResults();
277-
278-
bool castAdded = false;
279-
SmallVector<Value> castedValues;
280-
for (auto it : llvm::zip(op->getOperands(), resultTys)) {
281-
auto operand = std::get<0>(it);
282-
auto currentTy = operand.getType();
283-
auto castTy = std::get<1>(it);
284-
if (currentTy == castTy) {
285-
castedValues.push_back(operand);
286-
continue;
287-
}
288-
289-
castedValues.push_back(
290-
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
291-
.getResult());
292-
293-
castAdded = true;
294-
}
295-
296-
if (castAdded) {
297-
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, castedValues);
298-
}
299-
});
300289
}
301290
};
302291
} // namespace

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,17 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
12621262

12631263
// -----
12641264

1265+
// CHECK-LABEL: test_non_tosa_consumer_still_propagates
1266+
func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x?xf32> {
1267+
// CHECK: tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<1x1x1xf32>
1268+
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
1269+
%1 = arith.constant dense<[1, 1]> : tensor<2xindex>
1270+
%2 = tensor.reshape %0(%1) : (tensor<?x1x1xf32>, tensor<2xindex>) -> tensor<?x?xf32>
1271+
return %2 : tensor<?x?xf32>
1272+
}
1273+
1274+
// -----
1275+
12651276
// CHECK-LABEL: test_tosa_use_def_chain
12661277
func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
12671278
// CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
@@ -1298,3 +1309,4 @@ func.func @test_large_constant_permutation() {
12981309
%72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
12991310
return
13001311
}
1312+

0 commit comments

Comments
 (0)