@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
86
86
return state.addExtension <FuncAnalysisState>();
87
87
}
88
88
89
- // / Return the unique ReturnOp that terminates `funcOp`.
90
- // / Return nullptr if there is no such unique ReturnOp.
91
- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92
- func::ReturnOp returnOp;
93
- for (Block &b : funcOp.getBody ()) {
94
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
95
- if (returnOp)
96
- return nullptr ;
97
- returnOp = candidateOp;
98
- }
99
- }
100
- return returnOp;
89
+ // / Return all top-level func.return ops in the given function.
90
+ static SmallVector<func::ReturnOp> getReturnOps (FuncOp funcOp) {
91
+ SmallVector<func::ReturnOp> result;
92
+ for (Block &b : funcOp.getBody ())
93
+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator ()))
94
+ result.push_back (returnOp);
95
+ return result;
101
96
}
102
97
103
98
namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146
141
return success ();
147
142
}
148
143
149
- // Support only single return-terminated block in the function.
150
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
151
- assert (returnOp && " expected func with single return op" );
152
-
153
- for (OpOperand &returnVal : returnOp->getOpOperands ())
154
- if (isa<RankedTensorType>(returnVal.get ().getType ()))
155
- for (BlockArgument bbArg : funcOp.getArguments ())
156
- if (isa<RankedTensorType>(bbArg.getType ())) {
157
- int64_t returnIdx = returnVal.getOperandNumber ();
158
- int64_t bbArgIdx = bbArg.getArgNumber ();
159
- if (state.areEquivalentBufferizedValues (returnVal.get (), bbArg)) {
160
- funcState.equivalentFuncArgs [funcOp][returnIdx] = bbArgIdx;
161
- if (state.getOptions ().testAnalysisOnly )
162
- annotateEquivalentReturnBbArg (returnVal, bbArg);
144
+ // Find all func.return ops.
145
+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
146
+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
147
+
148
+ // Build alias sets. Merge all aliases from all func.return ops.
149
+ for (BlockArgument bbArg : funcOp.getArguments ()) {
150
+ if (isa<RankedTensorType>(bbArg.getType ())) {
151
+ int64_t bbArgIdx = bbArg.getArgNumber ();
152
+ // Store aliases in a set, so that we don't add the same alias twice.
153
+ SetVector<int64_t > aliases;
154
+ for (func::ReturnOp returnOp : returnOps) {
155
+ for (OpOperand &returnVal : returnOp->getOpOperands ()) {
156
+ if (isa<RankedTensorType>(returnVal.get ().getType ())) {
157
+ int64_t returnIdx = returnVal.getOperandNumber ();
158
+ if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
159
+ aliases.insert (returnIdx);
163
160
}
164
- if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
165
- funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (returnIdx);
166
161
}
162
+ }
163
+ for (int64_t alias : aliases)
164
+ funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (alias);
165
+ }
166
+ }
167
+
168
+ // Build equivalence sets.
169
+ // Helper function that finds an equivalent block argument index for the
170
+ // given OpOperand. Return std::nullopt if no equivalent block argument could
171
+ // be found.
172
+ auto findEquivalentBlockArgIdx =
173
+ [&](OpOperand &opOperand) -> std::optional<int64_t > {
174
+ Value v = opOperand.get ();
175
+ if (!isa<TensorType>(v.getType ()))
176
+ return std::nullopt;
177
+ for (BlockArgument bbArg : funcOp.getArguments ()) {
178
+ if (isa<RankedTensorType>(bbArg.getType ())) {
179
+ if (state.areEquivalentBufferizedValues (v, bbArg)) {
180
+ if (state.getOptions ().testAnalysisOnly )
181
+ annotateEquivalentReturnBbArg (opOperand, bbArg);
182
+ return bbArg.getArgNumber ();
183
+ }
184
+ }
185
+ }
186
+ return std::nullopt;
187
+ };
188
+
189
+ int64_t numResults = returnOps.front ()->getNumOperands ();
190
+ for (int64_t i = 0 ; i < numResults; ++i) {
191
+ // Find the equivalent block argument index for the i-th operand of the
192
+ // first func.return op.
193
+ std::optional<int64_t > maybeEquiv =
194
+ findEquivalentBlockArgIdx (returnOps.front ()->getOpOperand (i));
195
+ if (!maybeEquiv.has_value ())
196
+ continue ;
197
+ int64_t bbArgIdx = *maybeEquiv;
198
+ bool allEquiv = true ;
199
+
200
+ // Check if all other func.return ops have the same equivalent block
201
+ // argument for the i-th operand. In contrast to aliasing information,
202
+ // which is just "merged", equivalence information must match across all
203
+ // func.return ops.
204
+ for (func::ReturnOp returnOp : ArrayRef (returnOps).drop_front ()) {
205
+ std::optional<int64_t > maybeEquiv =
206
+ findEquivalentBlockArgIdx (returnOp->getOpOperand (i));
207
+ if (maybeEquiv != bbArgIdx) {
208
+ allEquiv = false ;
209
+ break ;
210
+ }
211
+ }
212
+
213
+ // All func.return ops have the same equivalent block argument for the i-th
214
+ // operand.
215
+ if (allEquiv)
216
+ funcState.equivalentFuncArgs [funcOp][i] = bbArgIdx;
217
+ }
167
218
168
219
return success ();
169
220
}
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
299
350
// For each FuncOp, the number of func::CallOp it contains.
300
351
DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
301
352
WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
302
- if (!funcOp.getBody ().empty ()) {
303
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
304
- if (!returnOp)
305
- return funcOp->emitError ()
306
- << " cannot bufferize a FuncOp with tensors and "
307
- " without a unique ReturnOp" ;
308
- }
309
-
310
353
// Collect function calls and populate the caller map.
311
354
numberCallOpsContainedInFuncOp[funcOp] = 0 ;
312
355
return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
342
385
return success ();
343
386
}
344
387
388
+ // / Helper function that extracts the source from a memref.cast. If the given
389
+ // / value is not a memref.cast result, simply returns the given value.
390
+ static Value unpackCast (Value v) {
391
+ auto castOp = v.getDefiningOp <memref::CastOp>();
392
+ if (!castOp)
393
+ return v;
394
+ return castOp.getSource ();
395
+ }
396
+
397
+ // / Helper function that returns the return types (skipping casts) of the given
398
+ // / func.return ops. This function returns as many types as the return ops have
399
+ // / operands. If the i-th operand is not the same for all func.return ops, then
400
+ // / the i-th returned type is an "empty" type.
401
+ static SmallVector<Type> getReturnTypes (SmallVector<func::ReturnOp> returnOps) {
402
+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
403
+ int numOperands = returnOps.front ()->getNumOperands ();
404
+
405
+ // Helper function that unpacks memref.cast ops and returns the type.
406
+ auto getSourceType = [&](Value v) { return unpackCast (v).getType (); };
407
+
408
+ SmallVector<Type> result;
409
+ for (int i = 0 ; i < numOperands; ++i) {
410
+ // Get the type of the i-th operand of the first func.return ops.
411
+ Type t = getSourceType (returnOps.front ()->getOperand (i));
412
+
413
+ // Check if all other func.return ops have a matching operand type.
414
+ for (int j = 1 ; j < static_cast <int >(returnOps.size ()); ++j)
415
+ if (getSourceType (returnOps[j]->getOperand (i)) != t)
416
+ t = Type ();
417
+
418
+ result.push_back (t);
419
+ }
420
+
421
+ return result;
422
+ }
423
+
345
424
// / Fold return values that are memref casts and update function return types.
346
425
// /
347
426
// / During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
350
429
// / entire function body, a more concise memref type can potentially be used for
351
430
// / the return type of the function.
352
431
static void foldMemRefCasts (func::FuncOp funcOp) {
432
+ // There is nothing to do for bodiless ops.
353
433
if (funcOp.getBody ().empty ())
354
434
return ;
355
435
356
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
357
- SmallVector<Type> resultTypes;
436
+ // Compute the common result types of all return ops.
437
+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
438
+ SmallVector<Type> resultTypes = getReturnTypes (returnOps);
358
439
359
- for (OpOperand &operand : returnOp->getOpOperands ()) {
360
- if (auto castOp = operand.get ().getDefiningOp <memref::CastOp>()) {
361
- operand.set (castOp.getSource ());
362
- resultTypes.push_back (castOp.getSource ().getType ());
363
- } else {
364
- resultTypes.push_back (operand.get ().getType ());
440
+ // Remove direct casts.
441
+ for (func::ReturnOp returnOp : returnOps) {
442
+ for (OpOperand &operand : returnOp->getOpOperands ()) {
443
+ // Bail if no common result type was found.
444
+ if (resultTypes[operand.getOperandNumber ()]) {
445
+ operand.set (unpackCast (operand.get ()));
446
+ }
365
447
}
366
448
}
367
449
450
+ // Fill in the missing result types that were not the same among all
451
+ // func.return ops.
452
+ for (int i = 0 ; i < static_cast <int >(resultTypes.size ()); ++i) {
453
+ if (resultTypes[i])
454
+ continue ;
455
+ resultTypes[i] = funcOp.getFunctionType ().getResult (i);
456
+ }
457
+
458
+ // Update the function type.
368
459
auto newFuncType = FunctionType::get (
369
460
funcOp.getContext (), funcOp.getFunctionType ().getInputs (), resultTypes);
370
461
funcOp.setType (newFuncType);
0 commit comments