Skip to content

Commit 67a8857

Browse files
authored
[flang][cuda] Handle pointer allocation with double descriptors (#124183)
1 parent 13dae34 commit 67a8857

File tree

4 files changed

+108
-14
lines changed

4 files changed

+108
-14
lines changed

flang/include/flang/Runtime/CUDA/pointer.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,26 @@ int RTDECL(CUFPointerAllocate)(Descriptor &, int64_t stream = -1,
2121
bool hasStat = false, const Descriptor *errMsg = nullptr,
2222
const char *sourceFile = nullptr, int sourceLine = 0);
2323

24+
/// Perform allocation of the descriptor with synchronization of it when
25+
/// necessary.
26+
int RTDECL(CUFPointerAllocateSync)(Descriptor &, int64_t stream = -1,
27+
bool hasStat = false, const Descriptor *errMsg = nullptr,
28+
const char *sourceFile = nullptr, int sourceLine = 0);
29+
2430
/// Perform allocation of the descriptor without synchronization. Assign data
2531
/// from source.
2632
int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
2733
const Descriptor &source, int64_t stream = -1, bool hasStat = false,
2834
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
2935
int sourceLine = 0);
3036

37+
/// Perform allocation of the descriptor with synchronization of it when
38+
/// necessary. Assign data from source.
39+
int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer,
40+
const Descriptor &source, int64_t stream = -1, bool hasStat = false,
41+
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
42+
int sourceLine = 0);
43+
3144
} // extern "C"
3245

3346
} // namespace Fortran::runtime::cuda

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,22 @@ struct CUFAllocateOpConversion
172172
isPointer = true;
173173

174174
if (hasDoubleDescriptors(op)) {
175-
if (isPointer)
176-
TODO(loc, "pointer allocation with double descriptors");
177175
// Allocation for module variable are done with custom runtime entry point
178176
// so the descriptors can be synchronized.
179177
mlir::func::FuncOp func;
180-
if (op.getSource())
181-
func = fir::runtime::getRuntimeFunc<mkRTKey(
182-
CUFAllocatableAllocateSourceSync)>(loc, builder);
183-
else
178+
if (op.getSource()) {
179+
func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
180+
CUFPointerAllocateSourceSync)>(loc, builder)
181+
: fir::runtime::getRuntimeFunc<mkRTKey(
182+
CUFAllocatableAllocateSourceSync)>(loc, builder);
183+
} else {
184184
func =
185-
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocateSync)>(
186-
loc, builder);
185+
isPointer
186+
? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
187+
loc, builder)
188+
: fir::runtime::getRuntimeFunc<mkRTKey(
189+
CUFAllocatableAllocateSync)>(loc, builder);
190+
}
187191
return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
188192
}
189193

flang/runtime/CUDA/pointer.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "../assign-impl.h"
1111
#include "../stat.h"
1212
#include "../terminator.h"
13+
#include "flang/Runtime/CUDA/descriptor.h"
1314
#include "flang/Runtime/CUDA/memmove-function.h"
1415
#include "flang/Runtime/pointer.h"
1516

@@ -35,6 +36,24 @@ int RTDEF(CUFPointerAllocate)(Descriptor &desc, int64_t stream, bool hasStat,
3536
return stat;
3637
}
3738

39+
int RTDEF(CUFPointerAllocateSync)(Descriptor &desc, int64_t stream,
40+
bool hasStat, const Descriptor *errMsg, const char *sourceFile,
41+
int sourceLine) {
42+
int stat{RTNAME(CUFPointerAllocate)(
43+
desc, stream, hasStat, errMsg, sourceFile, sourceLine)};
44+
#ifndef RT_DEVICE_COMPILATION
45+
// Descriptor synchronization is only done when the allocation is done
46+
// from the host.
47+
if (stat == StatOk) {
48+
void *deviceAddr{
49+
RTNAME(CUFGetDeviceAddress)((void *)&desc, sourceFile, sourceLine)};
50+
RTNAME(CUFDescriptorSync)
51+
((Descriptor *)deviceAddr, &desc, sourceFile, sourceLine);
52+
}
53+
#endif
54+
return stat;
55+
}
56+
3857
int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
3958
const Descriptor &source, int64_t stream, bool hasStat,
4059
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
@@ -48,6 +67,19 @@ int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
4867
return stat;
4968
}
5069

70+
int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer,
71+
const Descriptor &source, int64_t stream, bool hasStat,
72+
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
73+
int stat{RTNAME(CUFPointerAllocateSync)(
74+
pointer, stream, hasStat, errMsg, sourceFile, sourceLine)};
75+
if (stat == StatOk) {
76+
Terminator terminator{sourceFile, sourceLine};
77+
Fortran::runtime::DoFromSourceAssign(
78+
pointer, source, terminator, &MemmoveHostToDevice);
79+
}
80+
return stat;
81+
}
82+
5183
RT_EXT_API_GROUP_END
5284

5385
} // extern "C"

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,61 @@ func.func @_QPpointer_source() {
198198
%c0_i32 = arith.constant 0 : i32
199199
%c1 = arith.constant 1 : index
200200
%c0 = arith.constant 0 : index
201-
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QFpointer_sourceEa"}
202-
%4 = fir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFpointer_sourceEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
203-
%5 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a_d", data_attr = #cuf.cuda<device>, uniq_name = "_QFpointer_sourceEa_d"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
204-
%7 = fir.declare %5 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFpointer_sourceEa_d"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
205-
%8 = fir.load %4 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
206-
%22 = cuf.allocate %7 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>> source(%8 : !fir.box<!fir.heap<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
201+
%0 = fir.alloca !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QFpointer_sourceEa"}
202+
%4 = fir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFpointer_sourceEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
203+
%5 = cuf.alloc !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {bindc_name = "a_d", data_attr = #cuf.cuda<device>, uniq_name = "_QFpointer_sourceEa_d"} -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
204+
%7 = fir.declare %5 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFpointer_sourceEa_d"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
205+
%8 = fir.load %4 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
206+
%22 = cuf.allocate %7 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> source(%8 : !fir.box<!fir.ptr<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
207207
return
208208
}
209209

210210
// CHECK-LABEL: func.func @_QPpointer_source()
211211
// CHECK: _FortranACUFPointerAllocateSource
212212

213+
fir.global @_QMdataEb2 {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xi32>>> {
214+
%c0 = arith.constant 0 : index
215+
%0 = fir.zero_bits !fir.ptr<!fir.array<?xi32>>
216+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
217+
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xi32>>>
218+
fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xi32>>>
219+
}
220+
221+
func.func @_QQpointer_sync() attributes {fir.bindc_name = "test"} {
222+
%c0_i32 = arith.constant 0 : i32
223+
%c10_i32 = arith.constant 10 : i32
224+
%c1 = arith.constant 1 : index
225+
%0 = fir.address_of(@_QMdataEb2) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
226+
%1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdataEb"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>)
227+
%2 = fir.convert %1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
228+
%3 = fir.convert %c1 : (index) -> i64
229+
%4 = fir.convert %c10_i32 : (i32) -> i64
230+
fir.call @_FortranAAllocatableSetBounds(%2, %c0_i32, %3, %4) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
231+
%6 = cuf.allocate %1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>} -> i32
232+
return
233+
}
234+
235+
// CHECK-LABEL: func.func @_QQpointer_sync()
236+
// CHECK: _FortranACUFPointerAllocateSync
237+
238+
fir.global @_QMmod1Ea_d2 {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {
239+
%c0 = arith.constant 0 : index
240+
%0 = fir.zero_bits !fir.ptr<!fir.array<?x?xf32>>
241+
%1 = fir.shape %c0, %c0 : (index, index) -> !fir.shape<2>
242+
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
243+
fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
244+
}
245+
func.func @_QMmod1Ppointer_source_global() {
246+
%0 = fir.address_of(@_QMmod1Ea_d2) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
247+
%1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMmod1Ea_d"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
248+
%2 = fir.alloca !fir.box<!fir.ptr<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QMmod1Fallocate_source_globalEa"}
249+
%6 = fir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Fallocate_source_globalEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
250+
%7 = fir.load %6 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
251+
%21 = cuf.allocate %1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> source(%7 : !fir.box<!fir.ptr<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
252+
return
253+
}
254+
255+
// CHECK-LABEL: func.func @_QMmod1Ppointer_source_global()
256+
// CHECK: fir.call @_FortranACUFPointerAllocateSourceSync
257+
213258
} // end of module

0 commit comments

Comments
 (0)