Skip to content

[mlir][flang][openacc] Device type support on acc routine op #78375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 18, 2024

Conversation

clementval
Copy link
Contributor

This patch add support for device_type on the acc.routine operation. device_type can be specified on seq, worker, vector, gang and bind information.

The support is following the same design than the one for compute operations, data operation and the loop operation.

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Jan 17, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2024

@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This patch add support for device_type on the acc.routine operation. device_type can be specified on seq, worker, vector, gang and bind information.

The support is following the same design than the one for compute operations, data operation and the loop operation.


Patch is 30.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78375.diff

6 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+117-21)
  • (modified) flang/test/Lower/OpenACC/acc-routine.f90 (+16-2)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+47-11)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+255-32)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+2-2)
  • (modified) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+58)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index db9ed72bc87257..fd89d27db74dc0 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   llvm_unreachable("unsupported declarative directive");
 }
 
+static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  for (auto attr : arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+  return false;
+}
+
+template <typename RetTy, typename AttrTy>
+static std::optional<RetTy>
+getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
+                              llvm::SmallVector<mlir::Attribute> &deviceTypes,
+                              mlir::acc::DeviceType deviceType) {
+  assert(attributes.size() == deviceTypes.size() &&
+         "expect same number of attributes");
+  for (auto it : llvm::enumerate(deviceTypes)) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
+    if (deviceTypeAttr.getValue() == deviceType) {
+      if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
+        auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
+        return strAttr.getValue();
+      } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
+        auto intAttr =
+            mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
+        return intAttr.getInt();
+      }
+    }
+  }
+  return std::nullopt;
+}
+
+static bool compareDeviceTypeInfo(
+    mlir::acc::RoutineOp op,
+    llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
+  for (uint32_t dtypeInt = 0;
+       dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
+    auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
+    if (op.getBindNameValue(dtype) !=
+        getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
+            bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
+      return false;
+    if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
+      return false;
+    if (op.getGangDimValue(dtype) !=
+        getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
+            gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
+      return false;
+    if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
+      return false;
+    if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
+      return false;
+    if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
+      return false;
+  }
+  return true;
+}
+
 static void attachRoutineInfo(mlir::func::FuncOp func,
                               mlir::SymbolRefAttr routineAttr) {
   llvm::SmallVector<mlir::SymbolRefAttr> routines;
@@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct(
       funcName = funcOp.getName();
     }
   }
-  bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
-       hasNohost = false;
-  std::optional<std::string> bindName = std::nullopt;
-  std::optional<int64_t> gangDim = std::nullopt;
+  bool hasNohost = false;
+
+  llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
+      workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
+      gangDimDeviceTypes, gangDimValues;
+
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
 
   for (const Fortran::parser::AccClause &clause : clauses.v) {
     if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      hasSeq = true;
+      seqDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *gangClause =
                    std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-      hasGang = true;
+
       if (gangClause->v) {
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct(
             if (!dimValue)
               mlir::emitError(loc,
                               "dim value must be a constant positive integer");
-            gangDim = *dimValue;
+            gangDimValues.push_back(
+                builder.getIntegerAttr(builder.getI64Type(), *dimValue));
+            gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
           }
         }
+      } else {
+        gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
-      hasVector = true;
+      vectorDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
-      hasWorker = true;
+      workerDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
       hasNohost = true;
     } else if (const auto *bindClause =
                    std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
       if (const auto *name =
               std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
-        bindName = converter.mangleName(*name->symbol);
+        bindNames.push_back(
+            builder.getStringAttr(converter.mangleName(*name->symbol)));
+        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
       } else if (const auto charExpr =
                      std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
                          &bindClause->v.u)) {
@@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
                                                           *charExpr);
         if (!name)
           mlir::emitError(loc, "Could not retrieve the bind name");
-        bindName = *name;
+        bindNames.push_back(builder.getStringAttr(*name));
+        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
       }
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
     }
   }
 
@@ -3575,12 +3663,11 @@ void Fortran::lower::genOpenACCRoutineConstruct(
     if (routineOp.getFuncName().str().compare(funcName) == 0) {
       // If the routine is already specified with the same clauses, just skip
       // the operation creation.
-      if (routineOp.getBindName() == bindName &&
-          routineOp.getGang() == hasGang &&
-          routineOp.getWorker() == hasWorker &&
-          routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
-          routineOp.getNohost() == hasNohost &&
-          routineOp.getGangDim() == gangDim)
+      if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
+                                gangDeviceTypes, gangDimValues,
+                                gangDimDeviceTypes, seqDeviceTypes,
+                                workerDeviceTypes, vectorDeviceTypes) &&
+          routineOp.getNohost() == hasNohost)
         return;
       mlir::emitError(loc, "Routine already specified with different clauses");
     }
@@ -3588,10 +3675,19 @@ void Fortran::lower::genOpenACCRoutineConstruct(
 
   modBuilder.create<mlir::acc::RoutineOp>(
       loc, routineOpName.str(), funcName,
-      bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
-      hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
-      gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
-              : mlir::IntegerAttr{});
+      bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
+      bindNameDeviceTypes.empty() ? nullptr
+                                  : builder.getArrayAttr(bindNameDeviceTypes),
+      workerDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(workerDeviceTypes),
+      vectorDeviceTypes.empty() ? nullptr
+                                : builder.getArrayAttr(vectorDeviceTypes),
+      seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
+      hasNohost, /*implicit=*/false,
+      gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
+      gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
+      gangDimDeviceTypes.empty() ? nullptr
+                                 : builder.getArrayAttr(gangDimDeviceTypes));
 
   if (funcOp)
     attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 8b94279503334a..8e9e65da32cd19 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -2,12 +2,14 @@
 
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
-
+! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
 ! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq
 ! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
 ! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
 ! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
-! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim = 1 : i32)
+! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(1 : i64)
 ! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
 ! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
 ! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
@@ -106,3 +108,15 @@ subroutine acc_routine14()
 subroutine acc_routine15()
   !$acc routine bind(acc_routine16)
 end subroutine
+
+subroutine acc_routine16()
+  !$acc routine device_type(host) seq dtype(nvidia) gang
+end subroutine
+
+subroutine acc_routine17()
+  !$acc routine device_type(host) worker dtype(multicore) vector 
+end subroutine
+
+subroutine acc_routine18()
+  !$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16) 
+end subroutine
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 24f129d92805c0..7344ab2852b9ce 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1994,27 +1994,63 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
 
   let arguments = (ins SymbolNameAttr:$sym_name,
                        SymbolNameAttr:$func_name,
-                       OptionalAttr<StrAttr>:$bind_name,
-                       UnitAttr:$gang,
-                       UnitAttr:$worker,
-                       UnitAttr:$vector,
-                       UnitAttr:$seq,
+                       OptionalAttr<StrArrayAttr>:$bindName,
+                       OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
+                       OptionalAttr<DeviceTypeArrayAttr>:$worker,
+                       OptionalAttr<DeviceTypeArrayAttr>:$vector,
+                       OptionalAttr<DeviceTypeArrayAttr>:$seq,
                        UnitAttr:$nohost,
                        UnitAttr:$implicit,
-                       OptionalAttr<APIntAttr>:$gangDim);
+                       OptionalAttr<DeviceTypeArrayAttr>:$gang,
+                       OptionalAttr<I64ArrayAttr>:$gangDim,
+                       OptionalAttr<DeviceTypeArrayAttr>:$gangDimDeviceType);
 
   let extraClassDeclaration = [{
     static StringRef getGangDimKeyword() { return "dim"; }
+
+    /// Return true if the op has the worker attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWorker();
+    /// Return true if the op has the worker attribute for the given
+    /// device_type.
+    bool hasWorker(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the vector attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasVector();
+    /// Return true if the op has the vector attribute for the given
+    /// device_type.
+    bool hasVector(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the seq attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasSeq();
+    /// Return true if the op has the seq attribute for the given
+    /// device_type.
+    bool hasSeq(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the gang attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasGang();
+    /// Return true if the op has the gang attribute for the given
+    /// device_type.
+    bool hasGang(mlir::acc::DeviceType deviceType);
+
+    std::optional<int64_t> getGangDimValue();
+    std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
+
+    std::optional<llvm::StringRef> getBindNameValue();
+    std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     $sym_name `func` `(` $func_name `)`
     oilist (
-        `bind` `(` $bind_name `)`
-      | `gang` `` custom<RoutineGangClause>($gang, $gangDim)
-      | `worker` $worker
-      | `vector` $vector
-      | `seq` $seq
+        `bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
+      | `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
+      | `worker` custom<DeviceTypeArrayAttr>($worker)
+      | `vector` custom<DeviceTypeArrayAttr>($vector)
+      | `seq` custom<DeviceTypeArrayAttr>($seq)
       | `nohost` $nohost
       | `implicit` $implicit
     ) attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bf3264b5da9802..82e614cb7572f6 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1046,7 +1046,7 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
   return success();
 }
 
-bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
   if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
     return true;
   return false;
@@ -2131,55 +2131,278 @@ LogicalResult acc::DeclareOp::verify() {
 // RoutineOp
 //===----------------------------------------------------------------------===//
 
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(arrayAttr))
+    return false;
+
+  for (auto attr : *arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+
+  return false;
+}
+
+static unsigned getParallelismForDeviceType(acc::RoutineOp op,
+                                            acc::DeviceType dtype) {
+  unsigned parallelism = 0;
+  parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
+  parallelism += op.hasWorker(dtype) ? 1 : 0;
+  parallelism += op.hasVector(dtype) ? 1 : 0;
+  parallelism += op.hasSeq(dtype) ? 1 : 0;
+  return parallelism;
+}
+
 LogicalResult acc::RoutineOp::verify() {
-  int parallelism = 0;
-  parallelism += getGang() ? 1 : 0;
-  parallelism += getWorker() ? 1 : 0;
-  parallelism += getVector() ? 1 : 0;
-  parallelism += getSeq() ? 1 : 0;
+  unsigned baseParallelism =
+      getParallelismForDeviceType(*this, acc::DeviceType::None);
 
-  if (parallelism > 1)
+  if (baseParallelism > 1)
     return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
                           "be present at the same time";
 
+  for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+       ++dtypeInt) {
+    auto dtype = static_cast<acc::DeviceType>(dtypeInt);
+    if (dtype == acc::DeviceType::None)
+      continue;
+    unsigned parallelism = getParallelismForDeviceType(*this, dtype);
+
+    if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
+      return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
+                            "be present at the same time";
+  }
+
   return success();
 }
 
-static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang,
-                                          IntegerAttr &gangDim) {
-  // Since gang clause exists, ensure that unit attribute is set.
-  gang = UnitAttr::get(parser.getBuilder().getContext());
+static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
+                                 mlir::ArrayAttr &deviceTypes) {
+  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
 
-  // Next, look for dim on gang. Don't initialize `gangDim` yet since
-  // we leave it without attribute if there is no `dim` specifier.
-  if (succeeded(parser.parseOptionalLParen())) {
-    // Look for syntax that looks like `dim = 1 : i32`.
-    // Thus first look for `dim =`
-    if (failed(parser.parseKeyword(RoutineOp::getGangDimKeyword())) ||
-        failed(parser.parseEqual()))
-      return failure();
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(bindNameAttrs.emplace_back()))
+          return failure();
+        if (failed(parser.parseOptionalLSquare())) {
+          deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        } else {
+          if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        }
+        return success();
+      })))
+    return failure();
 
-    int64_t dimValue;
-    Type valueType;
-    // Now look for `1 : i32`
-    if (failed(parser.parseInteger(dimValue)) ||
-        failed(parser.parseColonType(valueType)))
-      return failure();
+  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
+  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+
+  return success();
+}
+
+static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                          std::optional<mlir::ArrayAttr> bindName,
+                          std::optional<mlir::ArrayAttr> deviceTypes) {
+  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
+                        [&](const auto &pair) {
+                          p << std::get<0>(pair);
+                          printSingleDeviceType(p, std::get<1>(pair));
+                        });
+}
+
+static ParseResult parseRoutineGangClause(OpAsmParser &parser,
+                                          mlir::ArrayAttr &gang,
+                                          mlir::ArrayAttr &gangDim,
+                                          mlir::ArrayAttr &gangDimDeviceTypes) {
 
-    gangDim = IntegerAttr::get(valueType, dimValue);
+  llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
+      gangDimDeviceTypeAttrs;
+  bool needCommaBeforeOperands = false;
 
-    if (failed(parser.parseRParen()))
+  // Gang keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+        pa...
[truncated]

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work! I only have one concern but I could be convinced it is OK.

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@clementval clementval merged commit b06bc7c into llvm:main Jan 18, 2024
@clementval clementval deleted the acc_routine_dtype_mlir branch January 18, 2024 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants