Skip to content

Commit 2026501

Browse files
[MLIR] Make OneShotModuleBufferize use OpInterface (#110322)
**Description:** This PR replaces a part of `FuncOp` and `CallOp` with `FunctionOpInterface` and `CallOpInterface` in `OneShotModuleBufferize`. Also fix the error from an integration test in the a previous PR attempt. (#107295) The below fixes skip `CallOpInterface` so that the assertions are not triggered. https://github.com/llvm/llvm-project/blob/8d780007625108a7f34e40efb8604b858e04c60c/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp#L254-L259 https://github.com/llvm/llvm-project/blob/8d780007625108a7f34e40efb8604b858e04c60c/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp#L311-L315 **Related Discord Discussion:** [Link](https://discord.com/channels/636084430946959380/642426447167881246/1280556809911799900) --------- Co-authored-by: erick-xanadu <[email protected]>
1 parent 60b604a commit 2026501

File tree

11 files changed

+316
-281
lines changed

11 files changed

+316
-281
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/IR/Operation.h"
1313
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
1415
#include "mlir/Support/LLVM.h"
1516
#include "llvm/ADT/DenseMapInfoVariant.h"
1617
#include "llvm/ADT/SetVector.h"
@@ -260,9 +261,9 @@ struct BufferizationOptions {
260261
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
261262
/// Tensor -> MemRef type converter.
262263
/// Parameters: Value, memory space, func op, bufferization options
263-
using FunctionArgTypeConverterFn =
264-
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
265-
func::FuncOp, const BufferizationOptions &)>;
264+
using FunctionArgTypeConverterFn = std::function<BaseMemRefType(
265+
TensorType, Attribute memorySpace, FunctionOpInterface,
266+
const BufferizationOptions &)>;
266267
/// Tensor -> MemRef type converter.
267268
/// Parameters: Value, memory space, bufferization options
268269
using UnknownTypeConverterFn = std::function<BaseMemRefType(

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
5050

5151
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
5252
/// indices.
53-
DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
53+
DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;
5454

5555
/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
56-
DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;
56+
DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;
5757

5858
/// A set of all read BlockArguments of FuncOps.
59-
DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
59+
DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;
6060

6161
/// A set of all written-to BlockArguments of FuncOps.
62-
DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
62+
DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;
6363

6464
/// Keep track of which FuncOps are fully analyzed or currently being
6565
/// analyzed.
66-
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
66+
DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;
6767

6868
/// This function is called right before analyzing the given FuncOp. It
6969
/// initializes the data structures for the FuncOp in this state object.
70-
void startFunctionAnalysis(FuncOp funcOp);
70+
void startFunctionAnalysis(FunctionOpInterface funcOp);
7171
};
7272

7373
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/TypeUtilities.h"
1919
#include "mlir/IR/Value.h"
2020
#include "mlir/Interfaces/ControlFlowInterfaces.h"
21+
#include "mlir/Interfaces/FunctionInterfaces.h"
2122
#include "llvm/ADT/ScopeExit.h"
2223
#include "llvm/Support/Debug.h"
2324

@@ -314,7 +315,7 @@ namespace {
314315
/// Default function arg type converter: Use a fully dynamic layout map.
315316
BaseMemRefType
316317
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
317-
func::FuncOp funcOp,
318+
FunctionOpInterface funcOp,
318319
const BufferizationOptions &options) {
319320
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
320321
}
@@ -361,7 +362,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
361362
void BufferizationOptions::setFunctionBoundaryTypeConversion(
362363
LayoutMapOption layoutMapOption) {
363364
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
364-
func::FuncOp funcOp,
365+
FunctionOpInterface funcOp,
365366
const BufferizationOptions &options) {
366367
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
367368
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace mlir {
2222
namespace bufferization {
2323
namespace func_ext {
2424

25-
void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
25+
void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
2626
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
2727
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
2828
auto createdAliasingResults =

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
7575
using namespace mlir::bufferization::func_ext;
7676

7777
/// A mapping of FuncOps to their callers.
78-
using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
78+
using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
7979

8080
/// Get or create FuncAnalysisState.
8181
static FuncAnalysisState &
@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8888

8989
/// Return the unique ReturnOp that terminates `funcOp`.
9090
/// Return nullptr if there is no such unique ReturnOp.
91-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92-
func::ReturnOp returnOp;
93-
for (Block &b : funcOp.getBody()) {
94-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
91+
static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
92+
Operation *returnOp = nullptr;
93+
for (Block &b : funcOp.getFunctionBody()) {
94+
auto candidateOp = b.getTerminator();
95+
if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
9596
if (returnOp)
9697
return nullptr;
9798
returnOp = candidateOp;
@@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
126127
/// Store function BlockArguments that are equivalent to/aliasing a returned
127128
/// value in FuncAnalysisState.
128129
static LogicalResult
129-
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
130+
aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
131+
OneShotAnalysisState &state,
130132
FuncAnalysisState &funcState) {
131-
if (funcOp.getBody().empty()) {
133+
if (funcOp.getFunctionBody().empty()) {
132134
// No function body available. Conservatively assume that every tensor
133135
// return value may alias with any tensor bbArg.
134-
FunctionType type = funcOp.getFunctionType();
135-
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
136+
for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
136137
if (!isa<TensorType>(inputIt.value()))
137138
continue;
138-
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
139+
for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
139140
if (!isa<TensorType>(resultIt.value()))
140141
continue;
141142
int64_t returnIdx = resultIt.index();
@@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
147148
}
148149

149150
// Support only single return-terminated block in the function.
150-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
151+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
151152
assert(returnOp && "expected func with single return op");
152153

153154
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
168169
return success();
169170
}
170171

171-
static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
172-
bool isWritten) {
172+
static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
173+
bool isRead, bool isWritten) {
173174
OpBuilder b(funcOp.getContext());
174175
Attribute accessType;
175176
if (isRead && isWritten) {
@@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
189190
/// function with unknown ops, we conservatively assume that such ops bufferize
190191
/// to a read + write.
191192
static LogicalResult
192-
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
193+
funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
194+
OneShotAnalysisState &state,
193195
FuncAnalysisState &funcState) {
194-
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
195-
++idx) {
196+
for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) {
196197
// Skip non-tensor arguments.
197-
if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
198+
if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
198199
continue;
199200
bool isRead;
200201
bool isWritten;
@@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
204205
StringRef str = accessAttr.getValue();
205206
isRead = str == "read" || str == "read-write";
206207
isWritten = str == "write" || str == "read-write";
207-
} else if (funcOp.getBody().empty()) {
208+
} else if (funcOp.getFunctionBody().empty()) {
208209
// If the function has no body, conservatively assume that all args are
209210
// read + written.
210211
isRead = true;
@@ -230,33 +231,33 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
230231

231232
/// Remove bufferization attributes on FuncOp arguments.
232233
static void removeBufferizationAttributes(BlockArgument bbArg) {
233-
auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
234+
auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
234235
funcOp.removeArgAttr(bbArg.getArgNumber(),
235236
BufferizationDialect::kBufferLayoutAttrName);
236237
funcOp.removeArgAttr(bbArg.getArgNumber(),
237238
BufferizationDialect::kWritableAttrName);
238239
}
239240

240-
/// Return the func::FuncOp called by `callOp`.
241-
static func::FuncOp getCalledFunction(func::CallOp callOp) {
241+
static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
242242
SymbolRefAttr sym =
243243
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
244244
if (!sym)
245245
return nullptr;
246-
return dyn_cast_or_null<func::FuncOp>(
246+
return dyn_cast_or_null<FunctionOpInterface>(
247247
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
248248
}
249249

250250
/// Gather equivalence info of CallOps.
251251
/// Note: This only adds new equivalence info if the called function was already
252252
/// analyzed.
253253
// TODO: This does not handle cyclic function call graphs etc.
254-
static void equivalenceAnalysis(func::FuncOp funcOp,
254+
static void equivalenceAnalysis(FunctionOpInterface funcOp,
255255
OneShotAnalysisState &state,
256256
FuncAnalysisState &funcState) {
257-
funcOp->walk([&](func::CallOp callOp) {
258-
func::FuncOp calledFunction = getCalledFunction(callOp);
259-
assert(calledFunction && "could not retrieved called func::FuncOp");
257+
funcOp->walk([&](CallOpInterface callOp) {
258+
FunctionOpInterface calledFunction = getCalledFunction(callOp);
259+
if (!calledFunction)
260+
return WalkResult::skip();
260261

261262
// No equivalence info available for the called function.
262263
if (!funcState.equivalentFuncArgs.count(calledFunction))
@@ -267,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
267268
int64_t bbargIdx = it.second;
268269
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
269270
continue;
270-
Value returnVal = callOp.getResult(returnIdx);
271+
Value returnVal = callOp->getResult(returnIdx);
271272
Value argVal = callOp->getOperand(bbargIdx);
272273
state.unionEquivalenceClasses(returnVal, argVal);
273274
}
@@ -277,11 +278,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
277278
}
278279

279280
/// Return "true" if the given function signature has tensor semantics.
280-
static bool hasTensorSignature(func::FuncOp funcOp) {
281-
return llvm::any_of(funcOp.getFunctionType().getInputs(),
282-
llvm::IsaPred<TensorType>) ||
283-
llvm::any_of(funcOp.getFunctionType().getResults(),
284-
llvm::IsaPred<TensorType>);
281+
static bool hasTensorSignature(FunctionOpInterface funcOp) {
282+
return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred<TensorType>) ||
283+
llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred<TensorType>);
285284
}
286285

287286
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -291,16 +290,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
291290
/// retrieve the called FuncOp from any func::CallOp.
292291
static LogicalResult
293292
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
294-
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
293+
SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
295294
FuncCallerMap &callerMap) {
296295
// For each FuncOp, the set of functions called by it (i.e. the union of
297296
// symbols of all nested func::CallOp).
298-
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
297+
DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
299298
// For each FuncOp, the number of func::CallOp it contains.
300-
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
301-
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
302-
if (!funcOp.getBody().empty()) {
303-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
299+
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
300+
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
301+
if (!funcOp.getFunctionBody().empty()) {
302+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
304303
if (!returnOp)
305304
return funcOp->emitError()
306305
<< "cannot bufferize a FuncOp with tensors and "
@@ -309,9 +308,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
309308

310309
// Collect function calls and populate the caller map.
311310
numberCallOpsContainedInFuncOp[funcOp] = 0;
312-
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
313-
func::FuncOp calledFunction = getCalledFunction(callOp);
314-
assert(calledFunction && "could not retrieved called func::FuncOp");
311+
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
312+
FunctionOpInterface calledFunction = getCalledFunction(callOp);
313+
if (!calledFunction)
314+
return WalkResult::skip();
315315
// If the called function does not have any tensors in its signature, then
316316
// it is not necessary to bufferize the callee before the caller.
317317
if (!hasTensorSignature(calledFunction))
@@ -349,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
349349
/// most generic layout map as function return types. After bufferizing the
350350
/// entire function body, a more concise memref type can potentially be used for
351351
/// the return type of the function.
352-
static void foldMemRefCasts(func::FuncOp funcOp) {
353-
if (funcOp.getBody().empty())
352+
static void foldMemRefCasts(FunctionOpInterface funcOp) {
353+
if (funcOp.getFunctionBody().empty())
354354
return;
355355

356-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
356+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
357357
SmallVector<Type> resultTypes;
358358

359359
for (OpOperand &operand : returnOp->getOpOperands()) {
@@ -365,8 +365,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
365365
}
366366
}
367367

368-
auto newFuncType = FunctionType::get(
369-
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
368+
auto newFuncType = FunctionType::get(funcOp.getContext(),
369+
funcOp.getArgumentTypes(), resultTypes);
370370
funcOp.setType(newFuncType);
371371
}
372372

@@ -379,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379379
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
380380

381381
// A list of functions in the order in which they are analyzed + bufferized.
382-
SmallVector<func::FuncOp> orderedFuncOps;
382+
SmallVector<FunctionOpInterface> orderedFuncOps;
383383

384384
// A mapping of FuncOps to their callers.
385385
FuncCallerMap callerMap;
@@ -388,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388388
return failure();
389389

390390
// Analyze ops.
391-
for (func::FuncOp funcOp : orderedFuncOps) {
391+
for (FunctionOpInterface funcOp : orderedFuncOps) {
392392
if (!state.getOptions().isOpAllowed(funcOp))
393393
continue;
394394

@@ -416,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
416416

417417
void mlir::bufferization::removeBufferizationAttributesInModule(
418418
ModuleOp moduleOp) {
419-
moduleOp.walk([&](func::FuncOp op) {
419+
moduleOp.walk([&](FunctionOpInterface op) {
420420
for (BlockArgument bbArg : op.getArguments())
421421
removeBufferizationAttributes(bbArg);
422422
});
@@ -430,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430430
IRRewriter rewriter(moduleOp.getContext());
431431

432432
// A list of functions in the order in which they are analyzed + bufferized.
433-
SmallVector<func::FuncOp> orderedFuncOps;
433+
SmallVector<FunctionOpInterface> orderedFuncOps;
434434

435435
// A mapping of FuncOps to their callers.
436436
FuncCallerMap callerMap;
@@ -439,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
439439
return failure();
440440

441441
// Bufferize functions.
442-
for (func::FuncOp funcOp : orderedFuncOps) {
442+
for (FunctionOpInterface funcOp : orderedFuncOps) {
443443
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
444444
// would be invalidated.
445445

446-
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
446+
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
447447
// This function was not analyzed and RaW conflicts were not resolved.
448448
// Buffer copies must be inserted before every write.
449449
OneShotBufferizationOptions updatedOptions = options;
@@ -463,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
463463
// Bufferize all other ops.
464464
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
465465
// Functions were already bufferized.
466-
if (isa<func::FuncOp>(&op))
466+
if (isa<FunctionOpInterface>(&op))
467467
continue;
468468
if (failed(bufferizeOp(&op, options, statistics)))
469469
return failure();
@@ -490,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
490490
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
491491
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
492492
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
493-
auto func = dyn_cast<func::FuncOp>(op);
493+
auto func = dyn_cast<FunctionOpInterface>(op);
494494
if (!func)
495-
func = op->getParentOfType<func::FuncOp>();
495+
func = op->getParentOfType<FunctionOpInterface>();
496496
if (func)
497497
return llvm::is_contained(options.noAnalysisFuncFilter,
498-
func.getSymName());
498+
func.getName());
499499
return false;
500500
};
501501
OneShotBufferizationOptions updatedOptions(options);

0 commit comments

Comments
 (0)