14
14
#include " flang/Optimizer/Dialect/FIROps.h"
15
15
#include " flang/Optimizer/HLFIR/HLFIROps.h"
16
16
#include " flang/Optimizer/Support/DataLayout.h"
17
+ #include " flang/Runtime/CUDA/allocatable.h"
17
18
#include " flang/Runtime/CUDA/common.h"
18
19
#include " flang/Runtime/CUDA/descriptor.h"
19
20
#include " flang/Runtime/CUDA/memory.h"
@@ -35,22 +36,27 @@ using namespace Fortran::runtime::cuda;
35
36
namespace {
36
37
37
38
template <typename OpTy>
38
- static bool needDoubleDescriptor (OpTy op) {
39
+ static bool isPinned (OpTy op) {
40
+ if (op.getDataAttr () && *op.getDataAttr () == cuf::DataAttribute::Pinned)
41
+ return true ;
42
+ return false ;
43
+ }
44
+
45
+ template <typename OpTy>
46
+ static bool hasDoubleDescriptors (OpTy op) {
39
47
if (auto declareOp =
40
48
mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox ().getDefiningOp ())) {
41
49
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
42
50
declareOp.getMemref ().getDefiningOp ())) {
43
- if (declareOp.getDataAttr () &&
44
- *declareOp.getDataAttr () == cuf::DataAttribute::Pinned)
51
+ if (isPinned (declareOp))
45
52
return false ;
46
53
return true ;
47
54
}
48
55
} else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
49
56
op.getBox ().getDefiningOp ())) {
50
57
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
51
58
declareOp.getMemref ().getDefiningOp ())) {
52
- if (declareOp.getDataAttr () &&
53
- *declareOp.getDataAttr () == cuf::DataAttribute::Pinned)
59
+ if (isPinned (declareOp))
54
60
return false ;
55
61
return true ;
56
62
}
@@ -108,17 +114,22 @@ struct CufAllocateOpConversion
108
114
if (op.getPinned ())
109
115
return mlir::failure ();
110
116
111
- // TODO: Allocation of module variable will need more work as the descriptor
112
- // will be duplicated and needs to be synced after allocation.
113
- if (needDoubleDescriptor (op))
114
- return mlir::failure ();
117
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
118
+ fir::FirOpBuilder builder (rewriter, mod);
119
+ mlir::Location loc = op.getLoc ();
120
+
121
+ if (hasDoubleDescriptors (op)) {
122
+ // Allocation for module variable are done with custom runtime entry point
123
+ // so the descriptors can be synchronized.
124
+ mlir::func::FuncOp func =
125
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFAllocatableAllocate)>(
126
+ loc, builder);
127
+ return convertOpToCall (op, rewriter, func);
128
+ }
115
129
116
130
// Allocation for local descriptor falls back on the standard runtime
117
131
// AllocatableAllocate as the dedicated allocator is set in the descriptor
118
132
// before the call.
119
- auto mod = op->template getParentOfType <mlir::ModuleOp>();
120
- fir::FirOpBuilder builder (rewriter, mod);
121
- mlir::Location loc = op.getLoc ();
122
133
mlir::func::FuncOp func =
123
134
fir::runtime::getRuntimeFunc<mkRTKey (AllocatableAllocate)>(loc,
124
135
builder);
@@ -133,17 +144,23 @@ struct CufDeallocateOpConversion
133
144
mlir::LogicalResult
134
145
matchAndRewrite (cuf::DeallocateOp op,
135
146
mlir::PatternRewriter &rewriter) const override {
136
- // TODO: Allocation of module variable will need more work as the descriptor
137
- // will be duplicated and needs to be synced after allocation.
138
- if (needDoubleDescriptor (op))
139
- return mlir::failure ();
140
147
141
- // Deallocation for local descriptor falls back on the standard runtime
142
- // AllocatableDeallocate as the dedicated deallocator is set in the
143
- // descriptor before the call.
144
148
auto mod = op->getParentOfType <mlir::ModuleOp>();
145
149
fir::FirOpBuilder builder (rewriter, mod);
146
150
mlir::Location loc = op.getLoc ();
151
+
152
+ if (hasDoubleDescriptors (op)) {
153
+ // Deallocation for module variable are done with custom runtime entry
154
+ // point so the descriptors can be synchronized.
155
+ mlir::func::FuncOp func =
156
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFAllocatableDeallocate)>(
157
+ loc, builder);
158
+ return convertOpToCall (op, rewriter, func);
159
+ }
160
+
161
+ // Deallocation for local descriptor falls back on the standard runtime
162
+ // AllocatableDeallocate as the dedicated deallocator is set in the
163
+ // descriptor before the call.
147
164
mlir::func::FuncOp func =
148
165
fir::runtime::getRuntimeFunc<mkRTKey (AllocatableDeallocate)>(loc,
149
166
builder);
@@ -448,10 +465,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
448
465
}
449
466
return true ;
450
467
});
451
- target.addDynamicallyLegalOp <cuf::AllocateOp>(
452
- [](::cuf::AllocateOp op) { return needDoubleDescriptor (op); });
453
- target.addDynamicallyLegalOp <cuf::DeallocateOp>(
454
- [](::cuf::DeallocateOp op) { return needDoubleDescriptor (op); });
455
468
target.addDynamicallyLegalOp <cuf::DataTransferOp>(
456
469
[](::cuf::DataTransferOp op) {
457
470
mlir::Type srcTy = fir::unwrapRefType (op.getSrc ().getType ());
0 commit comments