Skip to content

[mlir][openacc] Add device_type support for compute operations #75864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2023

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Dec 18, 2023

This patch adds representation for device_type clause information on compute construct (parallel, kernels, serial).

The device_type clause on compute construct impacts clauses that appear after it. The values impacted by device_type are now tied with an attribute array that represent the device_type associated with them. DeviceType::None is used to represent the value produced by a clause before any device_type. The operands and the attribute information are parser/printed together.

This is an example with vector_length clause. The first value (64) is not impacted by device_type so it will be represented with DeviceType::None. None is not printed. The second value (128) is tied with the device_type(multicore) clause.

!$acc parallel vector_length(64) device_type(multicore) vector_length(256)
acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type<multicore>]) {
}

When multiple values can be produced for a single clause like num_gangs and wait, an extra attribute describe the number of values belonging to each device_type. Values and attributes are parsed/printed together.

acc.parallel num_gangs({%c2 : i32, %c4 : i32}, {%c4 : i32} [#acc.device_type<nvidia>])

While preparing this patch I noticed that the wait devnum is not part of the operations and is not lowered. It will be added in a follow up patch.

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Dec 18, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This patch adds representation for device_type clause information on compute construct (parallel, kernels, serial).

The device_type clause on compute construct impacts clauses that appear after it. The values impacted by device_type are now tied with a attribute array that represent the device_type. DeviceType::None is used to represent the value produced by a clause before any device_type. The operands and the attribute information are parser/printed together.

This is an example with vector_length clause. The first value (64) is not impacted by device_type so it will be represented with DeviceType::None. None is not printed. The second value (128) is tied with the device_type(multicore) clause.

!$acc parallel vector_length(64) device_type(multicore) vector_length(256)
acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type&lt;multicore&gt;]) {
}

When multiple values can be produced for a single clause like num_gangs and wait, an extra attribute describe the number of values belonging to each device_type. Values and attributes are parsed/printed together.

acc.parallel num_gangs({%c2 : i32}, {%c4 : i32} [#acc.device_type&lt;nvidia&gt;])

While preparing this patch I noticed that the wait devnum is not part of the operations and is not lowered. It will be added in a follow up patch.


Patch is 84.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75864.diff

15 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+84-22)
  • (added) flang/test/Lower/OpenACC/acc-device-type.f90 (+44)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-kernels.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+8-8)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+5-5)
  • (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+5-5)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+221-65)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+504-11)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+38-38)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/OpenACC/CMakeLists.txt (+8)
  • (added) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+275)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 531685948bc843..57e14bf77e092c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1451,7 +1451,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
   case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
     return mlir::acc::DeviceType::Multicore;
   }
-  return mlir::acc::DeviceType::Default;
+  return mlir::acc::DeviceType::None;
 }
 
 static void gatherDeviceTypeAttrs(
@@ -1752,26 +1752,25 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                 bool outerCombined = false) {
 
   // Parallel operation operands
-  mlir::Value async;
-  mlir::Value numWorkers;
-  mlir::Value vectorLength;
   mlir::Value ifCond;
   mlir::Value selfCond;
   mlir::Value waitDevnum;
   llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
       copyEntryOperands, copyoutEntryOperands, createEntryOperands,
-      dataClauseOperands, numGangs;
+      dataClauseOperands, numGangs, numWorkers, vectorLength, async;
+  llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
+      vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
+      waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
 
-  // Async, wait and self clause have optional values but can be present with
+  // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
   // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
   bool addSelfAttr = false;
 
   bool hasDefaultNone = false;
@@ -1779,6 +1778,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
+
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separatly as clauses can appear
   // more than once.
@@ -1786,27 +1790,52 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *asyncClause =
             std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      const auto &asyncClauseValue = asyncClause->v;
+      if (asyncClauseValue) { // async has a value.
+        async.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
+        asyncDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      const auto &waitClauseValue = waitClause->v;
+      if (waitClauseValue) { // wait has a value.
+        const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+        const auto &waitList =
+            std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+        auto crtWaitOperands = waitOperands.size();
+        for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+          waitOperands.push_back(fir::getBase(converter.genExprValue(
+              *Fortran::semantics::GetExpr(value), stmtCtx)));
+        }
+        waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+      } else {
+        waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
+      auto crtNumGangs = numGangs.size();
       for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
         numGangs.push_back(fir::getBase(converter.genExprValue(
             *Fortran::semantics::GetExpr(expr), stmtCtx)));
+      numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+      numGangsSegments.push_back(numGangs.size() - crtNumGangs);
     } else if (const auto *numWorkersClause =
                    std::get_if<Fortran::parser::AccClause::NumWorkers>(
                        &clause.u)) {
-      numWorkers = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+      numWorkers.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
+      numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *vectorLengthClause =
                    std::get_if<Fortran::parser::AccClause::VectorLength>(
                        &clause.u)) {
-      vectorLength = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+      vectorLength.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
+      vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *ifClause =
                    std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1957,18 +1986,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       else if ((defaultClause->v).v ==
                llvm::acc::DefaultValue::ACC_Default_present)
         hasDefaultPresent = true;
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value, 8> operands;
   llvm::SmallVector<int32_t, 8> operandSegments;
-  addOperand(operands, operandSegments, async);
+  addOperands(operands, operandSegments, async);
   addOperands(operands, operandSegments, waitOperands);
   if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
     addOperands(operands, operandSegments, numGangs);
-    addOperand(operands, operandSegments, numWorkers);
-    addOperand(operands, operandSegments, vectorLength);
+    addOperands(operands, operandSegments, numWorkers);
+    addOperands(operands, operandSegments, vectorLength);
   }
   addOperand(operands, operandSegments, ifCond);
   addOperand(operands, operandSegments, selfCond);
@@ -1989,10 +2027,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
         builder, currentLocation, eval, operands, operandSegments,
         outerCombined);
 
-  if (addAsyncAttr)
-    computeOp.setAsyncAttrAttr(builder.getUnitAttr());
-  if (addWaitAttr)
-    computeOp.setWaitAttrAttr(builder.getUnitAttr());
   if (addSelfAttr)
     computeOp.setSelfAttrAttr(builder.getUnitAttr());
 
@@ -2001,6 +2035,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   if (hasDefaultPresent)
     computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
 
+  if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
+    if (!numWorkersDeviceTypes.empty())
+      computeOp.setNumWorkersDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
+    if (!vectorLengthDeviceTypes.empty())
+      computeOp.setVectorLengthDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
+    if (!numGangsDeviceTypes.empty())
+      computeOp.setNumGangsDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
+    if (!numGangsSegments.empty())
+      computeOp.setNumGangsSegmentsAttr(
+          builder.getDenseI32ArrayAttr(numGangsSegments));
+  }
+  if (!asyncDeviceTypes.empty())
+    computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+  if (!asyncOnlyDeviceTypes.empty())
+    computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
+
+  if (!waitOperandsDeviceTypes.empty())
+    computeOp.setWaitOperandsDeviceTypeAttr(
+        builder.getArrayAttr(waitOperandsDeviceTypes));
+  if (!waitOperandsSegments.empty())
+    computeOp.setWaitOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!waitOnlyDeviceTypes.empty())
+    computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
+
   if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
     if (!privatizations.empty())
       computeOp.setPrivatizationsAttr(
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
new file mode 100644
index 00000000000000..871dbc95f60fcb
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -0,0 +1,44 @@
+! This test checks lowering of OpenACC device_type clause on directive where its
+! position and the clauses that follow have special semantic
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1()
+
+  !$acc parallel num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
+
+  !$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
+
+  !$acc parallel device_type(multicore) vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+
+end subroutine
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 34e72326972417..93bc699031d550 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels loop async(1)
   DO i = 1, n
@@ -103,7 +103,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels loop wait(1)
   DO i = 1, n
@@ -111,7 +111,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 1f882c6df51061..99629bb8351723 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -40,7 +40,7 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels async(1)
   !$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels wait(1)
   !$acc end kernels
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels  wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -78,7 +78,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels  wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -87,7 +87,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -95,7 +95,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -103,7 +103,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 1856215ce59d13..deee7089033ead 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -64,7 +64,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop async(1)
   DO i = 1, n
@@ -105,7 +105,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop wait(1)
   DO i = 1, n
@@ -113,7 +113,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -128,7 +128,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -143,7 +143,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -157,7 +157,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -171,7 +171,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index bbf51ba36a7dea..a369bf01f25995 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -62,7 +62,7 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel async(1)
   !$acc end parallel
@@ -85,13 +85,13 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel wait(1)
   !$acc end parallel
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -100,7 +100,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -109,7 +109,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -117,7 +117,7 @@ subroutine acc_parallel
   !$acc end parallel
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This patch adds representation for device_type clause information on compute construct (parallel, kernels, serial).

The device_type clause on compute construct impacts clauses that appear after it. The values impacted by device_type are now tied with a attribute array that represent the device_type. DeviceType::None is used to represent the value produced by a clause before any device_type. The operands and the attribute information are parser/printed together.

This is an example with vector_length clause. The first value (64) is not impacted by device_type so it will be represented with DeviceType::None. None is not printed. The second value (128) is tied with the device_type(multicore) clause.

!$acc parallel vector_length(64) device_type(multicore) vector_length(256)
acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type&lt;multicore&gt;]) {
}

When multiple values can be produced for a single clause like num_gangs and wait, an extra attribute describe the number of values belonging to each device_type. Values and attributes are parsed/printed together.

acc.parallel num_gangs({%c2 : i32}, {%c4 : i32} [#acc.device_type&lt;nvidia&gt;])

While preparing this patch I noticed that the wait devnum is not part of the operations and is not lowered. It will be added in a follow up patch.


Patch is 84.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75864.diff

15 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+84-22)
  • (added) flang/test/Lower/OpenACC/acc-device-type.f90 (+44)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-kernels.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+8-8)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+5-5)
  • (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+5-5)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+221-65)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+504-11)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+38-38)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/OpenACC/CMakeLists.txt (+8)
  • (added) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+275)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 531685948bc843..57e14bf77e092c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1451,7 +1451,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
   case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
     return mlir::acc::DeviceType::Multicore;
   }
-  return mlir::acc::DeviceType::Default;
+  return mlir::acc::DeviceType::None;
 }
 
 static void gatherDeviceTypeAttrs(
@@ -1752,26 +1752,25 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                 bool outerCombined = false) {
 
   // Parallel operation operands
-  mlir::Value async;
-  mlir::Value numWorkers;
-  mlir::Value vectorLength;
   mlir::Value ifCond;
   mlir::Value selfCond;
   mlir::Value waitDevnum;
   llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
       copyEntryOperands, copyoutEntryOperands, createEntryOperands,
-      dataClauseOperands, numGangs;
+      dataClauseOperands, numGangs, numWorkers, vectorLength, async;
+  llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
+      vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
+      waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
 
-  // Async, wait and self clause have optional values but can be present with
+  // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
   // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
   bool addSelfAttr = false;
 
   bool hasDefaultNone = false;
@@ -1779,6 +1778,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
+
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separatly as clauses can appear
   // more than once.
@@ -1786,27 +1790,52 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *asyncClause =
             std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      const auto &asyncClauseValue = asyncClause->v;
+      if (asyncClauseValue) { // async has a value.
+        async.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
+        asyncDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      const auto &waitClauseValue = waitClause->v;
+      if (waitClauseValue) { // wait has a value.
+        const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+        const auto &waitList =
+            std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+        auto crtWaitOperands = waitOperands.size();
+        for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+          waitOperands.push_back(fir::getBase(converter.genExprValue(
+              *Fortran::semantics::GetExpr(value), stmtCtx)));
+        }
+        waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+      } else {
+        waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
+      auto crtNumGangs = numGangs.size();
       for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
         numGangs.push_back(fir::getBase(converter.genExprValue(
             *Fortran::semantics::GetExpr(expr), stmtCtx)));
+      numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+      numGangsSegments.push_back(numGangs.size() - crtNumGangs);
     } else if (const auto *numWorkersClause =
                    std::get_if<Fortran::parser::AccClause::NumWorkers>(
                        &clause.u)) {
-      numWorkers = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+      numWorkers.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
+      numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *vectorLengthClause =
                    std::get_if<Fortran::parser::AccClause::VectorLength>(
                        &clause.u)) {
-      vectorLength = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+      vectorLength.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
+      vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *ifClause =
                    std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1957,18 +1986,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       else if ((defaultClause->v).v ==
                llvm::acc::DefaultValue::ACC_Default_present)
         hasDefaultPresent = true;
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value, 8> operands;
   llvm::SmallVector<int32_t, 8> operandSegments;
-  addOperand(operands, operandSegments, async);
+  addOperands(operands, operandSegments, async);
   addOperands(operands, operandSegments, waitOperands);
   if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
     addOperands(operands, operandSegments, numGangs);
-    addOperand(operands, operandSegments, numWorkers);
-    addOperand(operands, operandSegments, vectorLength);
+    addOperands(operands, operandSegments, numWorkers);
+    addOperands(operands, operandSegments, vectorLength);
   }
   addOperand(operands, operandSegments, ifCond);
   addOperand(operands, operandSegments, selfCond);
@@ -1989,10 +2027,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
         builder, currentLocation, eval, operands, operandSegments,
         outerCombined);
 
-  if (addAsyncAttr)
-    computeOp.setAsyncAttrAttr(builder.getUnitAttr());
-  if (addWaitAttr)
-    computeOp.setWaitAttrAttr(builder.getUnitAttr());
   if (addSelfAttr)
     computeOp.setSelfAttrAttr(builder.getUnitAttr());
 
@@ -2001,6 +2035,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   if (hasDefaultPresent)
     computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
 
+  if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
+    if (!numWorkersDeviceTypes.empty())
+      computeOp.setNumWorkersDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
+    if (!vectorLengthDeviceTypes.empty())
+      computeOp.setVectorLengthDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
+    if (!numGangsDeviceTypes.empty())
+      computeOp.setNumGangsDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
+    if (!numGangsSegments.empty())
+      computeOp.setNumGangsSegmentsAttr(
+          builder.getDenseI32ArrayAttr(numGangsSegments));
+  }
+  if (!asyncDeviceTypes.empty())
+    computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+  if (!asyncOnlyDeviceTypes.empty())
+    computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
+
+  if (!waitOperandsDeviceTypes.empty())
+    computeOp.setWaitOperandsDeviceTypeAttr(
+        builder.getArrayAttr(waitOperandsDeviceTypes));
+  if (!waitOperandsSegments.empty())
+    computeOp.setWaitOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!waitOnlyDeviceTypes.empty())
+    computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
+
   if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
     if (!privatizations.empty())
       computeOp.setPrivatizationsAttr(
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
new file mode 100644
index 00000000000000..871dbc95f60fcb
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -0,0 +1,44 @@
+! This test checks lowering of OpenACC device_type clause on directive where its
+! position and the clauses that follow have special semantic
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1()
+
+  !$acc parallel num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
+
+  !$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
+
+  !$acc parallel device_type(multicore) vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+
+end subroutine
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 34e72326972417..93bc699031d550 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels loop async(1)
   DO i = 1, n
@@ -103,7 +103,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels loop wait(1)
   DO i = 1, n
@@ -111,7 +111,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 1f882c6df51061..99629bb8351723 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -40,7 +40,7 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels async(1)
   !$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels wait(1)
   !$acc end kernels
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels  wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -78,7 +78,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels  wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -87,7 +87,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -95,7 +95,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -103,7 +103,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 1856215ce59d13..deee7089033ead 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -64,7 +64,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop async(1)
   DO i = 1, n
@@ -105,7 +105,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop wait(1)
   DO i = 1, n
@@ -113,7 +113,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -128,7 +128,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -143,7 +143,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -157,7 +157,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -171,7 +171,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index bbf51ba36a7dea..a369bf01f25995 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -62,7 +62,7 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel async(1)
   !$acc end parallel
@@ -85,13 +85,13 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel wait(1)
   !$acc end parallel
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -100,7 +100,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -109,7 +109,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -117,7 +117,7 @@ subroutine acc_parallel
   !$acc end parallel
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This patch adds representation for device_type clause information on compute construct (parallel, kernels, serial).

The device_type clause on compute construct impacts clauses that appear after it. The values impacted by device_type are now tied with a attribute array that represent the device_type. DeviceType::None is used to represent the value produced by a clause before any device_type. The operands and the attribute information are parser/printed together.

This is an example with vector_length clause. The first value (64) is not impacted by device_type so it will be represented with DeviceType::None. None is not printed. The second value (128) is tied with the device_type(multicore) clause.

!$acc parallel vector_length(64) device_type(multicore) vector_length(256)
acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type&lt;multicore&gt;]) {
}

When multiple values can be produced for a single clause like num_gangs and wait, an extra attribute describe the number of values belonging to each device_type. Values and attributes are parsed/printed together.

acc.parallel num_gangs({%c2 : i32}, {%c4 : i32} [#acc.device_type&lt;nvidia&gt;])

While preparing this patch I noticed that the wait devnum is not part of the operations and is not lowered. It will be added in a follow up patch.


Patch is 84.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75864.diff

15 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+84-22)
  • (added) flang/test/Lower/OpenACC/acc-device-type.f90 (+44)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-kernels.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+8-8)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+5-5)
  • (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+5-5)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+221-65)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+504-11)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+38-38)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/OpenACC/CMakeLists.txt (+8)
  • (added) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+275)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 531685948bc843..57e14bf77e092c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1451,7 +1451,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
   case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
     return mlir::acc::DeviceType::Multicore;
   }
-  return mlir::acc::DeviceType::Default;
+  return mlir::acc::DeviceType::None;
 }
 
 static void gatherDeviceTypeAttrs(
@@ -1752,26 +1752,25 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                 bool outerCombined = false) {
 
   // Parallel operation operands
-  mlir::Value async;
-  mlir::Value numWorkers;
-  mlir::Value vectorLength;
   mlir::Value ifCond;
   mlir::Value selfCond;
   mlir::Value waitDevnum;
   llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
       copyEntryOperands, copyoutEntryOperands, createEntryOperands,
-      dataClauseOperands, numGangs;
+      dataClauseOperands, numGangs, numWorkers, vectorLength, async;
+  llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
+      vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
+      waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
 
-  // Async, wait and self clause have optional values but can be present with
+  // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
   // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
   bool addSelfAttr = false;
 
   bool hasDefaultNone = false;
@@ -1779,6 +1778,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
+
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separatly as clauses can appear
   // more than once.
@@ -1786,27 +1790,52 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *asyncClause =
             std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      const auto &asyncClauseValue = asyncClause->v;
+      if (asyncClauseValue) { // async has a value.
+        async.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
+        asyncDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      const auto &waitClauseValue = waitClause->v;
+      if (waitClauseValue) { // wait has a value.
+        const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+        const auto &waitList =
+            std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+        auto crtWaitOperands = waitOperands.size();
+        for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+          waitOperands.push_back(fir::getBase(converter.genExprValue(
+              *Fortran::semantics::GetExpr(value), stmtCtx)));
+        }
+        waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+      } else {
+        waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
+      auto crtNumGangs = numGangs.size();
       for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
         numGangs.push_back(fir::getBase(converter.genExprValue(
             *Fortran::semantics::GetExpr(expr), stmtCtx)));
+      numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+      numGangsSegments.push_back(numGangs.size() - crtNumGangs);
     } else if (const auto *numWorkersClause =
                    std::get_if<Fortran::parser::AccClause::NumWorkers>(
                        &clause.u)) {
-      numWorkers = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+      numWorkers.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
+      numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *vectorLengthClause =
                    std::get_if<Fortran::parser::AccClause::VectorLength>(
                        &clause.u)) {
-      vectorLength = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+      vectorLength.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
+      vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *ifClause =
                    std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1957,18 +1986,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       else if ((defaultClause->v).v ==
                llvm::acc::DefaultValue::ACC_Default_present)
         hasDefaultPresent = true;
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value, 8> operands;
   llvm::SmallVector<int32_t, 8> operandSegments;
-  addOperand(operands, operandSegments, async);
+  addOperands(operands, operandSegments, async);
   addOperands(operands, operandSegments, waitOperands);
   if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
     addOperands(operands, operandSegments, numGangs);
-    addOperand(operands, operandSegments, numWorkers);
-    addOperand(operands, operandSegments, vectorLength);
+    addOperands(operands, operandSegments, numWorkers);
+    addOperands(operands, operandSegments, vectorLength);
   }
   addOperand(operands, operandSegments, ifCond);
   addOperand(operands, operandSegments, selfCond);
@@ -1989,10 +2027,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
         builder, currentLocation, eval, operands, operandSegments,
         outerCombined);
 
-  if (addAsyncAttr)
-    computeOp.setAsyncAttrAttr(builder.getUnitAttr());
-  if (addWaitAttr)
-    computeOp.setWaitAttrAttr(builder.getUnitAttr());
   if (addSelfAttr)
     computeOp.setSelfAttrAttr(builder.getUnitAttr());
 
@@ -2001,6 +2035,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   if (hasDefaultPresent)
     computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
 
+  if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
+    if (!numWorkersDeviceTypes.empty())
+      computeOp.setNumWorkersDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
+    if (!vectorLengthDeviceTypes.empty())
+      computeOp.setVectorLengthDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
+    if (!numGangsDeviceTypes.empty())
+      computeOp.setNumGangsDeviceTypeAttr(
+          mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
+    if (!numGangsSegments.empty())
+      computeOp.setNumGangsSegmentsAttr(
+          builder.getDenseI32ArrayAttr(numGangsSegments));
+  }
+  if (!asyncDeviceTypes.empty())
+    computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+  if (!asyncOnlyDeviceTypes.empty())
+    computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
+
+  if (!waitOperandsDeviceTypes.empty())
+    computeOp.setWaitOperandsDeviceTypeAttr(
+        builder.getArrayAttr(waitOperandsDeviceTypes));
+  if (!waitOperandsSegments.empty())
+    computeOp.setWaitOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!waitOnlyDeviceTypes.empty())
+    computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
+
   if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
     if (!privatizations.empty())
       computeOp.setPrivatizationsAttr(
diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
new file mode 100644
index 00000000000000..871dbc95f60fcb
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -0,0 +1,44 @@
+! This test checks lowering of OpenACC device_type clause on directive where its
+! position and the clauses that follow have special semantic
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1()
+
+  !$acc parallel num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
+
+  !$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
+
+  !$acc parallel vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
+
+  !$acc parallel device_type(multicore) vector_length(1)
+  !$acc end parallel
+
+! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
+
+  !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+
+end subroutine
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 34e72326972417..93bc699031d550 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels loop async(1)
   DO i = 1, n
@@ -103,7 +103,7 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels loop wait(1)
   DO i = 1, n
@@ -111,7 +111,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 1f882c6df51061..99629bb8351723 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -40,7 +40,7 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]} 
 
   !$acc kernels async(1)
   !$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
 
 ! CHECK:      acc.kernels  {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc kernels wait(1)
   !$acc end kernels
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  wait([[WAIT1]] : i32) {
+! CHECK:      acc.kernels  wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -78,7 +78,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.kernels  wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -87,7 +87,7 @@ subroutine acc_kernels
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.kernels  wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -95,7 +95,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -103,7 +103,7 @@ subroutine acc_kernels
   !$acc end kernels
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.kernels  num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.kernels  num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 1856215ce59d13..deee7089033ead 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -64,7 +64,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop async(1)
   DO i = 1, n
@@ -105,7 +105,7 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel loop wait(1)
   DO i = 1, n
@@ -113,7 +113,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -128,7 +128,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -143,7 +143,7 @@ subroutine acc_parallel_loop
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -157,7 +157,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel num_gangs([[NUMGANGS1]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS1]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
@@ -171,7 +171,7 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel num_gangs([[NUMGANGS2]] : i32) {
+! CHECK:      acc.parallel num_gangs({[[NUMGANGS2]] : i32}) {
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index bbf51ba36a7dea..a369bf01f25995 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -62,7 +62,7 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {asyncAttr}
+! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
 
   !$acc parallel async(1)
   !$acc end parallel
@@ -85,13 +85,13 @@ subroutine acc_parallel
 
 ! CHECK:      acc.parallel {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitAttr}
+! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
 
   !$acc parallel wait(1)
   !$acc end parallel
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK:      acc.parallel wait([[WAIT1]] : i32) {
+! CHECK:      acc.parallel wait({[[WAIT1]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -100,7 +100,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK:      acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -109,7 +109,7 @@ subroutine acc_parallel
 
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK:      acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) {
+! CHECK:      acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -117,7 +117,7 @@ subroutine acc_parallel
   !$acc end parallel
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1...
[truncated]

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow! This is outstanding. The design and implementation make sense to me. I prefer this over other techniques (such as duplicating and specializing regions per device_type at this level). I like the getters that simplify access to the information. Thank you for this!

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@clementval clementval merged commit 8b885eb into llvm:main Dec 20, 2023
@vitalybuka
Copy link
Collaborator

Here is a broken bot https://lab.llvm.org/buildbot/#/builders/5/builds/39439/steps/9/logs/stdio
Could you fix or revert?

@clementval
Copy link
Contributor Author

Here is a broken bot https://lab.llvm.org/buildbot/#/builders/5/builds/39439/steps/9/logs/stdio Could you fix or revert?

Sorry about that. I'll revert it for now since I don't have time to fix it in the next 1-2 hours.

clementval added a commit that referenced this pull request Dec 21, 2023
clementval added a commit that referenced this pull request Dec 21, 2023
@clementval
Copy link
Contributor Author

Here is a broken bot https://lab.llvm.org/buildbot/#/builders/5/builds/39439/steps/9/logs/stdio Could you fix or revert?

By the way, there is no email send for this failure. Is that normal?

@vitalybuka
Copy link
Collaborator

Here is a broken bot https://lab.llvm.org/buildbot/#/builders/5/builds/39439/steps/9/logs/stdio Could you fix or revert?

By the way, there is no email send for this failure. Is that normal?

Search email for "sanitizer-x86_64-linux-fast", I've got email for a followup build for the same leak because I am in blame list.

@vitalybuka
Copy link
Collaborator

Here is a broken bot https://lab.llvm.org/buildbot/#/builders/5/builds/39439/steps/9/logs/stdio Could you fix or revert?

By the way, there is no email send for this failure. Is that normal?

Search email for "sanitizer-x86_64-linux-fast", I've got email for a followup build for the same leak because I am in blame list.

cmake/lit or whatever still continues to execute the test binary for incremental builds even after revert :(

@clementval
Copy link
Contributor Author

Here is a broken bot https://lab.llvm.org/buildbot/#/builders/5/builds/39439/steps/9/logs/stdio Could you fix or revert?

By the way, there is no email send for this failure. Is that normal?

Search email for "sanitizer-x86_64-linux-fast", I've got email for a followup build for the same leak because I am in blame list.

cmake/lit or whatever still continues to execute the test binary for incremental builds even after revert :(

I saw that but I don't know what to do besides reverting my change. I asked on discourse in the sanitizer channel but not reply so far.

@vitalybuka
Copy link
Collaborator

I saw that but I don't know what to do besides reverting my change. I asked on discourse in the sanitizer channel but not reply so far.

I will just reset bots. Thanks for revert.

@clementval
Copy link
Contributor Author

I will just reset bots. Thanks for revert.

Thanks!

clementval added a commit that referenced this pull request Dec 21, 2023
Re-land PR after being reverted because of buildbot failures.

This patch adds representation for `device_type` clause information on
compute construct (parallel, kernels, serial).

The `device_type` clause on compute construct impacts clauses that
appear after it. The values impacted by `device_type` are now tied with
an attribute array that represent the device_type associated with them.
`DeviceType::None` is used to represent the value produced by a clause
before any `device_type`. The operands and the attribute information are
parser/printed together.

This is an example with `vector_length` clause. The first value (64) is
not impacted by `device_type` so it will be represented with
DeviceType::None. None is not printed. The second value (128) is tied
with the `device_type(multicore)` clause.
```
!$acc parallel vector_length(64) device_type(multicore) vector_length(256)
```
```
acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type<multicore>]) {
}
```

When multiple values can be produced for a single clause like
`num_gangs` and `wait`, an extra attribute describe the number of values
belonging to each `device_type`. Values and attributes are
parsed/printed together.

```
acc.parallel num_gangs({%c2 : i32, %c4 : i32}, {%c4 : i32} [#acc.device_type<nvidia>])
```

While preparing this patch I noticed that the wait devnum is not part of
the operations and is not lowered. It will be added in a follow up
patch.
clementval added a commit that referenced this pull request Jan 5, 2024
Following #75864, this patch adds device_type support to the data
operation on the async and wait operands and attributes.
clementval added a commit that referenced this pull request Jan 9, 2024
These tests were initially pushed together with
#75864 but they were triggering
some buildbot failure (sanitizers). They now make use of the
`OwningOpRef` so all the resources are correctly destroyed at the end of
each tests.
They will be extended to includes all the extra getter functions added
with device_type support.
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
)

These tests were initially pushed together with
llvm#75864 but they were triggering
some buildbot failure (sanitizers). They now make use of the
`OwningOpRef` so all the resources are correctly destroyed at the end of
each tests.
They will be extended to includes all the extra getter functions added
with device_type support.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants