-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][openacc][NFC] Cleanup hasOnly functions for device_type support #78800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-mlir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesJust a cleanup for all the Full diff: https://github.com/llvm/llvm-project/pull/78800.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..a63d6afa0e8532 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -69,6 +69,41 @@ void OpenACCDialect::initialize() {
*getContext());
}
+//===----------------------------------------------------------------------===//
+// device_type support helpers
+//===----------------------------------------------------------------------===//
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+ if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+ return true;
+ return false;
+}
+
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+ mlir::acc::DeviceType deviceType) {
+ if (!hasDeviceTypeValues(arrayAttr))
+ return false;
+
+ for (auto attr : *arrayAttr) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+ if (deviceTypeAttr.getValue() == deviceType)
+ return true;
+ }
+
+ return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+ if (!hasDeviceTypeValues(deviceTypes))
+ return;
+
+ p << "[";
+ llvm::interleaveComma(*deviceTypes, p,
+ [&](mlir::Attribute attr) { p << attr; });
+ p << "]";
+}
+
//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
@@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() {
}
bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value acc::ParallelOp::getAsyncValue() {
@@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() {
}
bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range ParallelOp::getWaitValues() {
@@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
return success();
}
-static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
- if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
- return true;
- return false;
-}
-
-static void printDeviceTypes(mlir::OpAsmPrinter &p,
- std::optional<mlir::ArrayAttr> deviceTypes) {
- if (!hasDeviceTypeValues(deviceTypes))
- return;
-
- p << "[";
- llvm::interleaveComma(*deviceTypes, p,
- [&](mlir::Attribute attr) { p << attr; });
- p << "]";
-}
-
static void printDeviceTypeOperandsWithKeywordOnly(
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() {
}
bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value acc::SerialOp::getAsyncValue() {
@@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() {
}
bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range SerialOp::getWaitValues() {
@@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() {
}
bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value acc::KernelsOp::getAsyncValue() {
@@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() {
}
bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range KernelsOp::getWaitValues() {
@@ -1646,11 +1640,7 @@ Value LoopOp::getDataOperand(unsigned i) {
bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAuto_()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAuto_(), deviceType);
}
bool LoopOp::hasIndependent() {
@@ -1658,21 +1648,13 @@ bool LoopOp::hasIndependent() {
}
bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getIndependent()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getIndependent(), deviceType);
}
bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getSeq()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getSeq(), deviceType);
}
mlir::Value LoopOp::getVectorValue() {
@@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getVector()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getVector(), deviceType);
}
mlir::Value LoopOp::getWorkerValue() {
@@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWorker()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWorker(), deviceType);
}
mlir::Operation::operand_range LoopOp::getTileValues() {
@@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getGang()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getGang(), deviceType);
}
//===----------------------------------------------------------------------===//
@@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() {
}
bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getAsyncOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getAsyncOnly(), deviceType);
}
mlir::Value DataOp::getAsyncValue() {
@@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
- if (auto arrayAttr = getWaitOnly()) {
- if (findSegment(*arrayAttr, deviceType))
- return true;
- }
- return false;
+ return hasDeviceType(getWaitOnly(), deviceType);
}
mlir::Operation::operand_range DataOp::getWaitValues() {
@@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() {
// RoutineOp
//===----------------------------------------------------------------------===//
-static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
- mlir::acc::DeviceType deviceType) {
- if (!hasDeviceTypeValues(arrayAttr))
- return false;
-
- for (auto attr : *arrayAttr) {
- auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
- if (deviceTypeAttr.getValue() == deviceType)
- return true;
- }
-
- return false;
-}
-
static unsigned getParallelismForDeviceType(acc::RoutineOp op,
acc::DeviceType dtype) {
unsigned parallelism = 0;
|
razvanlupusoru
approved these changes
Jan 22, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Just a cleanup for all the
has.*Only()
function to avoid code duplication