Skip to content

[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr #76886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 5, 2024

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Jan 4, 2024

Analogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh.

…RefAttr

Analogous to func.call use FlatSymbolRefAttr to reference the
corresponding mesh.
@sogartar sogartar requested a review from joker-eph January 4, 2024 01:45
@llvmbot llvmbot added the mlir label Jan 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 4, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Analogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh.


Full diff: https://github.com/llvm/llvm-project/pull/76886.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+4-4)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+2-2)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 060d54b82efa63..bda6467e9c5d4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   let mnemonic = "shard";
 
   let parameters = (ins
-    AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
+    AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
     ArrayRefParameter<"MeshAxesAttr">:$split_axes,
     OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     The MeshSharding attribute could be used in the encoding of a
     `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
 
-    1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
+    1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
     cluster where the distributed tensor is placed. The symbol must resolve to a
     `mesh.cluster` operation.
 
@@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   }];
 
   let builders = [
-    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
                      "ArrayRef<MeshAxis>": $partial_axes,
                      "mesh::Partial": $partial_type), [{
@@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
       return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
                    partial_type);
     }]>,
-    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
       return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
     }]>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1934bdfb427059..f459077ea12022 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -196,12 +196,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
     ```
   }];
   let arguments = (ins
-    Builtin_RankedTensor:$src,
+    AnyRankedTensor:$src,
     MeshSharding:$shard,
     UnitAttr:$annotate_for_users
   );
   let results = (outs
-    Builtin_RankedTensor:$result
+    AnyRankedTensor:$result
   );
   let assemblyFormat = [{
     $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 201c0151754eba..a32274d857f15d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,13 +25,13 @@ struct ShardingOption {
   // An array of int array. The sub-array at the i-th position signifies the
   // mesh axes the i-th loop will be sharded on.
   ShardingArray shardingArray = {};
-  SymbolRefAttr cluster = nullptr;
+  FlatSymbolRefAttr cluster = nullptr;
   // `empty` being true indicates that no sharding information can be inferred
   // at present. Note that it is different from the case where an operation is
   // not sharded.
   bool empty = false;
   ShardingOption() = default;
-  ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
       : shardingArray(std::move(shardingArray)), cluster(cluster) {}
 };
 
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c3d8f1d456106d..6667d409df8b78 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,7 +266,7 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+                         FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
                          ArrayRef<MeshAxis> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index ee885ab16b7b06..dca7e86e6f07f5 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,7 +215,7 @@ namespace {
 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
 static LogicalResult fillShardingOption(Operation *op,
                                         ShardingOption &shardingOption,
-                                        SymbolRefAttr cluster,
+                                        FlatSymbolRefAttr cluster,
                                         ArrayRef<MeshAxis> meshAxes,
                                         unsigned loopIdx) {
   if ((shardingOption.cluster && cluster &&

@sogartar
Copy link
Contributor Author

sogartar commented Jan 4, 2024

@yaochengji, can you take a look at this PR?

@joker-eph
Copy link
Collaborator

Do you have a motivation for this?

Also can you add a test rejecting the nested structure?

@sogartar
Copy link
Contributor Author

sogartar commented Jan 4, 2024

Do you have a motivation for this?

I cargo culted from func.call the flat symbol refs for the other collective ops and I wanted the Mesh dialect to be consistent. I taught that flat symbols are less complicated and I could not think of a good reason for nested symbols right now.

@sogartar
Copy link
Contributor Author

sogartar commented Jan 5, 2024

Also can you add a test rejecting the nested structure?

Just added a test.

@yaochengji
Copy link
Member

LGTM, thanks

@sogartar sogartar merged commit fc18b13 into llvm:main Jan 5, 2024
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Jan 26, 2024
…be57544a8

Local branch amd-gfx 599be57 Merged main:f7f7574afe4cfc11ebe5d8cb811d5cd28dc862f6 into amd-gfx:1656d3862359
Remote branch main fc18b13 [mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (llvm#76886)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants