@@ -21,7 +21,7 @@ namespace bufferization {
21
21
} // namespace mlir
22
22
23
23
using namespace mlir ;
24
- using MemCpyFn = bufferization::BufferResultsToOutParamsOptions ::MemCpyFn;
24
+ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts ::MemCpyFn;
25
25
26
26
// / Return `true` if the given MemRef type has a fully dynamic layout.
27
27
static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -45,9 +45,12 @@ static bool hasStaticIdentityLayout(MemRefType type) {
45
45
// Updates the func op and entry block.
46
46
//
47
47
// Any args appended to the entry block are added to `appendedEntryArgs`.
48
+ // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
49
+ // to each newly created function argument.
48
50
static LogicalResult
49
51
updateFuncOp (func::FuncOp func,
50
- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
52
+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
53
+ bool addResultAttribute) {
51
54
auto functionType = func.getFunctionType ();
52
55
53
56
// Collect information about the results will become appended arguments.
@@ -80,6 +83,10 @@ updateFuncOp(func::FuncOp func,
80
83
for (int i = 0 , e = erasedResultTypes.size (); i < e; ++i, ++erasedIndicesIt) {
81
84
func.setArgAttrs (functionType.getNumInputs () + i,
82
85
func.getResultAttrs (*erasedIndicesIt));
86
+ if (addResultAttribute)
87
+ func.setArgAttr (functionType.getNumInputs () + i,
88
+ StringAttr::get (func.getContext (), " bufferize.result" ),
89
+ UnitAttr::get (func.getContext ()));
83
90
}
84
91
85
92
// Erase the results.
@@ -127,7 +134,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
127
134
// temporary buffers for newly introduced out params.
128
135
static LogicalResult
129
136
updateCalls (ModuleOp module ,
130
- const bufferization::BufferResultsToOutParamsOptions &options) {
137
+ const bufferization::BufferResultsToOutParamsOpts &options) {
131
138
bool didFail = false ;
132
139
SymbolTable symtab (module );
133
140
module .walk ([&](func::CallOp op) {
@@ -189,12 +196,13 @@ updateCalls(ModuleOp module,
189
196
190
197
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
191
198
ModuleOp module ,
192
- const bufferization::BufferResultsToOutParamsOptions &options) {
199
+ const bufferization::BufferResultsToOutParamsOpts &options) {
193
200
for (auto func : module .getOps <func::FuncOp>()) {
194
201
if (!options.filterFn (&func))
195
202
continue ;
196
203
SmallVector<BlockArgument, 6 > appendedEntryArgs;
197
- if (failed (updateFuncOp (func, appendedEntryArgs)))
204
+ if (failed (
205
+ updateFuncOp (func, appendedEntryArgs, options.addResultAttribute )))
198
206
return failure ();
199
207
if (func.isExternal ())
200
208
continue ;
@@ -218,21 +226,25 @@ struct BufferResultsToOutParamsPass
218
226
: bufferization::impl::BufferResultsToOutParamsBase<
219
227
BufferResultsToOutParamsPass> {
220
228
explicit BufferResultsToOutParamsPass (
221
- const bufferization::BufferResultsToOutParamsOptions &options)
229
+ const bufferization::BufferResultsToOutParamsOpts &options)
222
230
: options(options) {}
223
231
224
232
void runOnOperation () override {
233
+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
234
+ if (addResultAttribute)
235
+ options.addResultAttribute = true ;
236
+
225
237
if (failed (bufferization::promoteBufferResultsToOutParams (getOperation (),
226
238
options)))
227
239
return signalPassFailure ();
228
240
}
229
241
230
242
private:
231
- bufferization::BufferResultsToOutParamsOptions options;
243
+ bufferization::BufferResultsToOutParamsOpts options;
232
244
};
233
245
} // namespace
234
246
235
247
std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass (
236
- const bufferization::BufferResultsToOutParamsOptions &options) {
248
+ const bufferization::BufferResultsToOutParamsOpts &options) {
237
249
return std::make_unique<BufferResultsToOutParamsPass>(options);
238
250
}
0 commit comments