Skip to content

Commit a3201ce

Browse files
authored
[flang][cuda] Add option to disable warp function in semantic (#143640)
These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.
1 parent 3ece9b0 commit a3201ce

File tree

4 files changed

+101
-44
lines changed

4 files changed

+101
-44
lines changed

flang/include/flang/Support/Fortran-features.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
5555
SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
5656
IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
5757
ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
58-
InaccessibleDeferredOverride)
58+
InaccessibleDeferredOverride, CudaWarpMatchFunction)
5959

6060
// Portability and suspicious usage warnings
6161
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,

flang/lib/Semantics/check-cuda.cpp

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "flang/Semantics/expression.h"
1818
#include "flang/Semantics/symbol.h"
1919
#include "flang/Semantics/tools.h"
20+
#include "llvm/ADT/StringSet.h"
2021

2122
// Once labeled DO constructs have been canonicalized and their parse subtrees
2223
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -61,14 +62,19 @@ bool CanonicalizeCUDA(parser::Program &program) {
6162

6263
using MaybeMsg = std::optional<parser::MessageFormattedText>;
6364

65+
static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
66+
"match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
67+
"match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
68+
"match_any_syncjd"};
69+
6470
// Traverses an evaluate::Expr<> in search of unsupported operations
6571
// on the device.
6672

6773
struct DeviceExprChecker
6874
: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
6975
using Result = MaybeMsg;
7076
using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
71-
DeviceExprChecker() : Base(*this) {}
77+
explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
7278
using Base::operator();
7379
Result operator()(const evaluate::ProcedureDesignator &x) const {
7480
if (const Symbol * sym{x.GetInterfaceSymbol()}) {
@@ -78,10 +84,17 @@ struct DeviceExprChecker
7884
if (auto attrs{subp->cudaSubprogramAttrs()}) {
7985
if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
8086
*attrs == common::CUDASubprogramAttrs::Device) {
87+
if (warpFunctions_.contains(sym->name().ToString()) &&
88+
!context_.languageFeatures().IsEnabled(
89+
Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
90+
return parser::MessageFormattedText(
91+
"warp match function disabled"_err_en_US);
92+
}
8193
return {};
8294
}
8395
}
8496
}
97+
8598
const Symbol &ultimate{sym->GetUltimate()};
8699
const Scope &scope{ultimate.owner()};
87100
const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
@@ -94,9 +107,12 @@ struct DeviceExprChecker
94107
// TODO(CUDA): Check for unsupported intrinsics here
95108
return {};
96109
}
110+
97111
return parser::MessageFormattedText(
98112
"'%s' may not be called in device code"_err_en_US, x.GetName());
99113
}
114+
115+
SemanticsContext &context_;
100116
};
101117

102118
struct FindHostArray
@@ -133,9 +149,10 @@ struct FindHostArray
133149
}
134150
};
135151

136-
template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
152+
template <typename A>
153+
static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
137154
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
138-
return DeviceExprChecker{}(expr->typedExpr);
155+
return DeviceExprChecker{context}(expr->typedExpr);
139156
}
140157
return {};
141158
}
@@ -144,104 +161,124 @@ template <typename A>
144161
static void CheckUnwrappedExpr(
145162
SemanticsContext &context, SourceName at, const A &x) {
146163
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
147-
if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
164+
if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
148165
context.Say(at, std::move(*msg));
149166
}
150167
}
151168
}
152169

153170
template <bool CUF_KERNEL> struct ActionStmtChecker {
154-
template <typename A> static MaybeMsg WhyNotOk(const A &x) {
171+
template <typename A>
172+
static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
155173
if constexpr (ConstraintTrait<A>) {
156-
return WhyNotOk(x.thing);
174+
return WhyNotOk(context, x.thing);
157175
} else if constexpr (WrapperTrait<A>) {
158-
return WhyNotOk(x.v);
176+
return WhyNotOk(context, x.v);
159177
} else if constexpr (UnionTrait<A>) {
160-
return WhyNotOk(x.u);
178+
return WhyNotOk(context, x.u);
161179
} else if constexpr (TupleTrait<A>) {
162-
return WhyNotOk(x.t);
180+
return WhyNotOk(context, x.t);
163181
} else {
164182
return parser::MessageFormattedText{
165183
"Statement may not appear in device code"_err_en_US};
166184
}
167185
}
168186
template <typename A>
169-
static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
170-
return WhyNotOk(x.value());
187+
static MaybeMsg WhyNotOk(
188+
SemanticsContext &context, const common::Indirection<A> &x) {
189+
return WhyNotOk(context, x.value());
171190
}
172191
template <typename... As>
173-
static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
174-
return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
192+
static MaybeMsg WhyNotOk(
193+
SemanticsContext &context, const std::variant<As...> &x) {
194+
return common::visit(
195+
[&context](const auto &x) { return WhyNotOk(context, x); }, x);
175196
}
176197
template <std::size_t J = 0, typename... As>
177-
static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
198+
static MaybeMsg WhyNotOk(
199+
SemanticsContext &context, const std::tuple<As...> &x) {
178200
if constexpr (J == sizeof...(As)) {
179201
return {};
180-
} else if (auto msg{WhyNotOk(std::get<J>(x))}) {
202+
} else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
181203
return msg;
182204
} else {
183-
return WhyNotOk<(J + 1)>(x);
205+
return WhyNotOk<(J + 1)>(context, x);
184206
}
185207
}
186-
template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
208+
template <typename A>
209+
static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
187210
for (const auto &y : x) {
188-
if (MaybeMsg result{WhyNotOk(y)}) {
211+
if (MaybeMsg result{WhyNotOk(context, y)}) {
189212
return result;
190213
}
191214
}
192215
return {};
193216
}
194-
template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
217+
template <typename A>
218+
static MaybeMsg WhyNotOk(
219+
SemanticsContext &context, const std::optional<A> &x) {
195220
if (x) {
196-
return WhyNotOk(*x);
221+
return WhyNotOk(context, *x);
197222
} else {
198223
return {};
199224
}
200225
}
201226
template <typename A>
202-
static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
203-
return WhyNotOk(x.statement);
227+
static MaybeMsg WhyNotOk(
228+
SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
229+
return WhyNotOk(context, x.statement);
204230
}
205231
template <typename A>
206-
static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
207-
return WhyNotOk(x.statement);
232+
static MaybeMsg WhyNotOk(
233+
SemanticsContext &context, const parser::Statement<A> &x) {
234+
return WhyNotOk(context, x.statement);
208235
}
209-
static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
236+
static MaybeMsg WhyNotOk(
237+
SemanticsContext &context, const parser::AllocateStmt &) {
210238
return {}; // AllocateObjects are checked elsewhere
211239
}
212-
static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
240+
static MaybeMsg WhyNotOk(
241+
SemanticsContext &context, const parser::AllocateCoarraySpec &) {
213242
return parser::MessageFormattedText(
214243
"A coarray may not be allocated on the device"_err_en_US);
215244
}
216-
static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
245+
static MaybeMsg WhyNotOk(
246+
SemanticsContext &context, const parser::DeallocateStmt &) {
217247
return {}; // AllocateObjects are checked elsewhere
218248
}
219-
static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
220-
return DeviceExprChecker{}(x.typedAssignment);
249+
static MaybeMsg WhyNotOk(
250+
SemanticsContext &context, const parser::AssignmentStmt &x) {
251+
return DeviceExprChecker{context}(x.typedAssignment);
221252
}
222-
static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
223-
return DeviceExprChecker{}(x.typedCall);
253+
static MaybeMsg WhyNotOk(
254+
SemanticsContext &context, const parser::CallStmt &x) {
255+
return DeviceExprChecker{context}(x.typedCall);
256+
}
257+
static MaybeMsg WhyNotOk(
258+
SemanticsContext &context, const parser::ContinueStmt &) {
259+
return {};
224260
}
225-
static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
226-
static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
227-
if (auto result{
228-
CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
261+
static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
262+
if (auto result{CheckUnwrappedExpr(
263+
context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
229264
return result;
230265
}
231-
return WhyNotOk(
266+
return WhyNotOk(context,
232267
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
233268
.statement);
234269
}
235-
static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
270+
static MaybeMsg WhyNotOk(
271+
SemanticsContext &context, const parser::NullifyStmt &x) {
236272
for (const auto &y : x.v) {
237-
if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
273+
if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
238274
return result;
239275
}
240276
}
241277
return {};
242278
}
243-
static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
244-
return DeviceExprChecker{}(x.typedAssignment);
279+
static MaybeMsg WhyNotOk(
280+
SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
281+
return DeviceExprChecker{context}(x.typedAssignment);
245282
}
246283
};
247284

@@ -435,12 +472,14 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
435472
ErrorIfHostSymbol(assign->lhs, source);
436473
ErrorIfHostSymbol(assign->rhs, source);
437474
}
438-
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
475+
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
476+
context_, x)}) {
439477
context_.Say(source, std::move(*msg));
440478
}
441479
},
442480
[&](const auto &x) {
443-
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
481+
if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
482+
context_, x)}) {
444483
context_.Say(source, std::move(*msg));
445484
}
446485
},
@@ -504,7 +543,7 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
504543
Check(DEREF(parser::Unwrap<parser::Expr>(x)));
505544
}
506545
void Check(const parser::Expr &expr) {
507-
if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
546+
if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
508547
context_.Say(expr.source, std::move(*msg));
509548
}
510549
}

flang/test/Semantics/cuf22.cuf

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s
2+
3+
attributes(device) subroutine testMatch()
4+
integer :: a, ipred, mask, v32
5+
a = match_all_sync(mask, v32, ipred)
6+
end subroutine
7+
8+
! CHECK: warp match function disabled

flang/tools/bbc/bbc.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
223223
llvm::cl::desc("enable CUDA Fortran"),
224224
llvm::cl::init(false));
225225

226+
static llvm::cl::opt<bool>
227+
disableCUDAWarpFunction("fcuda-disable-warp-function",
228+
llvm::cl::desc("Disable CUDA Warp Function"),
229+
llvm::cl::init(false));
230+
226231
static llvm::cl::opt<std::string>
227232
enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
228233
llvm::cl::init(""));
@@ -600,6 +605,11 @@ int main(int argc, char **argv) {
600605
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
601606
}
602607

608+
if (disableCUDAWarpFunction) {
609+
options.features.Enable(
610+
Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
611+
}
612+
603613
if (enableGPUMode == "managed") {
604614
options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
605615
} else if (enableGPUMode == "unified") {

0 commit comments

Comments
 (0)