Skip to content

[flang][cuda] Add option to disable warp function in semantic #143640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2025

Conversation

clementval
Copy link
Contributor

These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.

@clementval clementval requested a review from wangzpgi June 11, 2025 00:28
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:semantics labels Jun 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-flang-semantics

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

These functions are not available in some lower compute capabilities. Add option in the language feature to enforce the semantic check on these.


Full diff: https://github.com/llvm/llvm-project/pull/143640.diff

3 Files Affected:

  • (modified) flang/include/flang/Support/Fortran-features.h (+1-1)
  • (modified) flang/lib/Semantics/check-cuda.cpp (+82-43)
  • (modified) flang/tools/bbc/bbc.cpp (+10)
diff --git a/flang/include/flang/Support/Fortran-features.h b/flang/include/flang/Support/Fortran-features.h
index 3f6d825e2b66c..ea0845b7d605f 100644
--- a/flang/include/flang/Support/Fortran-features.h
+++ b/flang/include/flang/Support/Fortran-features.h
@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
     SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
     IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
     ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
-    InaccessibleDeferredOverride)
+    InaccessibleDeferredOverride, CudaWarpMatchFunction)
 
 // Portability and suspicious usage warnings
 ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index c024640af1220..8decfb0149829 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -17,6 +17,7 @@
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSet.h"
 
 // Once labeled DO constructs have been canonicalized and their parse subtrees
 // transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -61,6 +62,11 @@ bool CanonicalizeCUDA(parser::Program &program) {
 
 using MaybeMsg = std::optional<parser::MessageFormattedText>;
 
+static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
+    "match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
+    "match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
+    "match_any_syncjd"};
+
 // Traverses an evaluate::Expr<> in search of unsupported operations
 // on the device.
 
@@ -68,7 +74,7 @@ struct DeviceExprChecker
     : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
   using Result = MaybeMsg;
   using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
-  DeviceExprChecker() : Base(*this) {}
+  explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
   using Base::operator();
   Result operator()(const evaluate::ProcedureDesignator &x) const {
     if (const Symbol * sym{x.GetInterfaceSymbol()}) {
@@ -78,10 +84,17 @@ struct DeviceExprChecker
         if (auto attrs{subp->cudaSubprogramAttrs()}) {
           if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
               *attrs == common::CUDASubprogramAttrs::Device) {
+            if (warpFunctions_.contains(sym->name().ToString()) &&
+                !context_.languageFeatures().IsEnabled(
+                    Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
+              return parser::MessageFormattedText(
+                  "warp match function disabled"_err_en_US);
+            }
             return {};
           }
         }
       }
+
       const Symbol &ultimate{sym->GetUltimate()};
       const Scope &scope{ultimate.owner()};
       const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
@@ -94,9 +107,12 @@ struct DeviceExprChecker
       // TODO(CUDA): Check for unsupported intrinsics here
       return {};
     }
+
     return parser::MessageFormattedText(
         "'%s' may not be called in device code"_err_en_US, x.GetName());
   }
+
+  SemanticsContext &context_;
 };
 
 struct FindHostArray
@@ -133,9 +149,10 @@ struct FindHostArray
   }
 };
 
-template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
+template <typename A>
+static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
-    return DeviceExprChecker{}(expr->typedExpr);
+    return DeviceExprChecker{context}(expr->typedExpr);
   }
   return {};
 }
@@ -144,104 +161,124 @@ template <typename A>
 static void CheckUnwrappedExpr(
     SemanticsContext &context, SourceName at, const A &x) {
   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
-    if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
+    if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
       context.Say(at, std::move(*msg));
     }
   }
 }
 
 template <bool CUF_KERNEL> struct ActionStmtChecker {
-  template <typename A> static MaybeMsg WhyNotOk(const A &x) {
+  template <typename A>
+  static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
     if constexpr (ConstraintTrait<A>) {
-      return WhyNotOk(x.thing);
+      return WhyNotOk(context, x.thing);
     } else if constexpr (WrapperTrait<A>) {
-      return WhyNotOk(x.v);
+      return WhyNotOk(context, x.v);
     } else if constexpr (UnionTrait<A>) {
-      return WhyNotOk(x.u);
+      return WhyNotOk(context, x.u);
     } else if constexpr (TupleTrait<A>) {
-      return WhyNotOk(x.t);
+      return WhyNotOk(context, x.t);
     } else {
       return parser::MessageFormattedText{
           "Statement may not appear in device code"_err_en_US};
     }
   }
   template <typename A>
-  static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
-    return WhyNotOk(x.value());
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const common::Indirection<A> &x) {
+    return WhyNotOk(context, x.value());
   }
   template <typename... As>
-  static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
-    return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const std::variant<As...> &x) {
+    return common::visit(
+        [&context](const auto &x) { return WhyNotOk(context, x); }, x);
   }
   template <std::size_t J = 0, typename... As>
-  static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const std::tuple<As...> &x) {
     if constexpr (J == sizeof...(As)) {
       return {};
-    } else if (auto msg{WhyNotOk(std::get<J>(x))}) {
+    } else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
       return msg;
     } else {
-      return WhyNotOk<(J + 1)>(x);
+      return WhyNotOk<(J + 1)>(context, x);
     }
   }
-  template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
+  template <typename A>
+  static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
     for (const auto &y : x) {
-      if (MaybeMsg result{WhyNotOk(y)}) {
+      if (MaybeMsg result{WhyNotOk(context, y)}) {
         return result;
       }
     }
     return {};
   }
-  template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
+  template <typename A>
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const std::optional<A> &x) {
     if (x) {
-      return WhyNotOk(*x);
+      return WhyNotOk(context, *x);
     } else {
       return {};
     }
   }
   template <typename A>
-  static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
-    return WhyNotOk(x.statement);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
+    return WhyNotOk(context, x.statement);
   }
   template <typename A>
-  static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
-    return WhyNotOk(x.statement);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::Statement<A> &x) {
+    return WhyNotOk(context, x.statement);
   }
-  static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::AllocateStmt &) {
     return {}; // AllocateObjects are checked elsewhere
   }
-  static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::AllocateCoarraySpec &) {
     return parser::MessageFormattedText(
         "A coarray may not be allocated on the device"_err_en_US);
   }
-  static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::DeallocateStmt &) {
     return {}; // AllocateObjects are checked elsewhere
   }
-  static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
-    return DeviceExprChecker{}(x.typedAssignment);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::AssignmentStmt &x) {
+    return DeviceExprChecker{context}(x.typedAssignment);
   }
-  static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
-    return DeviceExprChecker{}(x.typedCall);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::CallStmt &x) {
+    return DeviceExprChecker{context}(x.typedCall);
+  }
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::ContinueStmt &) {
+    return {};
   }
-  static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
-  static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
-    if (auto result{
-            CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
+  static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
+    if (auto result{CheckUnwrappedExpr(
+            context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
       return result;
     }
-    return WhyNotOk(
+    return WhyNotOk(context,
         std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
             .statement);
   }
-  static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::NullifyStmt &x) {
     for (const auto &y : x.v) {
-      if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
+      if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
         return result;
       }
     }
     return {};
   }
-  static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
-    return DeviceExprChecker{}(x.typedAssignment);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
+    return DeviceExprChecker{context}(x.typedAssignment);
   }
 };
 
@@ -435,12 +472,14 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
                 ErrorIfHostSymbol(assign->lhs, source);
                 ErrorIfHostSymbol(assign->rhs, source);
               }
-              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+                      context_, x)}) {
                 context_.Say(source, std::move(*msg));
               }
             },
             [&](const auto &x) {
-              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+                      context_, x)}) {
                 context_.Say(source, std::move(*msg));
               }
             },
@@ -504,7 +543,7 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
     Check(DEREF(parser::Unwrap<parser::Expr>(x)));
   }
   void Check(const parser::Expr &expr) {
-    if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
+    if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
       context_.Say(expr.source, std::move(*msg));
     }
   }
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index c544008a24d56..c80872108ac8f 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
                                       llvm::cl::desc("enable CUDA Fortran"),
                                       llvm::cl::init(false));
 
+static llvm::cl::opt<bool>
+    disableCUDAWarpFunction("fcuda-disable-warp-function",
+                            llvm::cl::desc("Disable CUDA Warp Function"),
+                            llvm::cl::init(false));
+
 static llvm::cl::opt<std::string>
     enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
                   llvm::cl::init(""));
@@ -600,6 +605,11 @@ int main(int argc, char **argv) {
     options.features.Enable(Fortran::common::LanguageFeature::CUDA);
   }
 
+  if (disableCUDAWarpFunction) {
+    options.features.Enable(
+        Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
+  }
+
   if (enableGPUMode == "managed") {
     options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
   } else if (enableGPUMode == "unified") {

@wangzpgi
Copy link
Contributor

Does it make sense to add a test?

@clementval
Copy link
Contributor Author

Does it make sense to add a test?

Yes. I had it ready. I forgot to git add it!

@@ -0,0 +1,8 @@
! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does not bbc do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bbc command will fail since I want to check the error. Adding the not makes the test pass because the return value of the command is checked.

@clementval clementval merged commit a3201ce into llvm:main Jun 11, 2025
7 checks passed
rorth pushed a commit to rorth/llvm-project that referenced this pull request Jun 11, 2025
…43640)

These functions are not available in some lower compute capabilities.
Add option in the language feature to enforce the semantic check on
these.
@clementval clementval deleted the cuf_warp branch June 11, 2025 19:03
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…43640)

These functions are not available in some lower compute capabilities.
Add option in the language feature to enforce the semantic check on
these.
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
…43640)

These functions are not available in some lower compute capabilities.
Add option in the language feature to enforce the semantic check on
these.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants