Skip to content

Commit ecaef01

Browse files
authored
[flang][cuda] Support corner case of data transfer (#132451)
The flang runtime will complain when the number of elements in the two descriptors involved in the data transfer are not matching. In some cases, we can still perform the data transfer to match the behavior of the reference compiler. When the RHS elements count is bigger than the LHS elements count and both descriptors are contiguous, we can perform the data transfer with the bare pointers and the number of bytes from the LHS. We don't really have unit tests set up for data transfer, this is why I didn't include one here.
1 parent acdb0c1 commit ecaef01

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

flang-rt/lib/cuda/memory.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "flang/Runtime/CUDA/memory.h"
1010
#include "flang-rt/runtime/assign-impl.h"
11+
#include "flang-rt/runtime/descriptor.h"
1112
#include "flang-rt/runtime/terminator.h"
1213
#include "flang/Runtime/CUDA/common.h"
1314
#include "flang/Runtime/CUDA/descriptor.h"
@@ -98,8 +99,21 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
9899
} else {
99100
terminator.Crash("host to host copy not supported");
100101
}
101-
Fortran::runtime::Assign(
102-
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
102+
if ((srcDesc->rank() > 0) && (dstDesc->Elements() < srcDesc->Elements())) {
103+
// Special case when rhs is bigger than lhs and both are contiguous arrays.
104+
// In this case we do a simple ptr to ptr transfer with the size of lhs.
105+
// This is be allowed in the reference compiler and it avoids error
106+
// triggered in the Assign runtime function used for the main case below.
107+
if (!srcDesc->IsContiguous() || !dstDesc->IsContiguous())
108+
terminator.Crash("Unsupported data transfer: mismatching element counts "
109+
"with non-contiguous arrays");
110+
RTNAME(CUFDataTransferPtrPtr)(dstDesc->raw().base_addr,
111+
srcDesc->raw().base_addr, dstDesc->Elements() * dstDesc->ElementBytes(),
112+
mode, sourceFile, sourceLine);
113+
} else {
114+
Fortran::runtime::Assign(
115+
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
116+
}
103117
}
104118

105119
void RTDECL(CUFDataTransferCstDesc)(Descriptor *dstDesc, Descriptor *srcDesc,

0 commit comments

Comments
 (0)