-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][cuda] Add option to disable warp function in semantic #143640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-semantics Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThese functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these. Full diff: https://github.com/llvm/llvm-project/pull/143640.diff 3 Files Affected:
diff --git a/flang/include/flang/Support/Fortran-features.h b/flang/include/flang/Support/Fortran-features.h
index 3f6d825e2b66c..ea0845b7d605f 100644
--- a/flang/include/flang/Support/Fortran-features.h
+++ b/flang/include/flang/Support/Fortran-features.h
@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
- InaccessibleDeferredOverride)
+ InaccessibleDeferredOverride, CudaWarpMatchFunction)
// Portability and suspicious usage warnings
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index c024640af1220..8decfb0149829 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -17,6 +17,7 @@
#include "flang/Semantics/expression.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSet.h"
// Once labeled DO constructs have been canonicalized and their parse subtrees
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -61,6 +62,11 @@ bool CanonicalizeCUDA(parser::Program &program) {
using MaybeMsg = std::optional<parser::MessageFormattedText>;
+static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
+ "match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
+ "match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
+ "match_any_syncjd"};
+
// Traverses an evaluate::Expr<> in search of unsupported operations
// on the device.
@@ -68,7 +74,7 @@ struct DeviceExprChecker
: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
using Result = MaybeMsg;
using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
- DeviceExprChecker() : Base(*this) {}
+ explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
using Base::operator();
Result operator()(const evaluate::ProcedureDesignator &x) const {
if (const Symbol * sym{x.GetInterfaceSymbol()}) {
@@ -78,10 +84,17 @@ struct DeviceExprChecker
if (auto attrs{subp->cudaSubprogramAttrs()}) {
if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
*attrs == common::CUDASubprogramAttrs::Device) {
+ if (warpFunctions_.contains(sym->name().ToString()) &&
+ !context_.languageFeatures().IsEnabled(
+ Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
+ return parser::MessageFormattedText(
+ "warp match function disabled"_err_en_US);
+ }
return {};
}
}
}
+
const Symbol &ultimate{sym->GetUltimate()};
const Scope &scope{ultimate.owner()};
const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
@@ -94,9 +107,12 @@ struct DeviceExprChecker
// TODO(CUDA): Check for unsupported intrinsics here
return {};
}
+
return parser::MessageFormattedText(
"'%s' may not be called in device code"_err_en_US, x.GetName());
}
+
+ SemanticsContext &context_;
};
struct FindHostArray
@@ -133,9 +149,10 @@ struct FindHostArray
}
};
-template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
+template <typename A>
+static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
- return DeviceExprChecker{}(expr->typedExpr);
+ return DeviceExprChecker{context}(expr->typedExpr);
}
return {};
}
@@ -144,104 +161,124 @@ template <typename A>
static void CheckUnwrappedExpr(
SemanticsContext &context, SourceName at, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
- if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
+ if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
context.Say(at, std::move(*msg));
}
}
}
template <bool CUF_KERNEL> struct ActionStmtChecker {
- template <typename A> static MaybeMsg WhyNotOk(const A &x) {
+ template <typename A>
+ static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
if constexpr (ConstraintTrait<A>) {
- return WhyNotOk(x.thing);
+ return WhyNotOk(context, x.thing);
} else if constexpr (WrapperTrait<A>) {
- return WhyNotOk(x.v);
+ return WhyNotOk(context, x.v);
} else if constexpr (UnionTrait<A>) {
- return WhyNotOk(x.u);
+ return WhyNotOk(context, x.u);
} else if constexpr (TupleTrait<A>) {
- return WhyNotOk(x.t);
+ return WhyNotOk(context, x.t);
} else {
return parser::MessageFormattedText{
"Statement may not appear in device code"_err_en_US};
}
}
template <typename A>
- static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
- return WhyNotOk(x.value());
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const common::Indirection<A> &x) {
+ return WhyNotOk(context, x.value());
}
template <typename... As>
- static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
- return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const std::variant<As...> &x) {
+ return common::visit(
+ [&context](const auto &x) { return WhyNotOk(context, x); }, x);
}
template <std::size_t J = 0, typename... As>
- static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const std::tuple<As...> &x) {
if constexpr (J == sizeof...(As)) {
return {};
- } else if (auto msg{WhyNotOk(std::get<J>(x))}) {
+ } else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
return msg;
} else {
- return WhyNotOk<(J + 1)>(x);
+ return WhyNotOk<(J + 1)>(context, x);
}
}
- template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
+ template <typename A>
+ static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
for (const auto &y : x) {
- if (MaybeMsg result{WhyNotOk(y)}) {
+ if (MaybeMsg result{WhyNotOk(context, y)}) {
return result;
}
}
return {};
}
- template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
+ template <typename A>
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const std::optional<A> &x) {
if (x) {
- return WhyNotOk(*x);
+ return WhyNotOk(context, *x);
} else {
return {};
}
}
template <typename A>
- static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
- return WhyNotOk(x.statement);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
+ return WhyNotOk(context, x.statement);
}
template <typename A>
- static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
- return WhyNotOk(x.statement);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::Statement<A> &x) {
+ return WhyNotOk(context, x.statement);
}
- static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::AllocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
- static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::AllocateCoarraySpec &) {
return parser::MessageFormattedText(
"A coarray may not be allocated on the device"_err_en_US);
}
- static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::DeallocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
- static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
- return DeviceExprChecker{}(x.typedAssignment);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::AssignmentStmt &x) {
+ return DeviceExprChecker{context}(x.typedAssignment);
}
- static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
- return DeviceExprChecker{}(x.typedCall);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::CallStmt &x) {
+ return DeviceExprChecker{context}(x.typedCall);
+ }
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::ContinueStmt &) {
+ return {};
}
- static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
- static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
- if (auto result{
- CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
+ static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
+ if (auto result{CheckUnwrappedExpr(
+ context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
return result;
}
- return WhyNotOk(
+ return WhyNotOk(context,
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
.statement);
}
- static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::NullifyStmt &x) {
for (const auto &y : x.v) {
- if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
+ if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
return result;
}
}
return {};
}
- static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
- return DeviceExprChecker{}(x.typedAssignment);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
+ return DeviceExprChecker{context}(x.typedAssignment);
}
};
@@ -435,12 +472,14 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
ErrorIfHostSymbol(assign->lhs, source);
ErrorIfHostSymbol(assign->rhs, source);
}
- if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+ if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+ context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
[&](const auto &x) {
- if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+ if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+ context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
@@ -504,7 +543,7 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
Check(DEREF(parser::Unwrap<parser::Expr>(x)));
}
void Check(const parser::Expr &expr) {
- if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
+ if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
context_.Say(expr.source, std::move(*msg));
}
}
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index c544008a24d56..c80872108ac8f 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
llvm::cl::desc("enable CUDA Fortran"),
llvm::cl::init(false));
+static llvm::cl::opt<bool>
+ disableCUDAWarpFunction("fcuda-disable-warp-function",
+ llvm::cl::desc("Disable CUDA Warp Function"),
+ llvm::cl::init(false));
+
static llvm::cl::opt<std::string>
enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
llvm::cl::init(""));
@@ -600,6 +605,11 @@ int main(int argc, char **argv) {
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
}
+ if (disableCUDAWarpFunction) {
+ options.features.Enable(
+ Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
+ }
+
if (enableGPUMode == "managed") {
options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
} else if (enableGPUMode == "unified") {
|
Does it make sense to add a test? |
Yes. I had it ready. I forgot to git add it! |
@@ -0,0 +1,8 @@ | |||
! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does not bbc
do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bbc command will fail since I want to check the error. Adding the not makes the test pass because the return value of the command is checked.
…43640) These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.
…43640) These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.
…43640) These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.
These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.