Skip to content

[mlir][openacc][NFC] Cleanup hasOnly functions for device_type support #78800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2024

Conversation

clementval
Copy link
Contributor

Just a cleanup for all the has.*Only() function to avoid code duplication

@llvmbot
Copy link
Member

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Just a cleanup for all the has.*Only() function to avoid code duplication


Full diff: https://github.com/llvm/llvm-project/pull/78800.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+49-101)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..a63d6afa0e8532 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -69,6 +69,41 @@ void OpenACCDialect::initialize() {
       *getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// device_type support helpers
+//===----------------------------------------------------------------------===//
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(arrayAttr))
+    return false;
+
+  for (auto attr : *arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+
+  p << "[";
+  llvm::interleaveComma(*deviceTypes, p,
+                        [&](mlir::Attribute attr) { p << attr; });
+  p << "]";
+}
+
 //===----------------------------------------------------------------------===//
 // DataBoundsOp
 //===----------------------------------------------------------------------===//
@@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() {
 }
 
 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value acc::ParallelOp::getAsyncValue() {
@@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() {
 }
 
 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range ParallelOp::getWaitValues() {
@@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
   return success();
 }
 
-static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
-  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
-    return true;
-  return false;
-}
-
-static void printDeviceTypes(mlir::OpAsmPrinter &p,
-                             std::optional<mlir::ArrayAttr> deviceTypes) {
-  if (!hasDeviceTypeValues(deviceTypes))
-    return;
-
-  p << "[";
-  llvm::interleaveComma(*deviceTypes, p,
-                        [&](mlir::Attribute attr) { p << attr; });
-  p << "]";
-}
-
 static void printDeviceTypeOperandsWithKeywordOnly(
     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() {
 }
 
 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value acc::SerialOp::getAsyncValue() {
@@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() {
 }
 
 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range SerialOp::getWaitValues() {
@@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() {
 }
 
 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value acc::KernelsOp::getAsyncValue() {
@@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() {
 }
 
 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range KernelsOp::getWaitValues() {
@@ -1646,11 +1640,7 @@ Value LoopOp::getDataOperand(unsigned i) {
 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAuto_()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAuto_(), deviceType);
 }
 
 bool LoopOp::hasIndependent() {
@@ -1658,21 +1648,13 @@ bool LoopOp::hasIndependent() {
 }
 
 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getIndependent()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getIndependent(), deviceType);
 }
 
 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getSeq()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getSeq(), deviceType);
 }
 
 mlir::Value LoopOp::getVectorValue() {
@@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getVector()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getVector(), deviceType);
 }
 
 mlir::Value LoopOp::getWorkerValue() {
@@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWorker()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWorker(), deviceType);
 }
 
 mlir::Operation::operand_range LoopOp::getTileValues() {
@@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getGang()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getGang(), deviceType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() {
 }
 
 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value DataOp::getAsyncValue() {
@@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
 
 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range DataOp::getWaitValues() {
@@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() {
 // RoutineOp
 //===----------------------------------------------------------------------===//
 
-static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
-                          mlir::acc::DeviceType deviceType) {
-  if (!hasDeviceTypeValues(arrayAttr))
-    return false;
-
-  for (auto attr : *arrayAttr) {
-    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
-    if (deviceTypeAttr.getValue() == deviceType)
-      return true;
-  }
-
-  return false;
-}
-
 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
                                             acc::DeviceType dtype) {
   unsigned parallelism = 0;

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!

@clementval clementval merged commit ee6199c into llvm:main Jan 22, 2024
@clementval clementval deleted the acc_hasOnly_cleanup branch January 22, 2024 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants