Skip to content

Commit 9d09c6f

Browse files
authored
[flang][cuda] Update device descriptor on data transfer (#114838)
When the destination of the data transfer is a global we might need to sync the descriptor after the data transfer is done. This is the case when the data transfer is from host/device to device as reallocation might have happened and the descriptor on the device needs to take the new values written on the host. A new entry point is added `CUFDataTransferGlobalDescDesc` with the sync when needed.
1 parent 3a5791e commit 9d09c6f

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

flang/include/flang/Runtime/CUDA/memory.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
4949
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
5050
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
5151

52+
/// Data transfer from a descriptor to a global descriptor.
53+
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
54+
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
55+
5256
} // extern "C"
5357
} // namespace Fortran::runtime::cuda
5458
#endif // FORTRAN_RUNTIME_CUDA_MEMORY_H_

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,16 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
429429
}
430430
};
431431

432+
static bool isDstGlobal(cuf::DataTransferOp op) {
433+
if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
434+
if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
435+
return true;
436+
if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
437+
if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
438+
return true;
439+
return false;
440+
}
441+
432442
struct CUFDataTransferOpConversion
433443
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
434444
using OpRewritePattern::OpRewritePattern;
@@ -522,8 +532,11 @@ struct CUFDataTransferOpConversion
522532
mlir::isa<fir::BaseBoxType>(dstTy)) {
523533
// Transfer between two descriptor.
524534
mlir::func::FuncOp func =
525-
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
526-
loc, builder);
535+
isDstGlobal(op)
536+
? fir::runtime::getRuntimeFunc<mkRTKey(
537+
CUFDataTransferGlobalDescDesc)>(loc, builder)
538+
: fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
539+
loc, builder);
527540

528541
auto fTy = func.getFunctionType();
529542
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);

flang/runtime/CUDA/memory.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "flang/Runtime/CUDA/memory.h"
1010
#include "../terminator.h"
1111
#include "flang/Runtime/CUDA/common.h"
12+
#include "flang/Runtime/CUDA/descriptor.h"
1213
#include "flang/Runtime/assign.h"
1314

1415
#include "cuda_runtime.h"
@@ -125,5 +126,18 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
125126
Fortran::runtime::Assign(
126127
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
127128
}
129+
130+
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
131+
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
132+
int sourceLine) {
133+
RTNAME(CUFDataTransferDescDesc)
134+
(dstDesc, srcDesc, mode, sourceFile, sourceLine);
135+
if ((mode == kHostToDevice) || (mode == kDeviceToDevice)) {
136+
void *deviceAddr{
137+
RTNAME(CUFGetDeviceAddress)((void *)dstDesc, sourceFile, sourceLine)};
138+
RTNAME(CUFDescriptorSync)
139+
((Descriptor *)deviceAddr, srcDesc, sourceFile, sourceLine);
140+
}
141+
}
128142
}
129143
} // namespace Fortran::runtime::cuda

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,29 @@ func.func @_QPsub9() {
224224
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
225225
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
226226
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
227+
228+
fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
229+
%c0 = arith.constant 0 : index
230+
%0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
231+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
232+
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
233+
fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
234+
}
235+
236+
func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
237+
%0 = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
238+
%1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
239+
%2 = fir.address_of(@_QFEahost) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
240+
%3:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
241+
cuf.data_transfer %3#0 to %1#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
242+
return
243+
}
244+
245+
// CHECK-LABEL: func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"}
246+
// CHECK: %[[GLOBAL_ADDRESS:.*]] = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
247+
// CHECK: %[[GLOBAL_DECL:.*]]:2 = hlfir.declare %[[GLOBAL_ADDRESS]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
248+
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
249+
// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
250+
251+
227252
} // end of module

0 commit comments

Comments
 (0)