@@ -86,20 +86,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
86
86
return state.addExtension <FuncAnalysisState>();
87
87
}
88
88
89
- // / Return the unique ReturnOp that terminates `funcOp`.
90
- // / Return nullptr if there is no such unique ReturnOp.
91
- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92
- func::ReturnOp returnOp;
93
- for (Block &b : funcOp.getBody ()) {
94
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
95
- if (returnOp)
96
- return nullptr ;
97
- returnOp = candidateOp;
98
- }
99
- }
100
- return returnOp;
101
- }
102
-
103
89
namespace {
104
90
105
91
// / Annotate IR with the results of the analysis. For testing purposes only.
@@ -146,24 +132,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146
132
return success ();
147
133
}
148
134
149
- // Support only single return-terminated block in the function.
150
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
151
- assert (returnOp && " expected func with single return op" );
152
-
153
- for (OpOperand &returnVal : returnOp->getOpOperands ())
154
- if (isa<RankedTensorType>(returnVal.get ().getType ()))
155
- for (BlockArgument bbArg : funcOp.getArguments ())
156
- if (isa<RankedTensorType>(bbArg.getType ())) {
157
- int64_t returnIdx = returnVal.getOperandNumber ();
158
- int64_t bbArgIdx = bbArg.getArgNumber ();
159
- if (state.areEquivalentBufferizedValues (returnVal.get (), bbArg)) {
160
- funcState.equivalentFuncArgs [funcOp][returnIdx] = bbArgIdx;
161
- if (state.getOptions ().testAnalysisOnly )
162
- annotateEquivalentReturnBbArg (returnVal, bbArg);
135
+ // Find all func.return ops.
136
+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
137
+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
138
+
139
+ // Build alias sets. Merge all aliases from all func.return ops.
140
+ for (BlockArgument bbArg : funcOp.getArguments ()) {
141
+ if (isa<RankedTensorType>(bbArg.getType ())) {
142
+ int64_t bbArgIdx = bbArg.getArgNumber ();
143
+ // Store aliases in a set, so that we don't add the same alias twice.
144
+ SetVector<int64_t > aliases;
145
+ for (func::ReturnOp returnOp : returnOps) {
146
+ for (OpOperand &returnVal : returnOp->getOpOperands ()) {
147
+ if (isa<RankedTensorType>(returnVal.get ().getType ())) {
148
+ int64_t returnIdx = returnVal.getOperandNumber ();
149
+ if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
150
+ aliases.insert (returnIdx);
163
151
}
164
- if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
165
- funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (returnIdx);
166
152
}
153
+ }
154
+ for (int64_t alias : aliases)
155
+ funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (alias);
156
+ }
157
+ }
158
+
159
+ // Build equivalence sets.
160
+ // Helper function that finds an equivalent block argument index for the
161
+ // given OpOperand. Return std::nullopt if no equivalent block argument could
162
+ // be found.
163
+ auto findEquivalentBlockArgIdx =
164
+ [&](OpOperand &opOperand) -> std::optional<int64_t > {
165
+ Value v = opOperand.get ();
166
+ if (!isa<TensorType>(v.getType ()))
167
+ return std::nullopt;
168
+ for (BlockArgument bbArg : funcOp.getArguments ()) {
169
+ if (isa<RankedTensorType>(bbArg.getType ())) {
170
+ if (state.areEquivalentBufferizedValues (v, bbArg)) {
171
+ if (state.getOptions ().testAnalysisOnly )
172
+ annotateEquivalentReturnBbArg (opOperand, bbArg);
173
+ return bbArg.getArgNumber ();
174
+ }
175
+ }
176
+ }
177
+ return std::nullopt;
178
+ };
179
+
180
+ int64_t numResults = returnOps.front ()->getNumOperands ();
181
+ for (int64_t i = 0 ; i < numResults; ++i) {
182
+ // Find the equivalent block argument index for the i-th operand of the
183
+ // first func.return op.
184
+ std::optional<int64_t > maybeEquiv =
185
+ findEquivalentBlockArgIdx (returnOps.front ()->getOpOperand (i));
186
+ if (!maybeEquiv.has_value ())
187
+ continue ;
188
+ int64_t bbArgIdx = *maybeEquiv;
189
+ bool allEquiv = true ;
190
+
191
+ // Check if all other func.return ops have the same equivalent block
192
+ // argument for the i-th operand. In contrast to aliasing information,
193
+ // which is just "merged", equivalence information must match across all
194
+ // func.return ops.
195
+ for (func::ReturnOp returnOp : ArrayRef (returnOps).drop_front ()) {
196
+ std::optional<int64_t > maybeEquiv =
197
+ findEquivalentBlockArgIdx (returnOp->getOpOperand (i));
198
+ if (maybeEquiv != bbArgIdx) {
199
+ allEquiv = false ;
200
+ break ;
201
+ }
202
+ }
203
+
204
+ // All func.return ops have the same equivalent block argument for the i-th
205
+ // operand.
206
+ if (allEquiv)
207
+ funcState.equivalentFuncArgs [funcOp][i] = bbArgIdx;
208
+ }
167
209
168
210
return success ();
169
211
}
@@ -302,14 +344,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
302
344
// For each FuncOp, the number of func::CallOp it contains.
303
345
DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
304
346
WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
305
- if (!funcOp.getBody ().empty ()) {
306
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
307
- if (!returnOp)
308
- return funcOp->emitError ()
309
- << " cannot bufferize a FuncOp with tensors and "
310
- " without a unique ReturnOp" ;
311
- }
312
-
313
347
// Collect function calls and populate the caller map.
314
348
numberCallOpsContainedInFuncOp[funcOp] = 0 ;
315
349
return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
@@ -351,6 +385,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
351
385
return success ();
352
386
}
353
387
388
+ // / Helper function that extracts the source from a memref.cast. If the given
389
+ // / value is not a memref.cast result, simply returns the given value.
390
+ static Value unpackCast (Value v) {
391
+ auto castOp = v.getDefiningOp <memref::CastOp>();
392
+ if (!castOp)
393
+ return v;
394
+ return castOp.getSource ();
395
+ }
396
+
397
+ // / Helper function that returns the return types (skipping casts) of the given
398
+ // / func.return ops. This function returns as many types as the return ops have
399
+ // / operands. If the i-th operand is not the same for all func.return ops, then
400
+ // / the i-th returned type is an "empty" type.
401
+ static SmallVector<Type> getReturnTypes (SmallVector<func::ReturnOp> returnOps) {
402
+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
403
+ int numOperands = returnOps.front ()->getNumOperands ();
404
+
405
+ // Helper function that unpacks memref.cast ops and returns the type.
406
+ auto getSourceType = [&](Value v) { return unpackCast (v).getType (); };
407
+
408
+ SmallVector<Type> result;
409
+ for (int i = 0 ; i < numOperands; ++i) {
410
+ // Get the type of the i-th operand of the first func.return ops.
411
+ Type t = getSourceType (returnOps.front ()->getOperand (i));
412
+
413
+ // Check if all other func.return ops have a matching operand type.
414
+ for (int j = 1 ; j < static_cast <int >(returnOps.size ()); ++j)
415
+ if (getSourceType (returnOps[j]->getOperand (i)) != t)
416
+ t = Type ();
417
+
418
+ result.push_back (t);
419
+ }
420
+
421
+ return result;
422
+ }
423
+
354
424
// / Fold return values that are memref casts and update function return types.
355
425
// /
356
426
// / During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -359,21 +429,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
359
429
// / entire function body, a more concise memref type can potentially be used for
360
430
// / the return type of the function.
361
431
static void foldMemRefCasts (func::FuncOp funcOp) {
432
+ // There is nothing to do for bodiless ops.
362
433
if (funcOp.getBody ().empty ())
363
434
return ;
364
435
365
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
366
- SmallVector<Type> resultTypes;
436
+ // Compute the common result types of all return ops.
437
+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
438
+ SmallVector<Type> resultTypes = getReturnTypes (returnOps);
367
439
368
- for (OpOperand &operand : returnOp->getOpOperands ()) {
369
- if (auto castOp = operand.get ().getDefiningOp <memref::CastOp>()) {
370
- operand.set (castOp.getSource ());
371
- resultTypes.push_back (castOp.getSource ().getType ());
372
- } else {
373
- resultTypes.push_back (operand.get ().getType ());
440
+ // Remove direct casts.
441
+ for (func::ReturnOp returnOp : returnOps) {
442
+ for (OpOperand &operand : returnOp->getOpOperands ()) {
443
+ // Bail if no common result type was found.
444
+ if (resultTypes[operand.getOperandNumber ()]) {
445
+ operand.set (unpackCast (operand.get ()));
446
+ }
374
447
}
375
448
}
376
449
450
+ // Fill in the missing result types that were not the same among all
451
+ // func.return ops.
452
+ for (int i = 0 ; i < static_cast <int >(resultTypes.size ()); ++i) {
453
+ if (resultTypes[i])
454
+ continue ;
455
+ resultTypes[i] = funcOp.getFunctionType ().getResult (i);
456
+ }
457
+
458
+ // Update the function type.
377
459
auto newFuncType = FunctionType::get (
378
460
funcOp.getContext (), funcOp.getFunctionType ().getInputs (), resultTypes);
379
461
funcOp.setType (newFuncType);
0 commit comments