Skip to content

[mlir][bufferization] Add BufferizableOpInterface::hasTensorSemantics #75273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,12 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
SmallVector<Value> &invocationStack);

/// Return "true" if the given op has tensor semantics and should be bufferized.
/// If the op is bufferizable, the BufferizableOpInterface is queried.
/// Otherwise, an op has tensor semantics if it has tensor operands, tensor
/// op results and/or tensor block arguments.
bool hasTensorSemantics(Operation *op);

/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
Expand Down Expand Up @@ -694,6 +700,10 @@ AliasingOpOperandList unknownGetAliasingOpOperands(Value value);
/// This is the default implementation of getAliasingValues in case the owner
/// op does not implement the BufferizableOpInterface.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);

/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);
} // namespace detail

} // namespace bufferization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,25 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return "true" if the this op has tensor semantics and should be
bufferized. By default, ops with tensor operands, tensor op results
and/or tensor block arguments have tensor semantics.

This interface methods can be implemented by ops that should be
bufferized but do not have tensor semantics according to the above
definition. E.g., this function can return "true" for symbols.
}],
/*retType=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::bufferization::detail
::defaultHasTensorSemantics($_op.getOperation());
}]
>,
StaticInterfaceMethod<
/*desc=*/[{
Return `true` if the op and this interface implementation supports
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
*options.defaultMemorySpace);
}

bool bufferization::hasTensorSemantics(Operation *op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.hasTensorSemantics();
return detail::defaultHasTensorSemantics(op);
}

void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
Operation *op,
ValueRange values) {
Expand Down Expand Up @@ -989,3 +995,20 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}

bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
return any_of(r.getBlocks(), [&](Block &b) {
return any_of(b.getArguments(), [&](BlockArgument bbArg) {
return isaTensor(bbArg.getType());
});
});
});
if (hasTensorBlockArgument)
return true;

if (any_of(op->getResultTypes(), isaTensor))
return true;
return any_of(op->getOperandTypes(), isaTensor);
}
25 changes: 0 additions & 25 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,31 +350,6 @@ mlir::bufferization::createFinalizingBufferizePass() {
// BufferizableOpInterface-based Bufferization
//===----------------------------------------------------------------------===//

static bool isaTensor(Type t) { return isa<TensorType>(t); }

/// Return true if the given op has a tensor result or a tensor operand.
static bool hasTensorSemantics(Operation *op) {
bool hasTensorBlockArgument = any_of(op->getRegions(), [](Region &r) {
return any_of(r.getBlocks(), [](Block &b) {
return any_of(b.getArguments(), [](BlockArgument bbArg) {
return isaTensor(bbArg.getType());
});
});
});
if (hasTensorBlockArgument)
return true;

if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
return hasTensorArg || hasTensorResult;
}

bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
return hasTensorResult || hasTensorOperand;
}

namespace {
/// A rewriter that keeps track of extra information during bufferization.
class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,28 @@ struct FuncOpInterface

static bool supportsUnstructuredControlFlow() { return true; }

bool hasTensorSemantics(Operation *op) const {
auto isaTensor = [](Type type) { return isa<TensorType>(type); };

// A function has tensor semantics if it has tensor arguments/results.
auto funcOp = cast<FuncOp>(op);
bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
if (hasTensorArg || hasTensorResult)
return true;

// It also has tensor semantics if it has tensor block arguments.
// TODO: Decouple bufferization of unstructured control flow from
// BufferizableOpInterface implementations. We should only care about
// region entry block arguments here (which are already covered by the
// argument types of the function).
for (Block &block : funcOp.getBody())
if (any_of(block.getArgumentTypes(), isaTensor))
return true;

return false;
}

AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1030,13 +1030,6 @@ OneShotAnalysisState::analyzeSingleOp(Operation *op,
return success();
}

/// Return true if the given op has a tensor result or a tensor operand.
static bool hasTensorSemantics(Operation *op) {
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
return hasTensorResult || hasTensorOperand;
}

/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
OneShotAnalysisState &state) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,15 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
foldMemRefCasts(funcOp);
}

// Bufferize all other ops.
for (Operation &op : moduleOp.getOps()) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be on function like interface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency reasons, func::FuncOp would be better here. Other function-like ops are not supported at the moment (they will fail to bufferize) and we exclusively use func::FuncOp, func::CallOp, func::ReturnOp in this file at the moment.

continue;
if (failed(bufferizeOp(&op, options, statistics)))
return failure();
}

// Post-pass cleanup of function argument attributes.
removeBufferizationAttributesInModule(moduleOp);

Expand Down