Skip to content

[mlir][tosa] Fix tosa-infer-shapes crash #87234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 103 additions & 95 deletions mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"

namespace mlir {
namespace tosa {
Expand All @@ -39,9 +34,87 @@ using namespace mlir::tosa;

namespace {

void propagateShapesInRegion(Region &region);
// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a TosaOp, or an op with a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC what TOSA ops do not have the type inference interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried removing that condition. Based on the how the tests failed, at least tosa.yield. Everthing in mlir/Dialect/Tosa/IR/TosaUtilOps.td lacks an inference interface as well.

// type-inference related interface.
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
bool isReplaceableUser(Operation *user) {
// Handle unregistered dialects.
if (!user->getDialect())
return false;

return user->getDialect()->getNamespace() ==
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is a string compare for dialects, comparing the id's would be cheaper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code already existed in the pass. I just moved it into a separate function so it could be used elsewhere.

TosaDialect::getDialectNamespace() ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
}

// During type propagation, the types of values in the operator graph are
// updated. For the tosa.while_loop operation, types are speculatively updated
// within the body region to determine the output type of the while_loop. This
// process is performed until a fixed point is reached, then the types are
// reverted.
//
// This class encapsulates the state information needed to perform the reversion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/reversion/roll back/ ? (just given that's the term used in dialect conversion)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can update that. Given that I've already merged the change, I'll create a followup PR to address your comments.

// process or to commit to the final changes.
class TypeModificationState {
public:
TypeModificationState() = default;

~TypeModificationState() {
// Ensure the recorded modifications are either committed or reverted.
assert(oldTypes.empty() && "unhandled type modifications");
}

// Update the state of the value and record the old type.
void setType(Value value, Type type) {
if (value.getType() != type) {
oldTypes.emplace_back(value, value.getType());
value.setType(type);
}
}

void propagateShapesToTosaIf(Operation &op) {
// Revert changes made to the types in the IR by setting all the affected
// values to their old types.
void revert() {
// Otherwise revert the changes.
for (auto [value, type] : oldTypes)
value.setType(type);

oldTypes.clear();
}

// Commit the changes to the types in the IR.
// This requires inserting tensor.cast operations to mediate the newly
// inferred result types with users that do not support type inference.
void commit() {
// For each use whose type changed, cast the value with the new type back to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast is only where type information can't be propagated/refined (e.g., isReplaceableUser == canBeRefined).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. canBeRefined would be a better name.

// the old type.
for (auto [value, oldType] : oldTypes) {
for (auto &use : value.getUses()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you have a user that uses the result of the op multiple times?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic will generate 1 tensor.cast operation per use. That was true of the prior logic, though.

if (isReplaceableUser(use.getOwner()))
continue;

OpBuilder builder(value.getContext());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can combine these 3 instructions by using

ImplicitLocOpBuilder builder(value.getLoc(), use.getOwner());

builder.setInsertionPoint(use.getOwner());

Location loc = value.getLoc();
use.set(builder.create<tensor::CastOp>(loc, oldType, value));
}
}

oldTypes.clear();
}

private:
// A record of each value whose type was updated along with that value's
// previous type.
llvm::SmallVector<std::pair<Value, Type>> oldTypes;
};

void propagateShapesInRegion(Region &region, TypeModificationState &state);

void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
Expand All @@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {

if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
blockArg.setType(newType);
state.setType(blockArg, newType);
}
}

Expand All @@ -71,64 +144,44 @@ void propagateShapesToTosaIf(Operation &op) {
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
frontBlock.getArgument(i).setType(joinedKnowledge.getType());
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}

propagateShapesInRegion(region);
propagateShapesInRegion(region, state);
}
}

void propagateShapesToTosaWhile(Operation &op) {
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;

// Determine what the expected argument types are to the cond/body blocks.
// The expected arguments should be compatible with ever iteration of the
// loop body / condition for tosa.while.
llvm::SmallVector<Type> argTypes;
for (auto operand : op.getOperands()) {
auto operandTy = cast<ShapedType>(operand.getType());
if (operandTy.hasRank()) {
auto newTy = operandTy.clone(operandTy.getShape());
argTypes.push_back(newTy);
} else {
argTypes.push_back(operand.getType());
}
}

// Save out the type information so we can restore at the end.
llvm::DenseMap<Value, Type> originalTypeMap;
for (auto &block : op.getRegion(1)) {
for (auto arg : block.getArguments())
originalTypeMap[arg] = arg.getType();
for (auto &op : block)
for (auto result : op.getResults())
originalTypeMap[result] = result.getType();
}
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());

bool hasNewTypes = true;
while (hasNewTypes) {
TypeModificationState localState;

// Set types on the block args.
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
block.getArgument(i).setType(argTypes[i]);
localState.setType(block.getArgument(i), argTypes[i]);
}

// Propagate to the end.
propagateShapesInRegion(bodyRegion);
propagateShapesInRegion(bodyRegion, localState);

// Find all the tosa yield types and verify there is atleast one.
// Find all the tosa yield types and verify there is a single one.
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);

if (yieldOps.empty())
return;

assert(yieldOps.size() == 1 && "missing or non-unique yield op");
// Using the new tosa.yield operand types, infer the new subtypes.
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
Expand Down Expand Up @@ -158,59 +211,31 @@ void propagateShapesToTosaWhile(Operation &op) {
argTypes[i] = newType;
}

// The types inferred in the block assume the operand types specified for
// this iteration. We need to restore the original types to ensure that
// future iterations only use the already specified types, not possible
// types from previous iterations.
for (auto &block : bodyRegion) {
for (auto arg : block.getArguments())
arg.setType(originalTypeMap[arg]);
for (auto &op : block)
for (auto result : op.getResults())
result.setType(originalTypeMap[result]);
}
// Revert all changes made during the speculative part of the algorithm.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this speculative? Is it too aggressive in assuming equality but lacking a way to relax it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe 'speculative' is the wrong word. Happy to change if you have a better term.

The inference step for while loops, refines the block argument types and propagates those types through the body block multiple times until a fixed-point is reached. This is to infer consistent types for the tosa.while_loop's block arguments, yield arguments, and operation results.

After each propagation, the types of each operation in the body are reset to their previous values to avoid leaking possibly overly refined results between propagation steps.

I think that, ideally, the TosaInferShapes pass would avoid modifying the IR until it has fully inferred the desired types, and only then update the IR. That would avoid these statefulness issues. I initially attempted to do that, but the InferShapedTypeOpInterface does not allow enough information to be passed in to the inference method (e.g. refinements made to body region operations).

localState.revert();
}

// We now set the block arguments according to the most recent shape
// inference results. This gives us the block arg types for the next
// iteration.
for (auto &region : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
region.front().getArgument(i).setType(argTypes[i]);
state.setType(region.front().getArgument(i), argTypes[i]);
}

propagateShapesInRegion(region);
propagateShapesInRegion(region, state);
}
}

// Track the old type for each operand whose type was updated
// during inference. This information is used to introduce casts
// back to the type expected by the operand after inference.
struct TypeRewriteInfo {
OpOperand *operand;
Type oldType;
};

void propagateShapesInRegion(Region &region) {
// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a TosaOp, or an op with a
// type-inference related interface.
// When a non-replaceable use is encountered, the value is wrapped in a
// cast back to the original type after inference.
auto isReplaceableUser = [](Operation *user) -> bool {
return user->getDialect()->getNamespace() ==
TosaDialect::getDialectNamespace() ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
};

llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
void propagateShapesInRegion(Region &region, TypeModificationState &state) {
for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
if (!op.getDialect() ||
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is string compare too.

continue;

propagateShapesToTosaIf(op);
propagateShapesToTosaWhile(op);
propagateShapesToTosaIf(op, state);
propagateShapesToTosaWhile(op, state);

InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
Expand Down Expand Up @@ -252,30 +277,11 @@ void propagateShapesInRegion(Region &region) {
continue;

// Set new type
result.setType(newKnowledge.getType());

// Collect all uses of the operation which require update.
for (auto &user : result.getUses()) {
if (!isReplaceableUser(user.getOwner()))
requiresUpdate.push_back({&user, resultTy});
}
state.setType(result, newKnowledge.getType());
}
}
}
}

// For each use whose type changed, cast the value with the new type back to
// the old type.
IRRewriter rewriter(region.getContext());
for (auto [operand, oldType] : requiresUpdate) {
rewriter.setInsertionPoint(operand->getOwner());

auto oldValue = operand->get();

auto loc = oldValue.getLoc();
auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
operand->set(castOp);
}
}

/// Pass that performs shape propagation across TOSA operations. This includes
Expand All @@ -285,7 +291,9 @@ struct TosaInferShapes
public:
void runOnOperation() override {
func::FuncOp func = getOperation();
propagateShapesInRegion(func.getBody());
TypeModificationState state;
propagateShapesInRegion(func.getBody(), state);
state.commit();
}
};
} // namespace
Expand Down
Loading