Skip to content

Commit 1fc3ce1

Browse files
authored
[flang][cuda] Enable data transfer for descriptors (#92804)
Remove the TODO when data transfer is done with descriptor variables.
1 parent 3c3e71d commit 1fc3ce1

File tree

4 files changed

+58
-19
lines changed

4 files changed

+58
-19
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,21 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
152152
a = adev ! transfer device to host
153153
bdev = adev ! transfer device to device
154154
```
155+
156+
When the data transfer is done on data hold by descriptors, the LHS data
157+
hold by the descriptor are updated. When required, the LHS decriptor is also
158+
updated.
155159
}];
156160

157-
let arguments = (ins Arg<AnyReferenceLike, "", [MemRead]>:$src,
158-
Arg<AnyReferenceLike, "", [MemWrite]>:$dst,
161+
let arguments = (ins Arg<AnyRefOrBoxType, "", [MemRead]>:$src,
162+
Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
159163
cuf_DataTransferKindAttr:$transfer_kind);
160164

161165
let assemblyFormat = [{
162166
$src `to` $dst attr-dict `:` type(operands)
163167
}];
168+
169+
let hasVerifier = 1;
164170
}
165171

166172
def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,

flang/lib/Lower/Bridge.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3782,8 +3782,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
37823782
hlfir::Entity &lhs, hlfir::Entity &rhs) {
37833783
bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
37843784
bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
3785-
if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
3786-
TODO(loc, "CUDA data transfler with descriptors");
3785+
3786+
auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
3787+
if (auto loadOp =
3788+
mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
3789+
return loadOp.getMemref();
3790+
return val;
3791+
};
3792+
3793+
mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
3794+
mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
37873795

37883796
// device = host
37893797
if (lhsIsDevice && !rhsIsDevice) {
@@ -3792,11 +3800,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
37923800
if (!rhs.isVariable()) {
37933801
auto associate = hlfir::genAssociateExpr(
37943802
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
3795-
builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhs,
3803+
builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
37963804
transferKindAttr);
37973805
builder.create<hlfir::EndAssociateOp>(loc, associate);
37983806
} else {
3799-
builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
3807+
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
3808+
transferKindAttr);
38003809
}
38013810
return;
38023811
}
@@ -3805,26 +3814,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
38053814
if (!lhsIsDevice && rhsIsDevice) {
38063815
auto transferKindAttr = cuf::DataTransferKindAttr::get(
38073816
builder.getContext(), cuf::DataTransferKind::DeviceHost);
3808-
if (!rhs.isVariable()) {
3809-
// evaluateRhs loads scalar. Look for the memory reference to be used in
3810-
// the transfer.
3811-
if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
3812-
auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
3813-
builder.create<cuf::DataTransferOp>(loc, loadOp.getMemref(), lhs,
3814-
transferKindAttr);
3815-
return;
3816-
}
3817-
} else {
3818-
builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
3819-
}
3817+
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
3818+
transferKindAttr);
38203819
return;
38213820
}
38223821

3822+
// device = device
38233823
if (lhsIsDevice && rhsIsDevice) {
38243824
assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
38253825
auto transferKindAttr = cuf::DataTransferKindAttr::get(
38263826
builder.getContext(), cuf::DataTransferKind::DeviceDevice);
3827-
builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
3827+
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
3828+
transferKindAttr);
38283829
return;
38293830
}
38303831
llvm_unreachable("Unhandled CUDA data transfer");

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ mlir::LogicalResult cuf::AllocateOp::verify() {
8989
return mlir::success();
9090
}
9191

92+
//===----------------------------------------------------------------------===//
93+
// DataTransferOp
94+
//===----------------------------------------------------------------------===//
95+
96+
mlir::LogicalResult cuf::DataTransferOp::verify() {
97+
mlir::Type srcTy = getSrc().getType();
98+
mlir::Type dstTy = getDst().getType();
99+
if (fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy) ||
100+
fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy))
101+
return mlir::success();
102+
return emitOpError("expect src and dst to be both references or descriptors");
103+
}
104+
92105
//===----------------------------------------------------------------------===//
93106
// DeallocateOp
94107
//===----------------------------------------------------------------------===//

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,22 @@ end subroutine
159159

160160
! CHECK-LABEL: func.func @_QPsub6
161161
! CHECK: cuf.data_transfer
162+
163+
subroutine sub7(a, b, c)
164+
integer, device, allocatable :: a(:), c(:)
165+
integer, allocatable :: b(:)
166+
b = a
167+
168+
a = b
169+
170+
c = a
171+
end subroutine
172+
173+
! CHECK-LABEL: func.func @_QPsub7(
174+
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}) {
175+
! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
176+
! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
177+
! CHECK: %[[C:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ec"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
178+
! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
179+
! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
180+
! CHECK: cuf.data_transfer %[[A]]#0 to %[[C]]#0 {transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>

0 commit comments

Comments
 (0)