Skip to content

Commit 0aa982f

Browse files
authored
[flang][cuda] Add restriction on implicit data transfer (#87720)
In section 3.4.2, some example of illegal data transfer using expression are given. One of it is when multiple device objects are part of an expression in the rhs. Current implementation allow a single device object in such case. This patch adds a similar restriction.
1 parent 60fc4ac commit 0aa982f

File tree

4 files changed

+41
-5
lines changed

4 files changed

+41
-5
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,18 +1227,24 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
12271227
const std::optional<ActualArgument> &, const std::string &procName,
12281228
const std::string &argName);
12291229

1230-
/// Check if any of the symbols part of the expression has a cuda data
1231-
/// attribute.
1232-
inline bool HasCUDAAttrs(const Expr<SomeType> &expr) {
1230+
// Get the number of distinct symbols with CUDA attribute in the expression.
1231+
template <typename A> inline int GetNbOfCUDASymbols(const A &expr) {
1232+
semantics::UnorderedSymbolSet symbols;
12331233
for (const Symbol &sym : CollectSymbols(expr)) {
12341234
if (const auto *details =
12351235
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
12361236
if (details->cudaDataAttr()) {
1237-
return true;
1237+
symbols.insert(sym);
12381238
}
12391239
}
12401240
}
1241-
return false;
1241+
return symbols.size();
1242+
}
1243+
1244+
// Check if any of the symbols part of the expression has a CUDA data
1245+
// attribute.
1246+
template <typename A> inline bool HasCUDAAttrs(const A &expr) {
1247+
return GetNbOfCUDASymbols(expr) > 0;
12421248
}
12431249

12441250
/// Check if the expression is a mix of host and device variables that require

flang/lib/Semantics/check-cuda.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#include "check-cuda.h"
1010
#include "flang/Common/template.h"
1111
#include "flang/Evaluate/fold.h"
12+
#include "flang/Evaluate/tools.h"
1213
#include "flang/Evaluate/traverse.h"
1314
#include "flang/Parser/parse-tree-visitor.h"
1415
#include "flang/Parser/parse-tree.h"
1516
#include "flang/Parser/tools.h"
1617
#include "flang/Semantics/expression.h"
1718
#include "flang/Semantics/symbol.h"
19+
#include "flang/Semantics/tools.h"
1820

1921
// Once labeled DO constructs have been canonicalized and their parse subtrees
2022
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -413,4 +415,18 @@ void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
413415
}
414416
}
415417

418+
void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
419+
const evaluate::Assignment *assign{semantics::GetAssignment(x)};
420+
int nbLhs{evaluate::GetNbOfCUDASymbols(assign->lhs)};
421+
int nbRhs{evaluate::GetNbOfCUDASymbols(assign->rhs)};
422+
auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()};
423+
424+
// device to host transfer with more than one device object on the rhs is not
425+
// legal.
426+
if (nbLhs == 0 && nbRhs > 1) {
427+
context_.Say(lhsLoc,
428+
"More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
429+
}
430+
}
431+
416432
} // namespace Fortran::semantics

flang/lib/Semantics/check-cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ struct Program;
1717
class Messages;
1818
struct Name;
1919
class CharBlock;
20+
struct AssignmentStmt;
2021
struct ExecutionPartConstruct;
2122
struct ExecutableConstruct;
2223
struct ActionStmt;
@@ -38,6 +39,7 @@ class CUDAChecker : public virtual BaseChecker {
3839
void Enter(const parser::FunctionSubprogram &);
3940
void Enter(const parser::SeparateModuleSubprogram &);
4041
void Enter(const parser::CUFKernelDoConstruct &);
42+
void Enter(const parser::AssignmentStmt &);
4143

4244
private:
4345
SemanticsContext &context_;

flang/test/Semantics/cuf11.cuf

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
! RUN: %python %S/test_errors.py %s %flang_fc1
2+
3+
subroutine sub1()
4+
real, device :: adev(10), bdev(10)
5+
real :: ahost(10)
6+
7+
!ERROR: More than one reference to a CUDA object on the right hand side of the assigment
8+
ahost = adev + bdev
9+
10+
ahost = adev + adev
11+
12+
end subroutine

0 commit comments

Comments
 (0)