Skip to content

Commit 156035e

Browse files
committed
[flang][cuda] Convert module allocation/deallocation to runtime calls
Convert `cuf.allocate` and `cuf.deallocate` to the runtime entry points added in #109213 Was reviewed in #109214 but the parent branch was closed for some reason.
1 parent 56015da commit 156035e

File tree

2 files changed

+74
-25
lines changed

2 files changed

+74
-25
lines changed

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "flang/Optimizer/Dialect/FIROps.h"
1515
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1616
#include "flang/Optimizer/Support/DataLayout.h"
17+
#include "flang/Runtime/CUDA/allocatable.h"
1718
#include "flang/Runtime/CUDA/common.h"
1819
#include "flang/Runtime/CUDA/descriptor.h"
1920
#include "flang/Runtime/CUDA/memory.h"
@@ -35,22 +36,27 @@ using namespace Fortran::runtime::cuda;
3536
namespace {
3637

3738
template <typename OpTy>
38-
static bool needDoubleDescriptor(OpTy op) {
39+
static bool isPinned(OpTy op) {
40+
if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
41+
return true;
42+
return false;
43+
}
44+
45+
template <typename OpTy>
46+
static bool hasDoubleDescriptors(OpTy op) {
3947
if (auto declareOp =
4048
mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
4149
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
4250
declareOp.getMemref().getDefiningOp())) {
43-
if (declareOp.getDataAttr() &&
44-
*declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
51+
if (isPinned(declareOp))
4552
return false;
4653
return true;
4754
}
4855
} else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
4956
op.getBox().getDefiningOp())) {
5057
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
5158
declareOp.getMemref().getDefiningOp())) {
52-
if (declareOp.getDataAttr() &&
53-
*declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
59+
if (isPinned(declareOp))
5460
return false;
5561
return true;
5662
}
@@ -108,17 +114,22 @@ struct CufAllocateOpConversion
108114
if (op.getPinned())
109115
return mlir::failure();
110116

111-
// TODO: Allocation of module variable will need more work as the descriptor
112-
// will be duplicated and needs to be synced after allocation.
113-
if (needDoubleDescriptor(op))
114-
return mlir::failure();
117+
auto mod = op->getParentOfType<mlir::ModuleOp>();
118+
fir::FirOpBuilder builder(rewriter, mod);
119+
mlir::Location loc = op.getLoc();
120+
121+
if (hasDoubleDescriptors(op)) {
122+
// Allocation for module variable are done with custom runtime entry point
123+
// so the descriptors can be synchronized.
124+
mlir::func::FuncOp func =
125+
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
126+
loc, builder);
127+
return convertOpToCall(op, rewriter, func);
128+
}
115129

116130
// Allocation for local descriptor falls back on the standard runtime
117131
// AllocatableAllocate as the dedicated allocator is set in the descriptor
118132
// before the call.
119-
auto mod = op->template getParentOfType<mlir::ModuleOp>();
120-
fir::FirOpBuilder builder(rewriter, mod);
121-
mlir::Location loc = op.getLoc();
122133
mlir::func::FuncOp func =
123134
fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocate)>(loc,
124135
builder);
@@ -133,17 +144,23 @@ struct CufDeallocateOpConversion
133144
mlir::LogicalResult
134145
matchAndRewrite(cuf::DeallocateOp op,
135146
mlir::PatternRewriter &rewriter) const override {
136-
// TODO: Allocation of module variable will need more work as the descriptor
137-
// will be duplicated and needs to be synced after allocation.
138-
if (needDoubleDescriptor(op))
139-
return mlir::failure();
140147

141-
// Deallocation for local descriptor falls back on the standard runtime
142-
// AllocatableDeallocate as the dedicated deallocator is set in the
143-
// descriptor before the call.
144148
auto mod = op->getParentOfType<mlir::ModuleOp>();
145149
fir::FirOpBuilder builder(rewriter, mod);
146150
mlir::Location loc = op.getLoc();
151+
152+
if (hasDoubleDescriptors(op)) {
153+
// Deallocation for module variable are done with custom runtime entry
154+
// point so the descriptors can be synchronized.
155+
mlir::func::FuncOp func =
156+
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
157+
loc, builder);
158+
return convertOpToCall(op, rewriter, func);
159+
}
160+
161+
// Deallocation for local descriptor falls back on the standard runtime
162+
// AllocatableDeallocate as the dedicated deallocator is set in the
163+
// descriptor before the call.
147164
mlir::func::FuncOp func =
148165
fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
149166
builder);
@@ -448,10 +465,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
448465
}
449466
return true;
450467
});
451-
target.addDynamicallyLegalOp<cuf::AllocateOp>(
452-
[](::cuf::AllocateOp op) { return needDoubleDescriptor(op); });
453-
target.addDynamicallyLegalOp<cuf::DeallocateOp>(
454-
[](::cuf::DeallocateOp op) { return needDoubleDescriptor(op); });
455468
target.addDynamicallyLegalOp<cuf::DataTransferOp>(
456469
[](::cuf::DataTransferOp op) {
457470
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,14 @@ func.func @_QPsub3() {
5454
}
5555

5656
// CHECK-LABEL: func.func @_QPsub3()
57-
// CHECK: cuf.allocate
58-
// CHECK: cuf.deallocate
57+
// CHECK: %[[A_ADDR:.*]] = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
58+
// CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_ADDR]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
59+
60+
// CHECK: %[[A_BOX:.*]] = fir.convert %[[A]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
61+
// CHECK: fir.call @_FortranACUFAllocatableAllocate(%[[A_BOX]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
62+
63+
// CHECK: %[[A_BOX:.*]] = fir.convert %[[A]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
64+
// CHECK: fir.call @_FortranACUFAllocatableDeallocate(%[[A_BOX]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
5965

6066
func.func @_QPsub4() attributes {cuf.proc_attr = #cuf.cuda_proc<device>} {
6167
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
@@ -95,4 +101,34 @@ func.func @_QPsub5() {
95101
// CHECK: fir.call @_FortranAAllocatableAllocate({{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
96102
// CHECK: fir.call @_FortranAAllocatableDeallocate({{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
97103

104+
105+
fir.global @_QMdataEb {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
106+
%c0 = arith.constant 0 : index
107+
%0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
108+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
109+
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
110+
fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
111+
}
112+
113+
func.func @_QQsub6() attributes {fir.bindc_name = "test"} {
114+
%c0_i32 = arith.constant 0 : i32
115+
%c10_i32 = arith.constant 10 : i32
116+
%c1 = arith.constant 1 : index
117+
%0 = fir.address_of(@_QMdataEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
118+
%1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMdataEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
119+
%2 = fir.convert %1#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
120+
%3 = fir.convert %c1 : (index) -> i64
121+
%4 = fir.convert %c10_i32 : (i32) -> i64
122+
%5 = fir.call @_FortranAAllocatableSetBounds(%2, %c0_i32, %3, %4) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
123+
%6 = cuf.allocate %1#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>} -> i32
124+
return
125+
}
126+
127+
// CHECK-LABEL: func.func @_QQsub6() attributes {fir.bindc_name = "test"}
128+
// CHECK: %[[B_ADDR:.*]] = fir.address_of(@_QMdataEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
129+
// CHECK: %[[B:.*]]:2 = hlfir.declare %[[B_ADDR]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMdataEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
130+
// CHECK: _FortranAAllocatableSetBounds
131+
// CHECK: %[[B_BOX:.*]] = fir.convert %[[B]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
132+
// CHECK: fir.call @_FortranACUFAllocatableAllocate(%[[B_BOX]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
133+
98134
} // end of module

0 commit comments

Comments
 (0)