Skip to content

[mlir][openacc][flang] Simplify gang, vector and worker representation #77667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 11, 2024

Conversation

clementval
Copy link
Contributor

The IR representation for gang, vector and worker has grown with the support for device_type. This patch simplify the IR representation for gang, vector and worker information on the acc.loop operation.

When the only the keyword is present without any values, the information is printed at the same place than when there is values. The device_type is omitted if there is no values and it is equal to None. Otherwise the full information is displayed. First the keyword only device_type information and then the values with their device_type.

This patch simplify the IR representation for gang, vector and worker
information on the acc.loop operation.

The device_type is omitted when possible when it is equal to `None`.
@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Jan 10, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir-openacc

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

The IR representation for gang, vector and worker has grown with the support for device_type. This patch simplify the IR representation for gang, vector and worker information on the acc.loop operation.

When the only the keyword is present without any values, the information is printed at the same place than when there is values. The device_type is omitted if there is no values and it is equal to None. Otherwise the full information is displayed. First the keyword only device_type information and then the values with their device_type.


Patch is 29.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77667.diff

7 Files Affected:

  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+6-6)
  • (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+12-6)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+6-6)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+6-6)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+9-11)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+228-47)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+20-20)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index b17f2e2c80b20f..755111a69467bf 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -497,10 +497,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -550,10 +550,10 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -591,10 +591,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index e7f65770498fe2..42e14afb35f522 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -67,10 +67,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop gang {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {gang = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop gang(num: 8)
   DO i = 1, n
@@ -109,10 +109,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop vector {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {vector = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop vector(128)
   DO i = 1, n
@@ -141,10 +141,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop worker {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {worker = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop worker(128)
   DO i = 1, n
@@ -320,4 +320,10 @@ program acc_loop
 ! CHECK: acc.loop
 ! CHECK: fir.do_loop
 
+  !$acc loop gang device_type(nvidia) gang(8)
+  DO i = 1, n
+  END DO
+
+! CHECK: acc.loop gang([#acc.device_type<none>], {num=%c8{{.*}} : i32} [#acc.device_type<nvidia>])
+
 end program
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index e9150a71f3826b..faef8517850e0d 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -512,10 +512,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -565,10 +565,10 @@ subroutine acc_parallel_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -606,10 +606,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 6041e7fb1b4906..9333761e4c2962 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -447,10 +447,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -500,10 +500,10 @@ subroutine acc_serial_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -541,10 +541,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index e6954062a50e0c..24f129d92805c0 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1483,7 +1483,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
     Example:
 
     ```mlir
-    acc.loop {
+    acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
           scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -1492,10 +1492,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
         }
       }
       acc.yield
-    } attributes {
-      collapse = [3], gang = [#acc.device_type<none>],
-      vector = [#acc.device_type<none>]
-    }
+    } attributes { collapse = [3] }
     ```
   }];
 
@@ -1613,13 +1610,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     oilist(
-        `gang` `` `(` custom<GangClause>($gangOperands, type($gangOperands),
+        `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,
-            $gangOperandsSegments) `)`
-      | `worker` `` `(` custom<DeviceTypeOperands>($workerNumOperands,
-            type($workerNumOperands), $workerNumOperandsDeviceType) `)`
-      | `vector` `` `(` custom<DeviceTypeOperands>($vectorOperands,
-            type($vectorOperands), $vectorOperandsDeviceType) `)`
+            $gangOperandsSegments, $gang)
+      | `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+            $workerNumOperands, type($workerNumOperands),
+            $workerNumOperandsDeviceType, $worker)
+      | `vector` `` custom<DeviceTypeOperandsWithKeywordOnly>($vectorOperands,
+            type($vectorOperands), $vectorOperandsDeviceType, $vector)
       | `private` `(` custom<SymOperandList>(
             $privateOperands, type($privateOperands), $privatizations) `)`
       | `tile` `(` custom<DeviceTypeOperandsWithSegment>($tileOperands,
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c53673fa426038..bf3264b5da9802 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -921,6 +921,12 @@ static ParseResult parseDeviceTypeOperandsWithSegment(
   return success();
 }
 
+static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
+  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+    p << " [" << attr << "]";
+}
+
 static void printDeviceTypeOperandsWithSegment(
     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -937,10 +943,7 @@ static void printDeviceTypeOperandsWithSegment(
       ++opIdx;
     }
     p << "}";
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
 }
 
@@ -978,11 +981,120 @@ printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
     if (i != 0)
       p << ", ";
     p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
+  }
+}
+
+static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::ArrayAttr &keywordOnlyDeviceType) {
+
+  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
+  bool needCommaBeforeOperands = false;
+
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    keywordOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  keywordOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  llvm::SmallVector<DeviceTypeAttr> attributes;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        if (succeeded(parser.parseOptionalLSquare())) {
+          if (parser.parseAttribute(attributes.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        } else {
+          attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        }
+        return success();
+      })))
+    return failure();
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+  return success();
+}
+
+bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+  p << "[";
+  for (unsigned i = 0; i < deviceTypes.value().size(); ++i) {
+    if (i != 0)
+      p << ", ";
     auto deviceTypeAttr =
         mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    p << deviceTypeAttr;
+  }
+  p << "]";
+}
+
+static void printDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
+    mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
+    std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && keywordOnlyDeviceTypes &&
+      keywordOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*keywordOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+
+  printDeviceTypes(p, keywordOnlyDeviceTypes);
+
+  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
+      hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
+  p << ")";
 }
 
 //===----------------------------------------------------------------------===//
@@ -1215,7 +1327,7 @@ static ParseResult parseGangValue(
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
     llvm::SmallVectorImpl<Type> &types,
     llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
-    bool &needComa, bool &newValue) {
+    bool &needCommaBetweenValues, bool &newValue) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
     if (parser.parseEqual())
       return failure();
@@ -1223,7 +1335,7 @@ static ParseResult parseGangValue(
         parser.parseColonType(types.emplace_back()))
       return failure();
     attributes.push_back(gangArgType);
-    needComa = true;
+    needCommaBetweenValues = true;
     newValue = true;
   }
   return success();
@@ -1233,11 +1345,37 @@ static ParseResult parseGangClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
     llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
-    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments) {
-  llvm::SmallVector<GangArgTypeAttr> attributes;
-  llvm::SmallVector<DeviceTypeAttr> deviceTypeAttributes;
+    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
+    mlir::ArrayAttr &gangOnlyDeviceType) {
+  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
   llvm::SmallVector<int32_t> seg;
-  bool needComa = false;
+  bool needCommaBetweenValues = false;
+  bool needCommaBeforeOperands = false;
+
+  // Gang only keyword
+  if (failed(parser.parseOptionalLParen())) {
+    gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    gangOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse gang only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  gangOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
 
   auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
                                                 mlir::acc::GangArgType::Num);
@@ -1247,6 +1385,11 @@ static ParseResult parseGangClause(
       parser.getContext(), mlir::acc::GangArgType::Static);
 
   do {
+    if (needCommaBeforeOperands) {
+      needCommaBeforeOperands = false;
+      continue;
+    }
+
     if (failed(parser.parseLBrace()))
       return failure();
 
@@ -1254,7 +1397,7 @@ static ParseResult parseGangClause(
     while (true) {
       bool newValue = false;
       bool needValue = false;
-      if (needComa) {
+      if (needCommaBetweenValues) {
         if (succeeded(parser.parseOptionalComma()))
           needValue = true; // expect a new value after comma.
         else
@@ -1262,16 +1405,19 @@ static ParseResult parseGangClause(
       }
 
       if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argNum, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argNum,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argDim, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argDim,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argStatic, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argStatic,
+                                needCommaBetweenValues, newValue)))
         return failure();
 
       if (!newValue && needValue) {
@@ -1305,13 +1451,18 @@ static ParseResult parseGangClause(
 
   } while (succeeded(parser.parseOptionalComma()));
 
-  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
-                                               attributes.end());
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
+                                               gangArgTypeAttributes.end());
   gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
+  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
+
+  llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
+      gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
+  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
 
-  llvm::SmallVector<mlir::Attribute> deviceTypeAttr(
-      deviceTypeAttributes.begin(), deviceTypeAttributes.end());
-  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttr);
   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
   return success();
 }
@@ -1320,33 +1471,63 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
                      mlir::OperandRange operands, mlir::TypeRange types,
                      std::optional<mlir::ArrayAttr> gangArgTypes,
                      std::optional<mlir::ArrayAttr> deviceTypes,
-                     std::optional<mlir::DenseI32ArrayAttr> segments) {
-  unsigned opIdx = 0;
-  for (unsigned i = 0; i < deviceTypes->size(); ++i) {
-    if (i != 0)
-      p << ", ";
-    p << "{";
-    for (int32_t j = 0; j < (*segments)[i]; ++j) {
-      if (j != 0)
+                     std::optional<mlir::DenseI32ArrayAttr> segments,
+                     std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && gangOnlyDeviceTypes &&
+      gangOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+  if (hasDeviceTypeValues(gangOnlyDeviceTypes)) {
+    p << "[";
+    for (unsigned i = 0; i < gangOnlyDeviceTypes.value().size(); ++i) {
+      if (i != 0)
         p << ", ";
-      auto gangArgTypeAttr =
-          mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
-      if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
-        p << LoopOp::getGangNumKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
-        p << LoopOp::getGangDimKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static)
-        p << LoopOp::getGangStaticKeyword();
-      p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
-      ++opIdx;
+      auto deviceTypeAttr =
+          mlir::dyn...
[truncated]

@clementval clementval requested a review from jeanPerier January 11, 2024 19:36
Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for implementing this. This really improves readability.

@clementval clementval merged commit 40f5f90 into llvm:main Jan 11, 2024
@clementval clementval deleted the acc_device_type_pretty branch January 11, 2024 21:02
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
llvm#77667)

The IR representation for gang, vector and worker has grown with the
support for device_type. This patch simplify the IR representation for
gang, vector and worker information on the acc.loop operation.

When the only the keyword is present without any values, the information
is printed at the same place than when there is values. The device_type
is omitted if there is no values and it is equal to None. Otherwise the
full information is displayed. First the keyword only device_type
information and then the values with their device_type.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants