Skip to content

Commit 75623bf

Browse files
authored
[flang][cuda] Handle gpu.return in AbstractResult pass (#119035)
1 parent 953838d commit 75623bf

File tree

2 files changed

+111
-49
lines changed

2 files changed

+111
-49
lines changed

flang/lib/Optimizer/Transforms/AbstractResult.cpp

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,60 @@ class SaveResultOpConversion
234234
}
235235
};
236236

237+
template <typename OpTy>
238+
static mlir::LogicalResult
239+
processReturnLikeOp(OpTy ret, mlir::Value newArg,
240+
mlir::PatternRewriter &rewriter) {
241+
auto loc = ret.getLoc();
242+
rewriter.setInsertionPoint(ret);
243+
mlir::Value resultValue = ret.getOperand(0);
244+
fir::LoadOp resultLoad;
245+
mlir::Value resultStorage;
246+
// Identify result local storage.
247+
if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
248+
resultLoad = load;
249+
resultStorage = load.getMemref();
250+
// The result alloca may be behind a fir.declare, if any.
251+
if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
252+
resultStorage = declare.getMemref();
253+
}
254+
// Replace old local storage with new storage argument, unless
255+
// the derived type is C_PTR/C_FUN_PTR, in which case the return
256+
// type is updated to return void* (no new argument is passed).
257+
if (fir::isa_builtin_cptr_type(resultValue.getType())) {
258+
auto module = ret->template getParentOfType<mlir::ModuleOp>();
259+
FirOpBuilder builder(rewriter, module);
260+
mlir::Value cptr = resultValue;
261+
if (resultLoad) {
262+
// Replace whole derived type load by component load.
263+
cptr = resultLoad.getMemref();
264+
rewriter.setInsertionPoint(resultLoad);
265+
}
266+
mlir::Value newResultValue =
267+
fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
268+
newResultValue = builder.createConvert(
269+
loc, getVoidPtrType(ret.getContext()), newResultValue);
270+
rewriter.setInsertionPoint(ret);
271+
rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue});
272+
} else if (resultStorage) {
273+
resultStorage.replaceAllUsesWith(newArg);
274+
rewriter.replaceOpWithNewOp<OpTy>(ret);
275+
} else {
276+
// The result storage may have been optimized out by a memory to
277+
// register pass, this is possible for fir.box results, or fir.record
278+
// with no length parameters. Simply store the result in the result
279+
// storage. at the return point.
280+
rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
281+
rewriter.replaceOpWithNewOp<OpTy>(ret);
282+
}
283+
// Delete result old local storage if unused.
284+
if (resultStorage)
285+
if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
286+
if (alloc->use_empty())
287+
rewriter.eraseOp(alloc);
288+
return mlir::success();
289+
}
290+
237291
class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
238292
public:
239293
using OpRewritePattern::OpRewritePattern;
@@ -242,55 +296,23 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
242296
llvm::LogicalResult
243297
matchAndRewrite(mlir::func::ReturnOp ret,
244298
mlir::PatternRewriter &rewriter) const override {
245-
auto loc = ret.getLoc();
246-
rewriter.setInsertionPoint(ret);
247-
mlir::Value resultValue = ret.getOperand(0);
248-
fir::LoadOp resultLoad;
249-
mlir::Value resultStorage;
250-
// Identify result local storage.
251-
if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
252-
resultLoad = load;
253-
resultStorage = load.getMemref();
254-
// The result alloca may be behind a fir.declare, if any.
255-
if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
256-
resultStorage = declare.getMemref();
257-
}
258-
// Replace old local storage with new storage argument, unless
259-
// the derived type is C_PTR/C_FUN_PTR, in which case the return
260-
// type is updated to return void* (no new argument is passed).
261-
if (fir::isa_builtin_cptr_type(resultValue.getType())) {
262-
auto module = ret->getParentOfType<mlir::ModuleOp>();
263-
FirOpBuilder builder(rewriter, module);
264-
mlir::Value cptr = resultValue;
265-
if (resultLoad) {
266-
// Replace whole derived type load by component load.
267-
cptr = resultLoad.getMemref();
268-
rewriter.setInsertionPoint(resultLoad);
269-
}
270-
mlir::Value newResultValue =
271-
fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
272-
newResultValue = builder.createConvert(
273-
loc, getVoidPtrType(ret.getContext()), newResultValue);
274-
rewriter.setInsertionPoint(ret);
275-
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
276-
ret, mlir::ValueRange{newResultValue});
277-
} else if (resultStorage) {
278-
resultStorage.replaceAllUsesWith(newArg);
279-
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
280-
} else {
281-
// The result storage may have been optimized out by a memory to
282-
// register pass, this is possible for fir.box results, or fir.record
283-
// with no length parameters. Simply store the result in the result
284-
// storage. at the return point.
285-
rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
286-
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
287-
}
288-
// Delete result old local storage if unused.
289-
if (resultStorage)
290-
if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
291-
if (alloc->use_empty())
292-
rewriter.eraseOp(alloc);
293-
return mlir::success();
299+
return processReturnLikeOp(ret, newArg, rewriter);
300+
}
301+
302+
private:
303+
mlir::Value newArg;
304+
};
305+
306+
class GPUReturnOpConversion
307+
: public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
308+
public:
309+
using OpRewritePattern::OpRewritePattern;
310+
GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
311+
: OpRewritePattern(context), newArg{newArg} {}
312+
llvm::LogicalResult
313+
matchAndRewrite(mlir::gpu::ReturnOp ret,
314+
mlir::PatternRewriter &rewriter) const override {
315+
return processReturnLikeOp(ret, newArg, rewriter);
294316
}
295317

296318
private:
@@ -373,6 +395,9 @@ class AbstractResultOpt
373395
patterns.insert<ReturnOpConversion>(context, newArg);
374396
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
375397
[](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
398+
patterns.insert<GPUReturnOpConversion>(context, newArg);
399+
target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>(
400+
[](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); });
376401
assert(func.getFunctionType() ==
377402
getNewFunctionType(funcTy, shouldBoxResult));
378403
} else {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: fir-opt -pass-pipeline='builtin.module(gpu.module(gpu.func(abstract-result)))' %s | FileCheck %s
2+
3+
gpu.module @test {
4+
gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg1: !fir.ref<f32>) -> !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> {
5+
%c1_i32 = arith.constant 1 : i32
6+
%18 = fir.dummy_scope : !fir.dscope
7+
%19 = fir.declare %arg0 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Ea"} : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.dscope) -> !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>
8+
%20 = fir.declare %arg1 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Eb"} : (!fir.ref<f32>, !fir.dscope) -> !fir.ref<f32>
9+
%21 = fir.alloca !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> {bindc_name = "c", uniq_name = "_QMinterval_mFtest1Ec"}
10+
%22 = fir.declare %21 {uniq_name = "_QMinterval_mFtest1Ec"} : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>) -> !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>
11+
%23 = fir.alloca i32 {bindc_name = "warpsize", uniq_name = "_QMcudadeviceECwarpsize"}
12+
%24 = fir.declare %23 {uniq_name = "_QMcudadeviceECwarpsize"} : (!fir.ref<i32>) -> !fir.ref<i32>
13+
%25 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
14+
%26 = fir.coordinate_of %19, %25 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
15+
%27 = fir.load %20 : !fir.ref<f32>
16+
%28 = arith.negf %27 fastmath<contract> : f32
17+
%29 = fir.load %26 : !fir.ref<f32>
18+
%30 = fir.call @__fadd_rd(%29, %28) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32
19+
%31 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
20+
%32 = fir.coordinate_of %22, %31 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
21+
fir.store %30 to %32 : !fir.ref<f32>
22+
%33 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
23+
%34 = fir.coordinate_of %19, %33 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
24+
%35 = fir.load %20 : !fir.ref<f32>
25+
%36 = arith.negf %35 fastmath<contract> : f32
26+
%37 = fir.load %34 : !fir.ref<f32>
27+
%38 = fir.call @__fadd_ru(%37, %36) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32
28+
%39 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
29+
%40 = fir.coordinate_of %22, %39 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32>
30+
fir.store %38 to %40 : !fir.ref<f32>
31+
%41 = fir.load %22 : !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>
32+
gpu.return %41 : !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>
33+
}
34+
}
35+
36+
// CHECK: gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg1: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg2: !fir.ref<f32>) {
37+
// CHECK: gpu.return{{$}}

0 commit comments

Comments
 (0)