Skip to content

[mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp #77838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2024

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Jan 11, 2024

Remove the somewhat redundant rank attribute.
Before this change

mesh.cluster @mesh(rank = 3, dim_sizes = 2x3)

After

mesh.cluster @mesh(shape = 2x3x?)

The rank is instead determined by the provided shape. With this change no longer getDimSizes() can be wrongly assumed to have size equal to the cluster rank.
Now getShape().size() will always equal getRank().

…lusterOp

Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, shape = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```

The rank is instead determined by the provided shape.
With this change no longer `getDimSizes()` can be wrongly assumed to have size
equal to the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
@llvmbot llvmbot added the mlir label Jan 11, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Remove the somewhat redundant rank attribute.
Before this change

mesh.cluster @<!-- -->mesh(rank = 3, shape = 2x3)

After

mesh.cluster @<!-- -->mesh(shape = 2x3x?)

The rank is instead determined by the provided shape. With this change no longer getDimSizes() can be wrongly assumed to have size equal to the cluster rank.
Now getShape().size() will always equal getRank().


Patch is 38.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77838.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+15-25)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+23-33)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+7-9)
  • (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+1-1)
  • (modified) mlir/test/Dialect/Mesh/folding.mlir (+2-2)
  • (modified) mlir/test/Dialect/Mesh/invalid.mlir (+65-70)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+10-10)
  • (modified) mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir (+1-1)
  • (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+2-2)
  • (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+3-3)
  • (modified) mlir/test/Dialect/Mesh/simplifications.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index bda6467e9c5d4b..07f954459ca49d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     Example:
 
     ```
-    mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+    mesh.cluster @mesh0(shape = 2x2x4)
 
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a9068562f5c903..c25996cf122c1c 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -41,7 +41,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     determine the layout and the addressing space of the computation distributed
     across the mesh.
 
-    3. `dim_sizes`: This attribute represents the shape of the device cluster.
+    3. `shape`: This attribute represents the shape of the device cluster.
     It uses the same notation as a tensor shape. Also allowing for dynamic
     dimensions.
     This flexibility allows for dynamic device assignment or configurations
@@ -53,19 +53,19 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     ```
     // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
+    mesh.cluster @mesh0(shape = 4x8x12)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
+    mesh.cluster @mesh1(shape = 4x?)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
+    mesh.cluster @mesh2(shape = ?x4)
 
     // A device mesh cluster with 2 axes, the number of devices along both axes
     // is unknown
-    mesh.cluster @mesh3(rank = 2)
+    mesh.cluster @mesh3(shape = ?x?)
 
     // Used in the mesh sharding attribute to extend the standard tensor to
     // distributed
@@ -74,24 +74,14 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   }];
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    I64Attr:$rank,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
+    DenseI64ArrayAttr:$shape
   );
   let assemblyFormat = [{
-    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
+    $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
       attr-dict
   }];
   let extraClassDeclaration = [{
-    // The `dim_sizes` attribute may have size less than the rank of the mesh.
-    // Returns the shape of the mesh with missing trailing dimensions
-    // explicitly set as dynamic.
-    ::mlir::SmallVector<int64_t> canonicalDimSizes();
-
-    template <typename OutIt>
-    void canonicalDimSizes(OutIt outIt) {
-      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
-      std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
-    }
+    int64_t getRank() { return getShape().size(); }
   }];
   let hasVerifier = 1;
 }
@@ -283,7 +273,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
@@ -368,7 +358,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+    mesh.cluster @mesh0(shape = 3)
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
@@ -425,7 +415,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
 
     %1 = mesh.broadcast %0 on @mesh0
       mesh_axes = [0]
@@ -481,7 +471,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
       gather_axis = 1 root = [1]
@@ -604,7 +594,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     across the device group.
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
       reduction = <max> scatter_axis = 0
@@ -667,7 +657,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
     %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
       scatter_axis = 0
       root = [1]
@@ -763,7 +753,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+    mesh.cluster @mesh0(shape = 2x4)
     %1 = mesh.shift on @mesh0 mesh_axes = [1]
       shift_axis = 1 offset = 2 rotate
       : tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 957b380efd516b..fa9da596a34587 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -196,17 +196,16 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ClusterOp::verify() {
-  ArrayRef<int64_t> dimSizes = getDimSizes();
-  uint64_t rank = getRank();
+  int64_t rank = getRank();
 
-  if (rank == 0)
+  if (rank <= 0)
     return emitOpError("rank of cluster is expected to be a positive integer");
 
-  if (dimSizes.size() > rank)
+  if (getShape().size() > rank)
     return emitOpError(
-        "rank of dim_sizes is not expected to be larger than rank of cluster");
+        "rank of shape is not expected to be larger than rank of cluster");
 
-  for (int64_t dimSize : dimSizes) {
+  for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh cluster is expected to be "
                          "non-negative or dynamic");
@@ -215,13 +214,6 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
-SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
-  SmallVector<int64_t> result;
-  canonicalDimSizes(std::back_inserter(result));
-  result.reserve(getRank());
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // mesh.cluster_shape op
 //===----------------------------------------------------------------------===//
@@ -614,7 +606,7 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getOperand(), getResult(),
                                            gatherAxis, getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+                                           mesh.value().getShape());
 }
 
 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -648,8 +640,7 @@ LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
   return verifyAllToAllOperandAndResultShape(
       getOperand(), getResult(), getSplitAxis().getSExtValue(),
-      getConcatAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().canonicalDimSizes());
+      getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
 }
 
 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -667,9 +658,9 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
@@ -690,16 +681,16 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
                                            getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+                                           mesh.value().getShape());
 }
 
 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -716,10 +707,10 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
-  if (getSource() && failed(verifyInGroupDevice(
-                         getLoc(), getSourceAttrName(), getSource().value(),
-                         getSourceDynamic(), getMeshAxes(), meshShape))) {
+  if (getSource() &&
+      failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
+                                 getSource().value(), getSourceDynamic(),
+                                 getMeshAxes(), mesh.value().getShape()))) {
     return failure();
   }
   return success();
@@ -739,9 +730,9 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
@@ -766,7 +757,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
   return verifyScatterOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().canonicalDimSizes());
+      mesh.value().getShape());
 }
 
 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -783,16 +774,16 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
   auto scatterAxis = getScatterAxis().getSExtValue();
   return verifyScatterOperandAndResultShape(getInput(), getResult(),
                                             scatterAxis, getMeshAxes(),
-                                            mesh.value().canonicalDimSizes());
+                                            mesh.value().getShape());
 }
 
 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -809,10 +800,9 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
                                  getDestination(), getDestinationDynamic(),
-                                 getMeshAxes(), meshShape))) {
+                                 getMeshAxes(), mesh.value().getShape()))) {
     return failure();
   }
   return success();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c9275ad5ad4551..c478b6da4c27b3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -80,7 +80,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
       opMeshAxes = opAxesIota;
     }
     if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
-          return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+          return ShapedType::isDynamic(mesh.getShape()[axis]);
         })) {
       // All mesh dimensions are dynamic. Nothing to fold.
       return failure();
@@ -91,7 +91,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
     SmallVector<size_t> newToOldResultsIndexMap;
 
     for (size_t i = 0; i < opMeshAxes.size(); ++i) {
-      auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+      auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
       if (ShapedType::isDynamic(meshAxisSize)) {
         newToOldResultsIndexMap.push_back(i);
         newShapeOpMeshAxes.push_back(opMeshAxes[i]);
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 0e83c024fc08f8..9478b2e4ee5cb2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -88,8 +88,8 @@ ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
                            MeshShardingAttr sharding) {
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
-  shardShape(shape.getShape(), mesh.canonicalDimSizes(),
-             sharding.getSplitAxes(), resShapeArr);
+  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+             resShapeArr);
   return shape.clone(resShapeArr);
 }
 
@@ -212,9 +212,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
 
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
       ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
-  ShapedType targetShape =
-      targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
-                                 mesh.canonicalDimSizes()[splitMeshAxis]);
+  ShapedType targetShape = targetShapeInSplitLastAxis(
+      sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
 
   Value meshAxisSize =
       builder
@@ -391,8 +390,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
   MeshShardingAttr targetSharding =
       targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
-      sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
-      splitTensorAxis);
+      sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
   Value allGatherResult = builder.create<AllGatherOp>(
       RankedTensorType::get(allGatherResultShape.getShape(),
                             allGatherResultShape.getElementType()),
@@ -526,8 +524,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
   MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
       ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
   ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
-      sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
-      sourceTensorAxis, targetTensorAxis);
+      sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+      targetTensorAxis);
   Value allToAllResult = builder.create<AllToAllOp>(
       RankedTensorType::get(allToAllResultShape.getShape(),
                             allToAllResultShape.getElementType()),
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 0a00ab41268d01..4cc009ef24eb3c 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --canonicalize %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes
 func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index dd64d746341b83..9162dc57ecfdf4 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
-mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+mesh.cluster @mesh0(shape = 4x?x2)
+mesh.cluster @mesh1(shape = 2x3)
 
 // CHECK-LABEL: func.func @cluster_shape_op_folding
 func.func @cluster_shape_op_folding() -> (index, index) {
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index f3524a82a1b9d2..8a1fb80065573b 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -1,21 +1,16 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
 // expected-error@+1 {{rank of cluster is expected to be a positive integer}}
-mesh.cluster @mesh0(rank = 0)
-
-// -----
-
-// expected-error@+1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
+mesh.cluster @mesh0(shape = [])
 
 // -----
 
 // expected-error@+1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
+mesh.cluster @mesh0(shape = -1)
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
     // expected-error@+1 {{mesh axis duplicated}}
@@ -26,7 +21,7 @@ func.func @mesh_axis_duplicated_different_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
     // expected-error@+1 {{mesh axis duplicated}}
@@ -37,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
     // expected-error@+1 {{mesh axis duplicated}}
@@ -48,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
     // expected-error@+1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
     // expected-error@+1 {{mesh axis is expected to be non-negative}}
@@ -78,7 +73,7 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
   // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -88,7 +83,7 @@ func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
 
 func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
...
[truncated]

@sogartar
Copy link
Contributor Author

@yaochengji, will you take a look?

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. But pleas wait for @yaochengji to have a look

@@ -41,7 +41,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
determine the layout and the addressing space of the computation distributed
across the mesh.

3. `dim_sizes`: This attribute represents the shape of the device cluster.
3. `shape`: This attribute represents the shape of the device cluster.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also remove the doc for rank and renumber the rest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you. I removed it.

@yaochengji
Copy link
Member

LGTM, Thanks.

@sogartar sogartar merged commit 5df2c00 into llvm:main Jan 15, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…lusterOp (llvm#77838)

Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, dim_sizes = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```

The rank is instead determined by the provided shape. With this change
no longer `getDimSizes()` can be wrongly assumed to have size equal to
the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants