Skip to content

Commit 5cfd5d1

Browse files
authored
[flang][cuda] Do not generate data transfer within cuf kernel (#89973)
CUDA data transfer with intrinsic assignment are not meant to be generated in cuf kernel. This patch fix this issue. @ImanHosseini
1 parent 60bbe57 commit 5cfd5d1

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3810,12 +3810,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
38103810
mlir::Location loc = getCurrentLocation();
38113811
fir::FirOpBuilder &builder = getFirOpBuilder();
38123812

3813+
bool isInDeviceContext =
3814+
builder.getRegion().getParentOfType<fir::CUDAKernelOp>();
38133815
bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs(assign.lhs) ||
38143816
Fortran::evaluate::HasCUDAAttrs(assign.rhs);
38153817
bool hasCUDAImplicitTransfer =
38163818
Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
38173819
llvm::SmallVector<mlir::Value> implicitTemps;
3818-
if (hasCUDAImplicitTransfer)
3820+
if (hasCUDAImplicitTransfer && !isInDeviceContext)
38193821
implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign);
38203822

38213823
// Gather some information about the assignment that will impact how it is
@@ -3874,13 +3876,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
38743876
Fortran::lower::StatementContext localStmtCtx;
38753877
hlfir::Entity rhs = evaluateRhs(localStmtCtx);
38763878
hlfir::Entity lhs = evaluateLhs(localStmtCtx);
3877-
if (isCUDATransfer && !hasCUDAImplicitTransfer)
3879+
if (isCUDATransfer && !hasCUDAImplicitTransfer && !isInDeviceContext)
38783880
genCUDADataTransfer(builder, loc, assign, lhs, rhs);
38793881
else
38803882
builder.create<hlfir::AssignOp>(loc, rhs, lhs,
38813883
isWholeAllocatableAssignment,
38823884
keepLhsLengthInAllocatableAssignment);
3883-
if (hasCUDAImplicitTransfer) {
3885+
if (hasCUDAImplicitTransfer && !isInDeviceContext) {
38843886
localSymbols.popScope();
38853887
for (mlir::Value temp : implicitTemps)
38863888
builder.create<fir::FreeMemOp>(loc, temp);

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,25 @@ end
119119
! CHECK: %[[T:.*]]:2 = hlfir.declare %7 {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub3Et"} : (!fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>) -> (!fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>, !fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>)
120120
! CHECK: %[[TMP_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = ".tmp"} : (!fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>) -> (!fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>, !fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>)
121121
! CHECK: fir.cuda_data_transfer %[[T]]#1 to %[[TMP_DECL]]#0 {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>, !fir.ref<!fir.type<_QMmod1Tt1{i:i32}>>
122+
123+
124+
! Check that fir.cuda_data_transfer are not generated within cuf kernel
125+
subroutine sub4()
126+
integer, parameter :: n = 10
127+
real, device :: adev(n)
128+
real :: ahost(n)
129+
real :: b
130+
integer :: i
131+
132+
adev = ahost
133+
!$cuf kernel do <<<*,*>>>
134+
do i = 1, n
135+
adev(i) = adev(i) + b
136+
enddo
137+
end subroutine
138+
139+
! CHECK-LABEL: func.func @_QPsub4()
140+
! CHECK: fir.cuda_data_transfer
141+
! CHECK: fir.cuda_kernel<<<*, *>>>
142+
! CHECK-NOT: fir.cuda_data_transfer
143+
! CHECK: hlfir.assign

0 commit comments

Comments
 (0)