Skip to content

[mlir][tosa] Improve tosa-infer-shapes for ops consumed by non-TOSA operators #72715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 32 additions & 43 deletions mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,27 @@ void propagateShapesToTosaWhile(Operation &op) {
}
}

// Track the old type for each operand whose type was updated
// during inference. This information is used to introduce casts
// back to the type expected by the operand after inference.
struct TypeRewriteInfo {
OpOperand *operand;
Type oldType;
};

void propagateShapesInRegion(Region &region) {
// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
// being replaceable if it is used by a TosaOp, or an op with a
// type-inference related interface.
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
auto isReplaceableUser = [](Operation *user) -> bool {
return isa<func::ReturnOp>(user) ||
user->getDialect()->getNamespace() ==
return user->getDialect()->getNamespace() ==
TosaDialect::getDialectNamespace() ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
};

llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
Expand All @@ -219,9 +229,6 @@ void propagateShapesInRegion(Region &region) {
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);

if (!llvm::all_of(result.getUsers(), isReplaceableUser))
continue;

// Determine the knowledge based on the output type.
// TODO: should also query WIP type probably
Type resultTy = result.getType();
Expand All @@ -246,10 +253,29 @@ void propagateShapesInRegion(Region &region) {

// Set new type
result.setType(newKnowledge.getType());

// Collect all uses of the operation which require update.
for (auto &user : result.getUses()) {
if (!isReplaceableUser(user.getOwner()))
requiresUpdate.push_back({&user, resultTy});
}
}
}
}
}

// For each use whose type changed, cast the value with the new type back to
// the old type.
IRRewriter rewriter(region.getContext());
for (auto [operand, oldType] : requiresUpdate) {
rewriter.setInsertionPoint(operand->getOwner());

auto oldValue = operand->get();

auto loc = oldValue.getLoc();
auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
operand->set(castOp);
}
}

/// Pass that performs shape propagation across TOSA operations. This includes
Expand All @@ -259,44 +285,7 @@ struct TosaInferShapes
public:
void runOnOperation() override {
func::FuncOp func = getOperation();

IRRewriter rewriter(func.getContext());

propagateShapesInRegion(func.getBody());

// Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
// the FuncOp type.
func.walk([&](func::ReturnOp op) {
func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp());
if (!parent)
return;

rewriter.setInsertionPoint(op);
FunctionType funcTy = func.getFunctionType();
auto resultTys = funcTy.getResults();

bool castAdded = false;
SmallVector<Value> castedValues;
for (auto it : llvm::zip(op->getOperands(), resultTys)) {
auto operand = std::get<0>(it);
auto currentTy = operand.getType();
auto castTy = std::get<1>(it);
if (currentTy == castTy) {
castedValues.push_back(operand);
continue;
}

castedValues.push_back(
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
.getResult());

castAdded = true;
}

if (castAdded) {
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, castedValues);
}
});
}
};
} // namespace
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,17 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)

// -----

// CHECK-LABEL: test_non_tosa_consumer_still_propagates
func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x?xf32> {
// CHECK: tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<1x1x1xf32>
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
%1 = arith.constant dense<[1, 1]> : tensor<2xindex>
%2 = tensor.reshape %0(%1) : (tensor<?x1x1xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: test_tosa_use_def_chain
func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
// CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
Expand Down Expand Up @@ -1298,3 +1309,4 @@ func.func @test_large_constant_permutation() {
%72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
return
}