Skip to content

Commit 652db7e

Browse files
authored
[flang][cuda] Support data transfer from pointer to a descriptor (#114892)
When source is a pointer to an array or a scalar, embox it and use the `CUFDataTransferDescDesc` or `CUFDataTransferGlobalDescDesc` entry points. The runtime is already able to deal with all the corner cases like non contiguous arrays and so on so we exploit this. Memset might still be used for simple case where we want to initialize to 0 for example. This will come in a follow up patch.
1 parent 9a5e5a6 commit 652db7e

File tree

4 files changed

+90
-62
lines changed

4 files changed

+90
-62
lines changed

flang/include/flang/Runtime/CUDA/memory.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
3535
void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
3636
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
3737

38-
/// Data transfer from a pointer to a descriptor.
39-
void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
40-
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
41-
int sourceLine = 0);
42-
4338
/// Data transfer from a descriptor to a pointer.
4439
void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
4540
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "flang/Runtime/allocatable.h"
2424
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2525
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
26+
#include "mlir/IR/Matchers.h"
2627
#include "mlir/Pass/Pass.h"
2728
#include "mlir/Transforms/DialectConversion.h"
2829
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -439,6 +440,14 @@ static bool isDstGlobal(cuf::DataTransferOp op) {
439440
return false;
440441
}
441442

443+
static mlir::Value getShapeFromDecl(mlir::Value src) {
444+
if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
445+
return declareOp.getShape();
446+
if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
447+
return declareOp.getShape();
448+
return mlir::Value{};
449+
}
450+
442451
struct CUFDataTransferOpConversion
443452
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
444453
using OpRewritePattern::OpRewritePattern;
@@ -528,54 +537,54 @@ struct CUFDataTransferOpConversion
528537
}
529538

530539
// Conversion of data transfer involving at least one descriptor.
531-
if (mlir::isa<fir::BaseBoxType>(srcTy) &&
532-
mlir::isa<fir::BaseBoxType>(dstTy)) {
533-
// Transfer between two descriptor.
540+
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
541+
// Transfer to a descriptor.
534542
mlir::func::FuncOp func =
535543
isDstGlobal(op)
536544
? fir::runtime::getRuntimeFunc<mkRTKey(
537545
CUFDataTransferGlobalDescDesc)>(loc, builder)
538546
: fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
539547
loc, builder);
540-
541-
auto fTy = func.getFunctionType();
542-
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
543-
mlir::Value sourceLine =
544-
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
545548
mlir::Value dst = op.getDst();
546549
mlir::Value src = op.getSrc();
547-
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
548-
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
549-
builder.create<fir::CallOp>(loc, func, args);
550-
rewriter.eraseOp(op);
551-
} else if (mlir::isa<fir::BaseBoxType>(dstTy) && fir::isa_trivial(srcTy)) {
552-
// Scalar to descriptor transfer.
553-
mlir::Value val = op.getSrc();
554-
if (op.getSrc().getDefiningOp() &&
555-
mlir::isa<mlir::arith::ConstantOp>(op.getSrc().getDefiningOp())) {
556-
mlir::Value alloc = builder.createTemporary(loc, srcTy);
557-
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
558-
val = alloc;
550+
551+
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
552+
// If src is not a descriptor, create one.
553+
mlir::Value addr;
554+
if (fir::isa_trivial(srcTy) &&
555+
mlir::matchPattern(op.getSrc().getDefiningOp(),
556+
mlir::m_Constant())) {
557+
// Put constant in memory if it is not.
558+
mlir::Value alloc = builder.createTemporary(loc, srcTy);
559+
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
560+
addr = alloc;
561+
} else {
562+
addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
563+
}
564+
mlir::Type boxTy = fir::BoxType::get(srcTy);
565+
llvm::SmallVector<mlir::Value> lenParams;
566+
mlir::Value box =
567+
builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
568+
/*slice=*/nullptr, lenParams,
569+
/*tdesc=*/nullptr);
570+
mlir::Value memBox = builder.createTemporary(loc, box.getType());
571+
builder.create<fir::StoreOp>(loc, box, memBox);
572+
src = memBox;
559573
}
560574

561-
mlir::func::FuncOp func =
562-
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemsetDescriptor)>(loc,
563-
builder);
564575
auto fTy = func.getFunctionType();
565576
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
566577
mlir::Value sourceLine =
567-
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
578+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
568579
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
569-
builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
580+
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
570581
builder.create<fir::CallOp>(loc, func, args);
571582
rewriter.eraseOp(op);
572583
} else {
573584
// Type used to compute the width.
574585
mlir::Type computeType = dstTy;
575586
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
576-
bool dstIsDesc = false;
577587
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
578-
dstIsDesc = true;
579588
computeType = srcTy;
580589
seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
581590
}
@@ -606,11 +615,8 @@ struct CUFDataTransferOpConversion
606615
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
607616

608617
mlir::func::FuncOp func =
609-
dstIsDesc
610-
? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
611-
loc, builder)
612-
: fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
613-
loc, builder);
618+
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
619+
loc, builder);
614620
auto fTy = func.getFunctionType();
615621
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
616622
mlir::Value sourceLine =

flang/runtime/CUDA/memory.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,6 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
9696
CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
9797
}
9898

99-
void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
100-
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
101-
Terminator terminator{sourceFile, sourceLine};
102-
terminator.Crash(
103-
"not yet implemented: CUDA data transfer from a pointer to a descriptor");
104-
}
105-
10699
void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
107100
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
108101
Terminator terminator{sourceFile, sourceLine};

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ func.func @_QPsub2() {
2929
}
3030

3131
// CHECK-LABEL: func.func @_QPsub2()
32+
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
3233
// CHECK: %[[TEMP:.*]] = fir.alloca i32
3334
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
3435
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
3536
// CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
37+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[TEMP]] : (!fir.ref<i32>) -> !fir.box<i32>
38+
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
3639
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
37-
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
38-
// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
40+
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
41+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
3942

4043
func.func @_QPsub3() {
4144
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -48,12 +51,15 @@ func.func @_QPsub3() {
4851
}
4952

5053
// CHECK-LABEL: func.func @_QPsub3()
54+
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
5155
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
5256
// CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
57+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[V]]#0 : (!fir.ref<i32>) -> !fir.box<i32>
58+
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
5359
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
54-
// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
55-
// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
56-
60+
// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
61+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
62+
5763
func.func @_QPsub4() {
5864
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
5965
%4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
@@ -67,15 +73,14 @@ func.func @_QPsub4() {
6773
return
6874
}
6975
// CHECK-LABEL: func.func @_QPsub4()
76+
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
7077
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
71-
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
72-
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
73-
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
74-
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
78+
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
79+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
80+
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
7581
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
76-
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
77-
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
78-
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
82+
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
83+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
7984
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
8085
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
8186
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -110,16 +115,15 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
110115
}
111116

112117
// CHECK-LABEL: func.func @_QPsub5
118+
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
113119
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
114120
// CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
115121
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[SHAPE]]) {uniq_name = "_QFsub5Eahost"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
116-
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
117-
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
118-
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
122+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
123+
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
119124
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
120-
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
121-
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
122-
// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
125+
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
126+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
123127
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
124128
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
125129
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -248,5 +252,35 @@ func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
248252
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
249253
// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
250254

255+
fir.global @_QMmod2Eadev {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
256+
%c0 = arith.constant 0 : index
257+
%0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
258+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
259+
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
260+
fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
261+
}
262+
func.func @_QPdesc_global_ptr() {
263+
%c10 = arith.constant 10 : index
264+
%0 = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
265+
%1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
266+
%2 = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
267+
%3 = fir.shape %c10 : (index) -> !fir.shape<1>
268+
%4 = fir.declare %2(%3) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
269+
cuf.data_transfer %4 to %1 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
270+
return
271+
}
272+
273+
// CHECK-LABEL: func.func @_QPdesc_global_ptr()
274+
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
275+
// CHECK: %[[ADDR_ADEV:.*]] = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
276+
// CHECK: %[[DECL_ADEV:.*]] = fir.declare %[[ADDR_ADEV]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
277+
// CHECK: %[[AHOST:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
278+
// CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1>
279+
// CHECK: %[[DECL_AHOST:.*]] = fir.declare %[[AHOST]](%[[SHAPE]]) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
280+
// CHECK: %[[EMBOX:.*]] = fir.embox %[[DECL_AHOST]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
281+
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
282+
// CHECK: %[[ADEV_BOXNONE:.*]] = fir.convert %[[DECL_ADEV]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
283+
// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
284+
// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
251285

252286
} // end of module

0 commit comments

Comments
 (0)