Skip to content

[mlir][bufferization] Add support for non-unique func.return #114017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {
class DialectRegistry;
Expand All @@ -21,6 +22,9 @@ class FuncOp;
} // namespace func

namespace bufferization {
/// Helper function that returns all func.return ops in the given function.
SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);

namespace func_ext {
/// The state of analysis of a FuncOp.
enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
#include <optional>

namespace mlir {
/// Return all func.return ops in the given function.
SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
SmallVector<func::ReturnOp> result;
for (Block &b : funcOp.getBody())
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
result.push_back(returnOp);
return result;
}

namespace bufferization {
namespace func_ext {

Expand All @@ -41,20 +50,6 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
#endif // NDEBUG
}

/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
func::ReturnOp returnOp;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
}
}
return returnOp;
}

/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
/// specified by the user (as per `options.functionArgTypeConverterFn`).
Expand Down Expand Up @@ -391,15 +386,6 @@ struct FuncOpInterface
getBufferType(op, value, options, invocationStack);
}

LogicalResult verifyAnalysis(Operation *op,
const AnalysisState &state) const {
auto funcOp = cast<func::FuncOp>(op);
// TODO: func.func with multiple returns are not supported.
if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
return op->emitOpError("op without unique func.return is not supported");
return success();
}

/// Rewrite function bbArgs and return values into buffer form. This function
/// bufferizes the function signature and the ReturnOp. When the entire
/// function body has been bufferized, function return types can be switched
Expand Down Expand Up @@ -446,41 +432,38 @@ struct FuncOpInterface
return success();
}

// TODO: Support functions with multiple returns.
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
assert(returnOp->getNumOperands() == retTypes.size() &&
"incorrect number of return values");
Location loc = returnOp.getLoc();

// 1. Bufferize every block.
for (Block &block : funcOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
options)))
return failure();

// 2. Bufferize all operands of the return op.
SmallVector<Value> returnValues;
for (auto [returnVal, bufferizedType] :
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);

// If not a tensor type just forward it.
if (!tensorType) {
returnValues.push_back(returnVal);
continue;
// 2. Bufferize the operands of the all return op.
for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
assert(returnOp->getNumOperands() == retTypes.size() &&
"incorrect number of return values");
SmallVector<Value> returnValues;
for (auto [returnVal, bufferizedType] :
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);

// If not a tensor type just forward it.
if (!tensorType) {
returnValues.push_back(returnVal);
continue;
}

// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
returnOp.getLoc(), bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}

// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc, bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
returnOp.getOperandsMutable().assign(returnValues);
}

returnOp.getOperandsMutable().assign(returnValues);

// 3. Set the new function type.
funcOp.setType(newFuncType);
return success();
Expand Down
174 changes: 128 additions & 46 deletions mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
return state.addExtension<FuncAnalysisState>();
}

/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
func::ReturnOp returnOp;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
}
}
return returnOp;
}

namespace {

/// Annotate IR with the results of the analysis. For testing purposes only.
Expand Down Expand Up @@ -146,24 +132,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}

// Support only single return-terminated block in the function.
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");

for (OpOperand &returnVal : returnOp->getOpOperands())
if (isa<RankedTensorType>(returnVal.get().getType()))
for (BlockArgument bbArg : funcOp.getArguments())
if (isa<RankedTensorType>(bbArg.getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
int64_t bbArgIdx = bbArg.getArgNumber();
if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
if (state.getOptions().testAnalysisOnly)
annotateEquivalentReturnBbArg(returnVal, bbArg);
// Find all func.return ops.
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
assert(!returnOps.empty() && "expected at least one ReturnOp");

// Build alias sets. Merge all aliases from all func.return ops.
for (BlockArgument bbArg : funcOp.getArguments()) {
if (isa<RankedTensorType>(bbArg.getType())) {
int64_t bbArgIdx = bbArg.getArgNumber();
// Store aliases in a set, so that we don't add the same alias twice.
SetVector<int64_t> aliases;
for (func::ReturnOp returnOp : returnOps) {
for (OpOperand &returnVal : returnOp->getOpOperands()) {
if (isa<RankedTensorType>(returnVal.get().getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
aliases.insert(returnIdx);
}
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
}
}
for (int64_t alias : aliases)
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
}
}

// Build equivalence sets.
// Helper function that finds an equivalent block argument index for the
// given OpOperand. Return std::nullopt if no equivalent block argument could
// be found.
auto findEquivalentBlockArgIdx =
[&](OpOperand &opOperand) -> std::optional<int64_t> {
Value v = opOperand.get();
if (!isa<TensorType>(v.getType()))
return std::nullopt;
for (BlockArgument bbArg : funcOp.getArguments()) {
if (isa<RankedTensorType>(bbArg.getType())) {
if (state.areEquivalentBufferizedValues(v, bbArg)) {
if (state.getOptions().testAnalysisOnly)
annotateEquivalentReturnBbArg(opOperand, bbArg);
return bbArg.getArgNumber();
}
}
}
return std::nullopt;
};

int64_t numResults = returnOps.front()->getNumOperands();
for (int64_t i = 0; i < numResults; ++i) {
// Find the equivalent block argument index for the i-th operand of the
// first func.return op.
std::optional<int64_t> maybeEquiv =
findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
if (!maybeEquiv.has_value())
continue;
int64_t bbArgIdx = *maybeEquiv;
bool allEquiv = true;

// Check if all other func.return ops have the same equivalent block
// argument for the i-th operand. In contrast to aliasing information,
// which is just "merged", equivalence information must match across all
// func.return ops.
for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
std::optional<int64_t> maybeEquiv =
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
if (maybeEquiv != bbArgIdx) {
allEquiv = false;
break;
}
}

// All func.return ops have the same equivalent block argument for the i-th
// operand.
if (allEquiv)
funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
}

return success();
}
Expand Down Expand Up @@ -302,14 +344,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
if (!funcOp.getBody().empty()) {
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
}

// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
Expand Down Expand Up @@ -351,6 +385,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
return success();
}

/// Helper function that extracts the source from a memref.cast. If the given
/// value is not a memref.cast result, simply returns the given value.
static Value unpackCast(Value v) {
auto castOp = v.getDefiningOp<memref::CastOp>();
if (!castOp)
return v;
return castOp.getSource();
}

/// Helper function that returns the return types (skipping casts) of the given
/// func.return ops. This function returns as many types as the return ops have
/// operands. If the i-th operand is not the same for all func.return ops, then
/// the i-th returned type is an "empty" type.
static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
assert(!returnOps.empty() && "expected at least one ReturnOp");
int numOperands = returnOps.front()->getNumOperands();

// Helper function that unpacks memref.cast ops and returns the type.
auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };

SmallVector<Type> result;
for (int i = 0; i < numOperands; ++i) {
// Get the type of the i-th operand of the first func.return ops.
Type t = getSourceType(returnOps.front()->getOperand(i));

// Check if all other func.return ops have a matching operand type.
for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
if (getSourceType(returnOps[j]->getOperand(i)) != t)
t = Type();

result.push_back(t);
}

return result;
}

/// Fold return values that are memref casts and update function return types.
///
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
Expand All @@ -359,21 +429,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(func::FuncOp funcOp) {
// There is nothing to do for bodiless ops.
if (funcOp.getBody().empty())
return;

func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
SmallVector<Type> resultTypes;
// Compute the common result types of all return ops.
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
SmallVector<Type> resultTypes = getReturnTypes(returnOps);

for (OpOperand &operand : returnOp->getOpOperands()) {
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
operand.set(castOp.getSource());
resultTypes.push_back(castOp.getSource().getType());
} else {
resultTypes.push_back(operand.get().getType());
// Remove direct casts.
for (func::ReturnOp returnOp : returnOps) {
for (OpOperand &operand : returnOp->getOpOperands()) {
// Bail if no common result type was found.
if (resultTypes[operand.getOperandNumber()]) {
operand.set(unpackCast(operand.get()));
}
}
}

// Fill in the missing result types that were not the same among all
// func.return ops.
for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
if (resultTypes[i])
continue;
resultTypes[i] = funcOp.getFunctionType().getResult(i);
}

// Update the function type.
auto newFuncType = FunctionType::get(
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
funcOp.setType(newFuncType);
Expand Down
Loading
Loading