Skip to content

Commit a309c07

Browse files
authored
[flang][cuda] Allow if stmt in device subroutine (#89347)
1 parent 7d8616e commit a309c07

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

flang/lib/Semantics/check-cuda.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,9 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
344344
[&](const common::Indirection<parser::BackspaceStmt> &x) {
345345
WarnOnIoStmt(source);
346346
},
347+
[&](const common::Indirection<parser::IfStmt> &x) {
348+
Check(x.value());
349+
},
347350
[&](const auto &x) {
348351
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
349352
context_.Say(source, std::move(*msg));
@@ -369,6 +372,13 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
369372
Check(std::get<parser::Block>(eb->t));
370373
}
371374
}
375+
void Check(const parser::IfStmt &is) {
376+
const auto &uS{
377+
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)};
378+
CheckUnwrappedExpr(
379+
context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t));
380+
Check(uS.statement, uS.source);
381+
}
372382
void Check(const parser::LoopControl::Bounds &bounds) {
373383
Check(bounds.lower);
374384
Check(bounds.upper);

flang/test/Semantics/cuf11.cuf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ logical function compare_h(a,b)
3030
!ERROR: 'b' is not an object of derived type; it is implicitly typed
3131
compare_h = (a%h .eq. b%h)
3232
end
33+
34+
attributes(global) subroutine sub2()
35+
if (threadIdx%x == 1) print *, "I'm number one"
36+
end subroutine

0 commit comments

Comments
 (0)