Skip to content

[mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes #87660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -39,30 +40,27 @@ namespace {
// type-inference related interface.
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
bool isReplaceableUser(Operation *user) {
// Handle unregistered dialects.
bool canBeRefined(Operation *user) {
if (!user->getDialect())
return false;

return user->getDialect()->getNamespace() ==
TosaDialect::getDialectNamespace() ||
return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
}

// During type propagation, the types of values in the operator graph are
// updated. For the tosa.while_loop operation, types are speculatively updated
// within the body region to determine the output type of the while_loop. This
// process is performed until a fixed point is reached, then the types are
// reverted.
// rolled back.
//
// This class encapsulates the state information needed to perform the reversion
// This class encapsulates the state information needed to perform the roll back
// process or to commit to the final changes.
class TypeModificationState {
public:
TypeModificationState() = default;

~TypeModificationState() {
// Ensure the recorded modifications are either committed or reverted.
// Ensure the recorded modifications are either committed or rolled back.
assert(oldTypes.empty() && "unhandled type modifications");
}

Expand All @@ -74,10 +72,9 @@ class TypeModificationState {
}
}

// Revert changes made to the types in the IR by setting all the affected
// Roll back changes made to the types in the IR by setting all the affected
// values to their old types.
void revert() {
// Otherwise revert the changes.
void rollBack() {
for (auto [value, type] : oldTypes)
value.setType(type);

Expand All @@ -91,15 +88,18 @@ class TypeModificationState {
// For each use whose type changed, cast the value with the new type back to
// the old type.
for (auto [value, oldType] : oldTypes) {
tensor::CastOp castedValue;
for (auto &use : value.getUses()) {
if (isReplaceableUser(use.getOwner()))
if (canBeRefined(use.getOwner()))
continue;

OpBuilder builder(value.getContext());
builder.setInsertionPoint(use.getOwner());
// Cache the cast to avoid generating duplicates
if (!castedValue) {
ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
castedValue = builder.create<tensor::CastOp>(oldType, value);
}

Location loc = value.getLoc();
use.set(builder.create<tensor::CastOp>(loc, oldType, value));
use.set(castedValue);
}
}

Expand Down Expand Up @@ -211,8 +211,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
argTypes[i] = newType;
}

// Revert all changes made during the speculative part of the algorithm.
localState.revert();
// Roll back all changes made during the speculative part of the algorithm.
localState.rollBack();
}

// We now set the block arguments according to the most recent shape
Expand All @@ -228,10 +228,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
}

void propagateShapesInRegion(Region &region, TypeModificationState &state) {
Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();

for (auto &block : region) {
for (Operation &op : block) {
if (!op.getDialect() ||
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
if (op.getDialect() != tosaDialect)
continue;

propagateShapesToTosaIf(op, state);
Expand Down