@@ -21,7 +21,7 @@ namespace bufferization {
21
21
} // namespace mlir
22
22
23
23
using namespace mlir ;
24
- using MemCpyFn = bufferization::BufferResultsToOutParamsOptions ::MemCpyFn;
24
+ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts ::MemCpyFn;
25
25
26
26
// / Return `true` if the given MemRef type has a fully dynamic layout.
27
27
static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
47
47
// Any args appended to the entry block are added to `appendedEntryArgs`.
48
48
static LogicalResult
49
49
updateFuncOp (func::FuncOp func,
50
- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
50
+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
51
+ bool addResultAttribute) {
51
52
auto functionType = func.getFunctionType ();
52
53
53
54
// Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
80
81
for (int i = 0 , e = erasedResultTypes.size (); i < e; ++i, ++erasedIndicesIt) {
81
82
func.setArgAttrs (functionType.getNumInputs () + i,
82
83
func.getResultAttrs (*erasedIndicesIt));
84
+ if (addResultAttribute)
85
+ func.setArgAttr (functionType.getNumInputs () + i,
86
+ StringAttr::get (func.getContext (), " bufferize.result" ),
87
+ UnitAttr::get (func.getContext ()));
83
88
}
84
89
85
90
// Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
127
132
// temporary buffers for newly introduced out params.
128
133
static LogicalResult
129
134
updateCalls (ModuleOp module ,
130
- const bufferization::BufferResultsToOutParamsOptions &options) {
135
+ const bufferization::BufferResultsToOutParamsOpts &options) {
131
136
bool didFail = false ;
132
137
SymbolTable symtab (module );
133
138
module .walk ([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
189
194
190
195
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
191
196
ModuleOp module ,
192
- const bufferization::BufferResultsToOutParamsOptions &options) {
197
+ const bufferization::BufferResultsToOutParamsOpts &options) {
193
198
for (auto func : module .getOps <func::FuncOp>()) {
194
199
if (!options.filterFn (&func))
195
200
continue ;
196
201
SmallVector<BlockArgument, 6 > appendedEntryArgs;
197
- if (failed (updateFuncOp (func, appendedEntryArgs)))
202
+ if (failed (
203
+ updateFuncOp (func, appendedEntryArgs, options.addResultAttribute )))
198
204
return failure ();
199
205
if (func.isExternal ())
200
206
continue ;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
218
224
: bufferization::impl::BufferResultsToOutParamsBase<
219
225
BufferResultsToOutParamsPass> {
220
226
explicit BufferResultsToOutParamsPass (
221
- const bufferization::BufferResultsToOutParamsOptions &options)
227
+ const bufferization::BufferResultsToOutParamsOpts &options)
222
228
: options(options) {}
223
229
224
230
void runOnOperation () override {
231
+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
232
+ if (addResultAttribute)
233
+ options.addResultAttribute = true ;
234
+
225
235
if (failed (bufferization::promoteBufferResultsToOutParams (getOperation (),
226
236
options)))
227
237
return signalPassFailure ();
228
238
}
229
239
230
240
private:
231
- bufferization::BufferResultsToOutParamsOptions options;
241
+ bufferization::BufferResultsToOutParamsOpts options;
232
242
};
233
243
} // namespace
234
244
235
245
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass (
236
- const bufferization::BufferResultsToOutParamsOptions &options) {
246
+ const bufferization::BufferResultsToOutParamsOpts &options) {
237
247
return std::make_unique<BufferResultsToOutParamsPass>(options);
238
248
}
0 commit comments