-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Fix tosa-infer-shapes crash #87234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,14 +18,9 @@ | |
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Interfaces/InferTypeOpInterface.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "llvm/Support/FormatVariadic.h" | ||
|
||
namespace mlir { | ||
namespace tosa { | ||
|
@@ -39,9 +34,87 @@ using namespace mlir::tosa; | |
|
||
namespace { | ||
|
||
void propagateShapesInRegion(Region ®ion); | ||
// Check whether this use case is replaceable. We define an op as | ||
// being replaceable if it is used by a TosaOp, or an op with a | ||
// type-inference related interface. | ||
// When a non-replaceable use is encountered, the value is wrapped in a | ||
// cast back to the original type after inference. | ||
bool isReplaceableUser(Operation *user) { | ||
// Handle unregistered dialects. | ||
if (!user->getDialect()) | ||
return false; | ||
|
||
return user->getDialect()->getNamespace() == | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: this is a string compare for dialects, comparing the id's would be cheaper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code already existed in the pass. I just moved it into a separate function so it could be used elsewhere. |
||
TosaDialect::getDialectNamespace() || | ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user); | ||
} | ||
|
||
// During type propagation, the types of values in the operator graph are | ||
// updated. For the tosa.while_loop operation, types are speculatively updated | ||
// within the body region to determine the output type of the while_loop. This | ||
// process is performed until a fixed point is reached, then the types are | ||
// reverted. | ||
// | ||
// This class encapsulates the state information needed to perform the reversion | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/reversion/roll back/ ? (just given that's the term used in dialect conversion) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can update that. Given that I've already merged the change, I'll create a followup PR to address your comments. |
||
// process or to commit to the final changes. | ||
class TypeModificationState { | ||
public: | ||
TypeModificationState() = default; | ||
|
||
~TypeModificationState() { | ||
// Ensure the recorded modifications are either committed or reverted. | ||
assert(oldTypes.empty() && "unhandled type modifications"); | ||
} | ||
|
||
// Update the state of the value and record the old type. | ||
void setType(Value value, Type type) { | ||
if (value.getType() != type) { | ||
oldTypes.emplace_back(value, value.getType()); | ||
value.setType(type); | ||
} | ||
} | ||
|
||
void propagateShapesToTosaIf(Operation &op) { | ||
// Revert changes made to the types in the IR by setting all the affected | ||
// values to their old types. | ||
void revert() { | ||
// Otherwise revert the changes. | ||
for (auto [value, type] : oldTypes) | ||
value.setType(type); | ||
|
||
oldTypes.clear(); | ||
} | ||
|
||
// Commit the changes to the types in the IR. | ||
// This requires inserting tensor.cast operations to mediate the newly | ||
// inferred result types with users that do not support type inference. | ||
void commit() { | ||
// For each use whose type changed, cast the value with the new type back to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cast is only where type information can't be propagated/refined (e.g., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
// the old type. | ||
for (auto [value, oldType] : oldTypes) { | ||
for (auto &use : value.getUses()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if you have a user that uses the result of the op multiple times? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic will generate 1 |
||
if (isReplaceableUser(use.getOwner())) | ||
continue; | ||
|
||
OpBuilder builder(value.getContext()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can combine these 3 instructions by using ImplicitLocOpBuilder builder(value.getLoc(), use.getOwner()); |
||
builder.setInsertionPoint(use.getOwner()); | ||
|
||
Location loc = value.getLoc(); | ||
use.set(builder.create<tensor::CastOp>(loc, oldType, value)); | ||
} | ||
} | ||
|
||
oldTypes.clear(); | ||
} | ||
|
||
private: | ||
// A record of each value whose type was updated along with that value's | ||
// previous type. | ||
llvm::SmallVector<std::pair<Value, Type>> oldTypes; | ||
}; | ||
|
||
void propagateShapesInRegion(Region ®ion, TypeModificationState &state); | ||
|
||
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) { | ||
IfOp ifOp = dyn_cast<IfOp>(op); | ||
if (!ifOp) | ||
return; | ||
|
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) { | |
|
||
if (inferredTy.hasRank()) { | ||
Type newType = oldType.clone(inferredTy.getShape()); | ||
blockArg.setType(newType); | ||
state.setType(blockArg, newType); | ||
} | ||
} | ||
|
||
|
@@ -71,64 +144,44 @@ void propagateShapesToTosaIf(Operation &op) { | |
ValueKnowledge::join(operandKnowledge, blockKnowledge); | ||
if (!joinedKnowledge) | ||
continue; | ||
frontBlock.getArgument(i).setType(joinedKnowledge.getType()); | ||
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType()); | ||
} | ||
|
||
propagateShapesInRegion(region); | ||
propagateShapesInRegion(region, state); | ||
} | ||
} | ||
|
||
void propagateShapesToTosaWhile(Operation &op) { | ||
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) { | ||
WhileOp whileOp = dyn_cast<WhileOp>(op); | ||
if (!whileOp) | ||
return; | ||
|
||
// Determine what the expected argument types are to the cond/body blocks. | ||
// The expected arguments should be compatible with ever iteration of the | ||
// loop body / condition for tosa.while. | ||
llvm::SmallVector<Type> argTypes; | ||
for (auto operand : op.getOperands()) { | ||
auto operandTy = cast<ShapedType>(operand.getType()); | ||
if (operandTy.hasRank()) { | ||
auto newTy = operandTy.clone(operandTy.getShape()); | ||
argTypes.push_back(newTy); | ||
} else { | ||
argTypes.push_back(operand.getType()); | ||
} | ||
} | ||
|
||
// Save out the type information so we can restore at the end. | ||
llvm::DenseMap<Value, Type> originalTypeMap; | ||
for (auto &block : op.getRegion(1)) { | ||
for (auto arg : block.getArguments()) | ||
originalTypeMap[arg] = arg.getType(); | ||
for (auto &op : block) | ||
for (auto result : op.getResults()) | ||
originalTypeMap[result] = result.getType(); | ||
} | ||
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes()); | ||
|
||
bool hasNewTypes = true; | ||
while (hasNewTypes) { | ||
TypeModificationState localState; | ||
|
||
// Set types on the block args. | ||
Region &bodyRegion = op.getRegion(1); | ||
Block &block = bodyRegion.front(); | ||
for (int i = 0, s = argTypes.size(); i < s; i++) { | ||
block.getArgument(i).setType(argTypes[i]); | ||
localState.setType(block.getArgument(i), argTypes[i]); | ||
} | ||
|
||
// Propagate to the end. | ||
propagateShapesInRegion(bodyRegion); | ||
propagateShapesInRegion(bodyRegion, localState); | ||
|
||
// Find all the tosa yield types and verify there is atleast one. | ||
// Find all the tosa yield types and verify there is a single one. | ||
llvm::SmallVector<YieldOp> yieldOps; | ||
for (auto &block : bodyRegion) | ||
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) | ||
yieldOps.push_back(yieldOp); | ||
|
||
if (yieldOps.empty()) | ||
return; | ||
|
||
assert(yieldOps.size() == 1 && "missing or non-unique yield op"); | ||
// Using the new tosa.yield operand types, infer the new subtypes. | ||
llvm::SmallVector<ValueKnowledge> yieldTypeInfo; | ||
for (auto ty : argTypes) { | ||
|
@@ -158,59 +211,31 @@ void propagateShapesToTosaWhile(Operation &op) { | |
argTypes[i] = newType; | ||
} | ||
|
||
// The types inferred in the block assume the operand types specified for | ||
// this iteration. We need to restore the original types to ensure that | ||
// future iterations only use the already specified types, not possible | ||
// types from previous iterations. | ||
for (auto &block : bodyRegion) { | ||
for (auto arg : block.getArguments()) | ||
arg.setType(originalTypeMap[arg]); | ||
for (auto &op : block) | ||
for (auto result : op.getResults()) | ||
result.setType(originalTypeMap[result]); | ||
} | ||
// Revert all changes made during the speculative part of the algorithm. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this speculative? Is it too aggressive in assuming equality but lacking a way to relax it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe 'speculative' is the wrong word. Happy to change if you have a better term. The inference step for while loops, refines the block argument types and propagates those types through the body block multiple times until a fixed-point is reached. This is to infer consistent types for the After each propagation, the types of each operation in the body are reset to their previous values to avoid leaking possibly overly refined results between propagation steps. I think that, ideally, the TosaInferShapes pass would avoid modifying the IR until it has fully inferred the desired types, and only then update the IR. That would avoid these statefulness issues. I initially attempted to do that, but the |
||
localState.revert(); | ||
} | ||
|
||
// We now set the block arguments according to the most recent shape | ||
// inference results. This gives us the block arg types for the next | ||
// iteration. | ||
for (auto ®ion : op.getRegions()) { | ||
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { | ||
region.front().getArgument(i).setType(argTypes[i]); | ||
state.setType(region.front().getArgument(i), argTypes[i]); | ||
} | ||
|
||
propagateShapesInRegion(region); | ||
propagateShapesInRegion(region, state); | ||
} | ||
} | ||
|
||
// Track the old type for each operand whose type was updated | ||
// during inference. This information is used to introduce casts | ||
// back to the type expected by the operand after inference. | ||
struct TypeRewriteInfo { | ||
OpOperand *operand; | ||
Type oldType; | ||
}; | ||
|
||
void propagateShapesInRegion(Region ®ion) { | ||
// Check whether this use case is replaceable. We define an op as | ||
// being replaceable if it is used by a TosaOp, or an op with a | ||
// type-inference related interface. | ||
// When a non-replaceable use is encountered, the value is wrapped in a | ||
// cast back to the original type after inference. | ||
auto isReplaceableUser = [](Operation *user) -> bool { | ||
return user->getDialect()->getNamespace() == | ||
TosaDialect::getDialectNamespace() || | ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user); | ||
}; | ||
|
||
llvm::SmallVector<TypeRewriteInfo> requiresUpdate; | ||
void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { | ||
for (auto &block : region) { | ||
for (Operation &op : block) { | ||
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) | ||
if (!op.getDialect() || | ||
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is string compare too. |
||
continue; | ||
|
||
propagateShapesToTosaIf(op); | ||
propagateShapesToTosaWhile(op); | ||
propagateShapesToTosaIf(op, state); | ||
propagateShapesToTosaWhile(op, state); | ||
|
||
InferShapedTypeOpInterface shapeInterface = | ||
dyn_cast<InferShapedTypeOpInterface>(op); | ||
|
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region ®ion) { | |
continue; | ||
|
||
// Set new type | ||
result.setType(newKnowledge.getType()); | ||
|
||
// Collect all uses of the operation which require update. | ||
for (auto &user : result.getUses()) { | ||
if (!isReplaceableUser(user.getOwner())) | ||
requiresUpdate.push_back({&user, resultTy}); | ||
} | ||
state.setType(result, newKnowledge.getType()); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// For each use whose type changed, cast the value with the new type back to | ||
// the old type. | ||
IRRewriter rewriter(region.getContext()); | ||
for (auto [operand, oldType] : requiresUpdate) { | ||
rewriter.setInsertionPoint(operand->getOwner()); | ||
|
||
auto oldValue = operand->get(); | ||
|
||
auto loc = oldValue.getLoc(); | ||
auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue); | ||
operand->set(castOp); | ||
} | ||
} | ||
|
||
/// Pass that performs shape propagation across TOSA operations. This includes | ||
|
@@ -285,7 +291,9 @@ struct TosaInferShapes | |
public: | ||
void runOnOperation() override { | ||
func::FuncOp func = getOperation(); | ||
propagateShapesInRegion(func.getBody()); | ||
TypeModificationState state; | ||
propagateShapesInRegion(func.getBody(), state); | ||
state.commit(); | ||
} | ||
}; | ||
} // namespace | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC what TOSA ops do not have the type inference interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just tried removing that condition. Based on the how the tests failed, at least
tosa.yield
. Everthing inmlir/Dialect/Tosa/IR/TosaUtilOps.td
lacks an inference interface as well.