@@ -234,6 +234,60 @@ class SaveResultOpConversion
234
234
}
235
235
};
236
236
237
+ template <typename OpTy>
238
+ static mlir::LogicalResult
239
+ processReturnLikeOp (OpTy ret, mlir::Value newArg,
240
+ mlir::PatternRewriter &rewriter) {
241
+ auto loc = ret.getLoc ();
242
+ rewriter.setInsertionPoint (ret);
243
+ mlir::Value resultValue = ret.getOperand (0 );
244
+ fir::LoadOp resultLoad;
245
+ mlir::Value resultStorage;
246
+ // Identify result local storage.
247
+ if (auto load = resultValue.getDefiningOp <fir::LoadOp>()) {
248
+ resultLoad = load;
249
+ resultStorage = load.getMemref ();
250
+ // The result alloca may be behind a fir.declare, if any.
251
+ if (auto declare = resultStorage.getDefiningOp <fir::DeclareOp>())
252
+ resultStorage = declare.getMemref ();
253
+ }
254
+ // Replace old local storage with new storage argument, unless
255
+ // the derived type is C_PTR/C_FUN_PTR, in which case the return
256
+ // type is updated to return void* (no new argument is passed).
257
+ if (fir::isa_builtin_cptr_type (resultValue.getType ())) {
258
+ auto module = ret->template getParentOfType <mlir::ModuleOp>();
259
+ FirOpBuilder builder (rewriter, module );
260
+ mlir::Value cptr = resultValue;
261
+ if (resultLoad) {
262
+ // Replace whole derived type load by component load.
263
+ cptr = resultLoad.getMemref ();
264
+ rewriter.setInsertionPoint (resultLoad);
265
+ }
266
+ mlir::Value newResultValue =
267
+ fir::factory::genCPtrOrCFunptrValue (builder, loc, cptr);
268
+ newResultValue = builder.createConvert (
269
+ loc, getVoidPtrType (ret.getContext ()), newResultValue);
270
+ rewriter.setInsertionPoint (ret);
271
+ rewriter.replaceOpWithNewOp <OpTy>(ret, mlir::ValueRange{newResultValue});
272
+ } else if (resultStorage) {
273
+ resultStorage.replaceAllUsesWith (newArg);
274
+ rewriter.replaceOpWithNewOp <OpTy>(ret);
275
+ } else {
276
+ // The result storage may have been optimized out by a memory to
277
+ // register pass, this is possible for fir.box results, or fir.record
278
+ // with no length parameters. Simply store the result in the result
279
+ // storage. at the return point.
280
+ rewriter.create <fir::StoreOp>(loc, resultValue, newArg);
281
+ rewriter.replaceOpWithNewOp <OpTy>(ret);
282
+ }
283
+ // Delete result old local storage if unused.
284
+ if (resultStorage)
285
+ if (auto alloc = resultStorage.getDefiningOp <fir::AllocaOp>())
286
+ if (alloc->use_empty ())
287
+ rewriter.eraseOp (alloc);
288
+ return mlir::success ();
289
+ }
290
+
237
291
class ReturnOpConversion : public mlir ::OpRewritePattern<mlir::func::ReturnOp> {
238
292
public:
239
293
using OpRewritePattern::OpRewritePattern;
@@ -242,55 +296,23 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
242
296
llvm::LogicalResult
243
297
matchAndRewrite (mlir::func::ReturnOp ret,
244
298
mlir::PatternRewriter &rewriter) const override {
245
- auto loc = ret.getLoc ();
246
- rewriter.setInsertionPoint (ret);
247
- mlir::Value resultValue = ret.getOperand (0 );
248
- fir::LoadOp resultLoad;
249
- mlir::Value resultStorage;
250
- // Identify result local storage.
251
- if (auto load = resultValue.getDefiningOp <fir::LoadOp>()) {
252
- resultLoad = load;
253
- resultStorage = load.getMemref ();
254
- // The result alloca may be behind a fir.declare, if any.
255
- if (auto declare = resultStorage.getDefiningOp <fir::DeclareOp>())
256
- resultStorage = declare.getMemref ();
257
- }
258
- // Replace old local storage with new storage argument, unless
259
- // the derived type is C_PTR/C_FUN_PTR, in which case the return
260
- // type is updated to return void* (no new argument is passed).
261
- if (fir::isa_builtin_cptr_type (resultValue.getType ())) {
262
- auto module = ret->getParentOfType <mlir::ModuleOp>();
263
- FirOpBuilder builder (rewriter, module );
264
- mlir::Value cptr = resultValue;
265
- if (resultLoad) {
266
- // Replace whole derived type load by component load.
267
- cptr = resultLoad.getMemref ();
268
- rewriter.setInsertionPoint (resultLoad);
269
- }
270
- mlir::Value newResultValue =
271
- fir::factory::genCPtrOrCFunptrValue (builder, loc, cptr);
272
- newResultValue = builder.createConvert (
273
- loc, getVoidPtrType (ret.getContext ()), newResultValue);
274
- rewriter.setInsertionPoint (ret);
275
- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(
276
- ret, mlir::ValueRange{newResultValue});
277
- } else if (resultStorage) {
278
- resultStorage.replaceAllUsesWith (newArg);
279
- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
280
- } else {
281
- // The result storage may have been optimized out by a memory to
282
- // register pass, this is possible for fir.box results, or fir.record
283
- // with no length parameters. Simply store the result in the result
284
- // storage. at the return point.
285
- rewriter.create <fir::StoreOp>(loc, resultValue, newArg);
286
- rewriter.replaceOpWithNewOp <mlir::func::ReturnOp>(ret);
287
- }
288
- // Delete result old local storage if unused.
289
- if (resultStorage)
290
- if (auto alloc = resultStorage.getDefiningOp <fir::AllocaOp>())
291
- if (alloc->use_empty ())
292
- rewriter.eraseOp (alloc);
293
- return mlir::success ();
299
+ return processReturnLikeOp (ret, newArg, rewriter);
300
+ }
301
+
302
+ private:
303
+ mlir::Value newArg;
304
+ };
305
+
306
+ class GPUReturnOpConversion
307
+ : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
308
+ public:
309
+ using OpRewritePattern::OpRewritePattern;
310
+ GPUReturnOpConversion (mlir::MLIRContext *context, mlir::Value newArg)
311
+ : OpRewritePattern(context), newArg{newArg} {}
312
+ llvm::LogicalResult
313
+ matchAndRewrite (mlir::gpu::ReturnOp ret,
314
+ mlir::PatternRewriter &rewriter) const override {
315
+ return processReturnLikeOp (ret, newArg, rewriter);
294
316
}
295
317
296
318
private:
@@ -373,6 +395,9 @@ class AbstractResultOpt
373
395
patterns.insert <ReturnOpConversion>(context, newArg);
374
396
target.addDynamicallyLegalOp <mlir::func::ReturnOp>(
375
397
[](mlir::func::ReturnOp ret) { return ret.getOperands ().empty (); });
398
+ patterns.insert <GPUReturnOpConversion>(context, newArg);
399
+ target.addDynamicallyLegalOp <mlir::gpu::ReturnOp>(
400
+ [](mlir::gpu::ReturnOp ret) { return ret.getOperands ().empty (); });
376
401
assert (func.getFunctionType () ==
377
402
getNewFunctionType (funcTy, shouldBoxResult));
378
403
} else {
0 commit comments