-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[acc] Add attribute for combined constructs #80319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Combined constructs are decomposed into separate operations. However, this does not adhere to `acc` dialect's goal to be able to regenerate semantically equivalent clauses as user's intent. Thus, add an attribute to keep track of the combined constructs.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-openacc Author: Razvan Lupusoru (razvanlupusoru) ChangesCombined constructs are decomposed into separate operations. However, this does not adhere to Full diff: https://github.com/llvm/llvm-project/pull/80319.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index bb3b9617c24ed..941682e6840a0 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -122,15 +122,19 @@ mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp);
/// Used to obtain the attribute name for declare.
static constexpr StringLiteral getDeclareAttrName() {
- return StringLiteral("acc.declare");
+ return DeclareAttr::name;
}
static constexpr StringLiteral getDeclareActionAttrName() {
- return StringLiteral("acc.declare_action");
+ return DeclareActionAttr::name;
}
static constexpr StringLiteral getRoutineInfoAttrName() {
- return StringLiteral("acc.routine_info");
+ return RoutineInfoAttr::name;
+}
+
+static constexpr StringLiteral getCombinedConstructsAttrName() {
+ return CombinedConstructsTypeAttr::name;
}
struct RuntimeCounters
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 9398cbfdacee4..24acc66bf9497 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -218,6 +218,24 @@ def GangArgTypeArrayAttr :
let constBuilderCall = ?;
}
+// Combined constructs enumerations
+def OpenACC_KernelsLoop : I32EnumAttrCase<"KernelsLoop", 1, "kernels_loop">;
+def OpenACC_ParallelLoop : I32EnumAttrCase<"ParallelLoop", 2, "parallel_loop">;
+def OpenACC_SerialLoop : I32EnumAttrCase<"SerialLoop", 3, "serial_loop">;
+
+def OpenACC_CombinedConstructsType : I32EnumAttr<"CombinedConstructsType",
+ "Differentiate between combined constructs",
+ [OpenACC_KernelsLoop, OpenACC_ParallelLoop, OpenACC_SerialLoop]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::acc";
+}
+
+def OpenACC_CombinedConstructsAttr : EnumAttr<OpenACC_Dialect,
+ OpenACC_CombinedConstructsType,
+ "combined_constructs"> {
+ let assemblyFormat = [{ ```<` $value `>` }];
+}
+
// Define a resource for the OpenACC runtime counters.
def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;
@@ -928,7 +946,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ OptionalAttr<DefaultValueAttr>:$defaultAttr,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
let regions = (region AnyRegion:$region);
@@ -989,7 +1008,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
let assemblyFormat = [{
oilist(
- `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1059,7 +1079,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ OptionalAttr<DefaultValueAttr>:$defaultAttr,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
let regions = (region AnyRegion:$region);
@@ -1101,7 +1122,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
let assemblyFormat = [{
oilist(
- `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1168,7 +1190,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
Optional<I1>:$selfCond,
UnitAttr:$selfAttr,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ OptionalAttr<DefaultValueAttr>:$defaultAttr,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
let regions = (region AnyRegion:$region);
@@ -1229,7 +1252,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
let assemblyFormat = [{
oilist(
- `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `num_gangs` `(` custom<NumGangs>($numGangs,
@@ -1550,7 +1574,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<AnyType>:$reductionOperands,
- OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes
+ OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined
);
let results = (outs Variadic<AnyType>:$results);
@@ -1642,7 +1667,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
oilist(
- `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
$gangOperandsArgType, $gangOperandsDeviceType,
$gangOperandsSegments, $gang)
| `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index ae5da686f8595..a020e6d34aba9 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -815,6 +815,11 @@ LogicalResult acc::ParallelOp::verify() {
if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
return failure();
+ if (getCombined().has_value() &&
+ getCombined().value() != acc::CombinedConstructsType::ParallelLoop) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
@@ -1285,6 +1290,45 @@ static void printDeviceTypeOperandsWithKeywordOnly(
p << ")";
}
+static ParseResult
+parseCombinedConstructs(mlir::OpAsmParser &parser,
+ mlir::acc::CombinedConstructsTypeAttr &attr) {
+ // Just parsing first keyword we know which type of combined construct it is.
+ if (succeeded(parser.parseOptionalKeyword("kernels"))) {
+ attr = mlir::acc::CombinedConstructsTypeAttr::get(
+ parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
+ } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
+ attr = mlir::acc::CombinedConstructsTypeAttr::get(
+ parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
+ } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
+ attr = mlir::acc::CombinedConstructsTypeAttr::get(
+ parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
+ } else {
+ parser.emitError(parser.getCurrentLocation(),
+ "expected compute construct name for combined constructs");
+ return failure();
+ }
+
+ // Ensure that the `loop` wording follows the compute construct.
+ return parser.parseKeyword("loop");
+}
+
+static void
+printCombinedConstructs(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::acc::CombinedConstructsTypeAttr attr) {
+ switch (attr.getValue()) {
+ case mlir::acc::CombinedConstructsType::KernelsLoop:
+ p << "kernels loop";
+ break;
+ case mlir::acc::CombinedConstructsType::ParallelLoop:
+ p << "parallel loop";
+ break;
+ case mlir::acc::CombinedConstructsType::SerialLoop:
+ p << "serial loop";
+ break;
+ };
+}
+
//===----------------------------------------------------------------------===//
// SerialOp
//===----------------------------------------------------------------------===//
@@ -1370,6 +1414,11 @@ LogicalResult acc::SerialOp::verify() {
if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
return failure();
+ if (getCombined().has_value() &&
+ getCombined().value() != acc::CombinedConstructsType::SerialLoop) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
}
@@ -1497,6 +1546,11 @@ LogicalResult acc::KernelsOp::verify() {
if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
return failure();
+ if (getCombined().has_value() &&
+ getCombined().value() != acc::CombinedConstructsType::KernelsLoop) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
@@ -1854,6 +1908,13 @@ LogicalResult acc::LoopOp::verify() {
"reductions", false)))
return failure();
+ if (getCombined().has_value() &&
+ (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
+ getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
+ getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
// Check non-empty body().
if (getRegion().empty())
return emitError("expected non-empty body.");
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 16df33eec642c..48cbddae071ba 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -738,3 +738,33 @@ func.func @acc_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>,
acc.terminator
}
}
+
+// -----
+
+func.func @acc_combined() {
+ // expected-error @below {{expected compute construct name for combined constructs}}
+ acc.parallel combined() {
+ }
+
+ return
+}
+
+// -----
+
+func.func @acc_combined() {
+ // expected-error @below {{expected 'loop'}}
+ acc.parallel combined(parallel) {
+ }
+
+ return
+}
+
+// -----
+
+func.func @acc_combined() {
+ // expected-error @below {{unexpected combined constructs attribute}}
+ acc.parallel combined(kernels loop) {
+ }
+
+ return
+}
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 4e6ed8645cdbc..a10b603e8a07b 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1846,9 +1846,49 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
// -----
-%c2 = arith.constant 2 : i32
-%c1 = arith.constant 1 : i32
-acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+// CHECK-LABEL: func.func @acc_num_gangs
+func.func @acc_num_gangs() {
+ %c2 = arith.constant 2 : i32
+ %c1 = arith.constant 1 : i32
+ acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+ }
+
+ return
}
// CHECK: acc.parallel num_gangs({%c2{{.*}} : i32} [#acc.device_type<default>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+// -----
+
+// CHECK-LABEL: func.func @acc_combined
+func.func @acc_combined() {
+ acc.parallel combined(parallel loop) {
+ acc.loop combined(parallel loop) {
+ acc.yield
+ }
+ acc.terminator
+ }
+
+ acc.kernels combined(kernels loop) {
+ acc.loop combined(kernels loop) {
+ acc.yield
+ }
+ acc.terminator
+ }
+
+ acc.serial combined(serial loop) {
+ acc.loop combined(serial loop) {
+ acc.yield
+ }
+ acc.terminator
+ }
+
+ return
+}
+
+// CHECK: acc.parallel combined(parallel loop)
+// CHECK: acc.loop combined(parallel loop)
+// CHECK: acc.kernels combined(kernels loop)
+// CHECK: acc.loop combined(kernels loop)
+// CHECK: acc.serial combined(serial loop)
+// CHECK: acc.loop combined(serial loop)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this information. I have couple of comments and suggestions.
…nd reduce printing verbosity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you, Razvan!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments.
Combined constructs are decomposed into separate operations. However, this does not adhere to
acc
dialect's goal to be able to regenerate semantically equivalent clauses as user's intent. Thus, add an attribute to keep track of the combined constructs.