-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr #76886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…RefAttr Analogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh.
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesAnalogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh. Full diff: https://github.com/llvm/llvm-project/pull/76886.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 060d54b82efa63..bda6467e9c5d4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let mnemonic = "shard";
let parameters = (ins
- AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
+ AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
The MeshSharding attribute could be used in the encoding of a
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
- 1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
+ 1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
cluster where the distributed tensor is placed. The symbol must resolve to a
`mesh.cluster` operation.
@@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}];
let builders = [
- AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
@@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
partial_type);
}]>,
- AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
}]>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1934bdfb427059..f459077ea12022 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -196,12 +196,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
```
}];
let arguments = (ins
- Builtin_RankedTensor:$src,
+ AnyRankedTensor:$src,
MeshSharding:$shard,
UnitAttr:$annotate_for_users
);
let results = (outs
- Builtin_RankedTensor:$result
+ AnyRankedTensor:$result
);
let assemblyFormat = [{
$src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 201c0151754eba..a32274d857f15d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,13 +25,13 @@ struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
// mesh axes the i-th loop will be sharded on.
ShardingArray shardingArray = {};
- SymbolRefAttr cluster = nullptr;
+ FlatSymbolRefAttr cluster = nullptr;
// `empty` being true indicates that no sharding information can be inferred
// at present. Note that it is different from the case where an operation is
// not sharded.
bool empty = false;
ShardingOption() = default;
- ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+ ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
};
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c3d8f1d456106d..6667d409df8b78 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,7 +266,7 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+ FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
ArrayRef<MeshAxis> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index ee885ab16b7b06..dca7e86e6f07f5 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,7 +215,7 @@ namespace {
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
- SymbolRefAttr cluster,
+ FlatSymbolRefAttr cluster,
ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) {
if ((shardingOption.cluster && cluster &&
|
@yaochengji, can you take a look at this PR? |
Do you have a motivation for this? Also can you add a test rejecting the nested structure? |
I cargo culted from |
Just added a test. |
LGTM, thanks |
…be57544a8 Local branch amd-gfx 599be57 Merged main:f7f7574afe4cfc11ebe5d8cb811d5cd28dc862f6 into amd-gfx:1656d3862359 Remote branch main fc18b13 [mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (llvm#76886)
Analogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh.