17
17
#include " flang/Semantics/expression.h"
18
18
#include " flang/Semantics/symbol.h"
19
19
#include " flang/Semantics/tools.h"
20
+ #include " llvm/ADT/StringSet.h"
20
21
21
22
// Once labeled DO constructs have been canonicalized and their parse subtrees
22
23
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -61,14 +62,19 @@ bool CanonicalizeCUDA(parser::Program &program) {
61
62
62
63
using MaybeMsg = std::optional<parser::MessageFormattedText>;
63
64
65
+ static const llvm::StringSet<> warpFunctions_ = {" match_all_syncjj" ,
66
+ " match_all_syncjx" , " match_all_syncjf" , " match_all_syncjd" ,
67
+ " match_any_syncjj" , " match_any_syncjx" , " match_any_syncjf" ,
68
+ " match_any_syncjd" };
69
+
64
70
// Traverses an evaluate::Expr<> in search of unsupported operations
65
71
// on the device.
66
72
67
73
struct DeviceExprChecker
68
74
: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
69
75
using Result = MaybeMsg;
70
76
using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
71
- DeviceExprChecker () : Base(*this ) {}
77
+ explicit DeviceExprChecker (SemanticsContext &c ) : Base(*this ), context_{c} {}
72
78
using Base::operator ();
73
79
Result operator ()(const evaluate::ProcedureDesignator &x) const {
74
80
if (const Symbol * sym{x.GetInterfaceSymbol ()}) {
@@ -78,10 +84,17 @@ struct DeviceExprChecker
78
84
if (auto attrs{subp->cudaSubprogramAttrs ()}) {
79
85
if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
80
86
*attrs == common::CUDASubprogramAttrs::Device) {
87
+ if (warpFunctions_.contains (sym->name ().ToString ()) &&
88
+ !context_.languageFeatures ().IsEnabled (
89
+ Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
90
+ return parser::MessageFormattedText (
91
+ " warp match function disabled" _err_en_US);
92
+ }
81
93
return {};
82
94
}
83
95
}
84
96
}
97
+
85
98
const Symbol &ultimate{sym->GetUltimate ()};
86
99
const Scope &scope{ultimate.owner ()};
87
100
const Symbol *mod{scope.IsModule () ? scope.symbol () : nullptr };
@@ -94,9 +107,12 @@ struct DeviceExprChecker
94
107
// TODO(CUDA): Check for unsupported intrinsics here
95
108
return {};
96
109
}
110
+
97
111
return parser::MessageFormattedText (
98
112
" '%s' may not be called in device code" _err_en_US, x.GetName ());
99
113
}
114
+
115
+ SemanticsContext &context_;
100
116
};
101
117
102
118
struct FindHostArray
@@ -133,9 +149,10 @@ struct FindHostArray
133
149
}
134
150
};
135
151
136
- template <typename A> static MaybeMsg CheckUnwrappedExpr (const A &x) {
152
+ template <typename A>
153
+ static MaybeMsg CheckUnwrappedExpr (SemanticsContext &context, const A &x) {
137
154
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
138
- return DeviceExprChecker{}(expr->typedExpr );
155
+ return DeviceExprChecker{context }(expr->typedExpr );
139
156
}
140
157
return {};
141
158
}
@@ -144,104 +161,124 @@ template <typename A>
144
161
static void CheckUnwrappedExpr (
145
162
SemanticsContext &context, SourceName at, const A &x) {
146
163
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
147
- if (auto msg{DeviceExprChecker{}(expr->typedExpr )}) {
164
+ if (auto msg{DeviceExprChecker{context }(expr->typedExpr )}) {
148
165
context.Say (at, std::move (*msg));
149
166
}
150
167
}
151
168
}
152
169
153
170
template <bool CUF_KERNEL> struct ActionStmtChecker {
154
- template <typename A> static MaybeMsg WhyNotOk (const A &x) {
171
+ template <typename A>
172
+ static MaybeMsg WhyNotOk (SemanticsContext &context, const A &x) {
155
173
if constexpr (ConstraintTrait<A>) {
156
- return WhyNotOk (x.thing );
174
+ return WhyNotOk (context, x.thing );
157
175
} else if constexpr (WrapperTrait<A>) {
158
- return WhyNotOk (x.v );
176
+ return WhyNotOk (context, x.v );
159
177
} else if constexpr (UnionTrait<A>) {
160
- return WhyNotOk (x.u );
178
+ return WhyNotOk (context, x.u );
161
179
} else if constexpr (TupleTrait<A>) {
162
- return WhyNotOk (x.t );
180
+ return WhyNotOk (context, x.t );
163
181
} else {
164
182
return parser::MessageFormattedText{
165
183
" Statement may not appear in device code" _err_en_US};
166
184
}
167
185
}
168
186
template <typename A>
169
- static MaybeMsg WhyNotOk (const common::Indirection<A> &x) {
170
- return WhyNotOk (x.value ());
187
+ static MaybeMsg WhyNotOk (
188
+ SemanticsContext &context, const common::Indirection<A> &x) {
189
+ return WhyNotOk (context, x.value ());
171
190
}
172
191
template <typename ... As>
173
- static MaybeMsg WhyNotOk (const std::variant<As...> &x) {
174
- return common::visit ([](const auto &x) { return WhyNotOk (x); }, x);
192
+ static MaybeMsg WhyNotOk (
193
+ SemanticsContext &context, const std::variant<As...> &x) {
194
+ return common::visit (
195
+ [&context](const auto &x) { return WhyNotOk (context, x); }, x);
175
196
}
176
197
template <std::size_t J = 0 , typename ... As>
177
- static MaybeMsg WhyNotOk (const std::tuple<As...> &x) {
198
+ static MaybeMsg WhyNotOk (
199
+ SemanticsContext &context, const std::tuple<As...> &x) {
178
200
if constexpr (J == sizeof ...(As)) {
179
201
return {};
180
- } else if (auto msg{WhyNotOk (std::get<J>(x))}) {
202
+ } else if (auto msg{WhyNotOk (context, std::get<J>(x))}) {
181
203
return msg;
182
204
} else {
183
- return WhyNotOk<(J + 1 )>(x);
205
+ return WhyNotOk<(J + 1 )>(context, x);
184
206
}
185
207
}
186
- template <typename A> static MaybeMsg WhyNotOk (const std::list<A> &x) {
208
+ template <typename A>
209
+ static MaybeMsg WhyNotOk (SemanticsContext &context, const std::list<A> &x) {
187
210
for (const auto &y : x) {
188
- if (MaybeMsg result{WhyNotOk (y)}) {
211
+ if (MaybeMsg result{WhyNotOk (context, y)}) {
189
212
return result;
190
213
}
191
214
}
192
215
return {};
193
216
}
194
- template <typename A> static MaybeMsg WhyNotOk (const std::optional<A> &x) {
217
+ template <typename A>
218
+ static MaybeMsg WhyNotOk (
219
+ SemanticsContext &context, const std::optional<A> &x) {
195
220
if (x) {
196
- return WhyNotOk (*x);
221
+ return WhyNotOk (context, *x);
197
222
} else {
198
223
return {};
199
224
}
200
225
}
201
226
template <typename A>
202
- static MaybeMsg WhyNotOk (const parser::UnlabeledStatement<A> &x) {
203
- return WhyNotOk (x.statement );
227
+ static MaybeMsg WhyNotOk (
228
+ SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
229
+ return WhyNotOk (context, x.statement );
204
230
}
205
231
template <typename A>
206
- static MaybeMsg WhyNotOk (const parser::Statement<A> &x) {
207
- return WhyNotOk (x.statement );
232
+ static MaybeMsg WhyNotOk (
233
+ SemanticsContext &context, const parser::Statement<A> &x) {
234
+ return WhyNotOk (context, x.statement );
208
235
}
209
- static MaybeMsg WhyNotOk (const parser::AllocateStmt &) {
236
+ static MaybeMsg WhyNotOk (
237
+ SemanticsContext &context, const parser::AllocateStmt &) {
210
238
return {}; // AllocateObjects are checked elsewhere
211
239
}
212
- static MaybeMsg WhyNotOk (const parser::AllocateCoarraySpec &) {
240
+ static MaybeMsg WhyNotOk (
241
+ SemanticsContext &context, const parser::AllocateCoarraySpec &) {
213
242
return parser::MessageFormattedText (
214
243
" A coarray may not be allocated on the device" _err_en_US);
215
244
}
216
- static MaybeMsg WhyNotOk (const parser::DeallocateStmt &) {
245
+ static MaybeMsg WhyNotOk (
246
+ SemanticsContext &context, const parser::DeallocateStmt &) {
217
247
return {}; // AllocateObjects are checked elsewhere
218
248
}
219
- static MaybeMsg WhyNotOk (const parser::AssignmentStmt &x) {
220
- return DeviceExprChecker{}(x.typedAssignment );
249
+ static MaybeMsg WhyNotOk (
250
+ SemanticsContext &context, const parser::AssignmentStmt &x) {
251
+ return DeviceExprChecker{context}(x.typedAssignment );
221
252
}
222
- static MaybeMsg WhyNotOk (const parser::CallStmt &x) {
223
- return DeviceExprChecker{}(x.typedCall );
253
+ static MaybeMsg WhyNotOk (
254
+ SemanticsContext &context, const parser::CallStmt &x) {
255
+ return DeviceExprChecker{context}(x.typedCall );
256
+ }
257
+ static MaybeMsg WhyNotOk (
258
+ SemanticsContext &context, const parser::ContinueStmt &) {
259
+ return {};
224
260
}
225
- static MaybeMsg WhyNotOk (const parser::ContinueStmt &) { return {}; }
226
- static MaybeMsg WhyNotOk (const parser::IfStmt &x) {
227
- if (auto result{
228
- CheckUnwrappedExpr (std::get<parser::ScalarLogicalExpr>(x.t ))}) {
261
+ static MaybeMsg WhyNotOk (SemanticsContext &context, const parser::IfStmt &x) {
262
+ if (auto result{CheckUnwrappedExpr (
263
+ context, std::get<parser::ScalarLogicalExpr>(x.t ))}) {
229
264
return result;
230
265
}
231
- return WhyNotOk (
266
+ return WhyNotOk (context,
232
267
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t )
233
268
.statement );
234
269
}
235
- static MaybeMsg WhyNotOk (const parser::NullifyStmt &x) {
270
+ static MaybeMsg WhyNotOk (
271
+ SemanticsContext &context, const parser::NullifyStmt &x) {
236
272
for (const auto &y : x.v ) {
237
- if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr )}) {
273
+ if (MaybeMsg result{DeviceExprChecker{context }(y.typedExpr )}) {
238
274
return result;
239
275
}
240
276
}
241
277
return {};
242
278
}
243
- static MaybeMsg WhyNotOk (const parser::PointerAssignmentStmt &x) {
244
- return DeviceExprChecker{}(x.typedAssignment );
279
+ static MaybeMsg WhyNotOk (
280
+ SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
281
+ return DeviceExprChecker{context}(x.typedAssignment );
245
282
}
246
283
};
247
284
@@ -435,12 +472,14 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
435
472
ErrorIfHostSymbol (assign->lhs , source);
436
473
ErrorIfHostSymbol (assign->rhs , source);
437
474
}
438
- if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk (x)}) {
475
+ if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk (
476
+ context_, x)}) {
439
477
context_.Say (source, std::move (*msg));
440
478
}
441
479
},
442
480
[&](const auto &x) {
443
- if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk (x)}) {
481
+ if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk (
482
+ context_, x)}) {
444
483
context_.Say (source, std::move (*msg));
445
484
}
446
485
},
@@ -504,7 +543,7 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
504
543
Check (DEREF (parser::Unwrap<parser::Expr>(x)));
505
544
}
506
545
void Check (const parser::Expr &expr) {
507
- if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr )}) {
546
+ if (MaybeMsg msg{DeviceExprChecker{context_ }(expr.typedExpr )}) {
508
547
context_.Say (expr.source , std::move (*msg));
509
548
}
510
549
}
0 commit comments