@@ -119,9 +119,21 @@ static void updateReturnOps(func::FuncOp func,
119
119
120
120
// Updates all CallOps in the scope of the given ModuleOp by allocating
121
121
// temporary buffers for newly introduced out params.
122
- static LogicalResult updateCalls (ModuleOp module ) {
122
+ static LogicalResult
123
+ updateCalls (ModuleOp module ,
124
+ const bufferization::BufferResultsToOutParamsOptions &options) {
123
125
bool didFail = false ;
126
+ SymbolTable symtab (module );
124
127
module .walk ([&](func::CallOp op) {
128
+ auto callee = symtab.lookup <func::FuncOp>(op.getCallee ());
129
+ if (!callee) {
130
+ op.emitError () << " cannot find callee '" << op.getCallee () << " ' in "
131
+ << " symbol table" ;
132
+ didFail = true ;
133
+ return ;
134
+ }
135
+ if (!options.filterFn (&callee))
136
+ return ;
125
137
SmallVector<Value, 6 > replaceWithNewCallResults;
126
138
SmallVector<Value, 6 > replaceWithOutParams;
127
139
for (OpResult result : op.getResults ()) {
@@ -169,17 +181,20 @@ static LogicalResult updateCalls(ModuleOp module) {
169
181
return failure (didFail);
170
182
}
171
183
172
- LogicalResult
173
- mlir::bufferization::promoteBufferResultsToOutParams (ModuleOp module ) {
184
+ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
185
+ ModuleOp module ,
186
+ const bufferization::BufferResultsToOutParamsOptions &options) {
174
187
for (auto func : module .getOps <func::FuncOp>()) {
188
+ if (!options.filterFn (&func))
189
+ continue ;
175
190
SmallVector<BlockArgument, 6 > appendedEntryArgs;
176
191
if (failed (updateFuncOp (func, appendedEntryArgs)))
177
192
return failure ();
178
193
if (func.isExternal ())
179
194
continue ;
180
195
updateReturnOps (func, appendedEntryArgs);
181
196
}
182
- if (failed (updateCalls (module )))
197
+ if (failed (updateCalls (module , options )))
183
198
return failure ();
184
199
return success ();
185
200
}
@@ -188,14 +203,22 @@ namespace {
188
203
struct BufferResultsToOutParamsPass
189
204
: bufferization::impl::BufferResultsToOutParamsBase<
190
205
BufferResultsToOutParamsPass> {
206
+ explicit BufferResultsToOutParamsPass (
207
+ const bufferization::BufferResultsToOutParamsOptions &options)
208
+ : options(options) {}
209
+
191
210
void runOnOperation () override {
192
- if (failed (bufferization::promoteBufferResultsToOutParams (getOperation ())))
211
+ if (failed (bufferization::promoteBufferResultsToOutParams (getOperation (),
212
+ options)))
193
213
return signalPassFailure ();
194
214
}
215
+
216
+ private:
217
+ bufferization::BufferResultsToOutParamsOptions options;
195
218
};
196
219
} // namespace
197
220
198
- std::unique_ptr<Pass>
199
- mlir:: bufferization::createBufferResultsToOutParamsPass ( ) {
200
- return std::make_unique<BufferResultsToOutParamsPass>();
221
+ std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass (
222
+ const bufferization::BufferResultsToOutParamsOptions &options ) {
223
+ return std::make_unique<BufferResultsToOutParamsPass>(options );
201
224
}
0 commit comments