Skip to content

[mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp #77838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
Example:

```
mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
mesh.cluster @mesh0(shape = 2x2x4)

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
Expand Down
45 changes: 15 additions & 30 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
cluster. This name serves as a symbolic reference to the cluster throughout
the MLIR module, allowing for consistent referencing and easier debugging.

2. `rank`: This attribute specifies the number of axes of the cluster. The
rank indicates the dimensionality of the mesh cluster and can be used to
determine the layout and the addressing space of the computation distributed
across the mesh.

3. `dim_sizes`: This attribute represents the shape of the device cluster.
2. `shape`: This attribute represents the shape of the device cluster.
It uses the same notation as a tensor shape. Also allowing for dynamic
dimensions.
This flexibility allows for dynamic device assignment or configurations
Expand All @@ -53,19 +48,19 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
```
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
mesh.cluster @mesh0(shape = 4x8x12)

// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
mesh.cluster @mesh1(shape = 4x?)

// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
mesh.cluster @mesh2(shape = ?x4)

// A device mesh cluster with 2 axes, the number of devices along both axes
// is unknown
mesh.cluster @mesh3(rank = 2)
mesh.cluster @mesh3(shape = ?x?)

// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
Expand All @@ -74,24 +69,14 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
I64Attr:$rank,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
DenseI64ArrayAttr:$shape
);
let assemblyFormat = [{
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
$sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
attr-dict
}];
let extraClassDeclaration = [{
// The `dim_sizes` attribute may have size less than the rank of the mesh.
// Returns the shape of the mesh with missing trailing dimensions
// explicitly set as dynamic.
::mlir::SmallVector<int64_t> canonicalDimSizes();

template <typename OutIt>
void canonicalDimSizes(OutIt outIt) {
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
}
int64_t getRank() { return getShape().size(); }
}];
let hasVerifier = 1;
}
Expand Down Expand Up @@ -283,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [

Example:
```mlir
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
Expand Down Expand Up @@ -368,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [

Example:
```
mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
mesh.cluster @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
Expand Down Expand Up @@ -425,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [

Example:
```
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
mesh.cluster @mesh0(shape = 2x2)

%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
Expand Down Expand Up @@ -481,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [

Example:
```mlir
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
Expand Down Expand Up @@ -604,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
across the device group.
Example:
```
mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
Expand Down Expand Up @@ -667,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [

Example:
```
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
mesh.cluster @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
Expand Down Expand Up @@ -763,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [

Example:
```
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
mesh.cluster @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
Expand Down
56 changes: 23 additions & 33 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,16 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
//===----------------------------------------------------------------------===//

LogicalResult ClusterOp::verify() {
ArrayRef<int64_t> dimSizes = getDimSizes();
uint64_t rank = getRank();
int64_t rank = getRank();

if (rank == 0)
if (rank <= 0)
return emitOpError("rank of cluster is expected to be a positive integer");

if (dimSizes.size() > rank)
if (getShape().size() > rank)
return emitOpError(
"rank of dim_sizes is not expected to be larger than rank of cluster");
"rank of shape is not expected to be larger than rank of cluster");

for (int64_t dimSize : dimSizes) {
for (int64_t dimSize : getShape()) {
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
return emitOpError("dimension size of a mesh cluster is expected to be "
"non-negative or dynamic");
Expand All @@ -215,13 +214,6 @@ LogicalResult ClusterOp::verify() {
return success();
}

SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
SmallVector<int64_t> result;
canonicalDimSizes(std::back_inserter(result));
result.reserve(getRank());
return result;
}

//===----------------------------------------------------------------------===//
// mesh.cluster_shape op
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -614,7 +606,7 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
gatherAxis, getMeshAxes(),
mesh.value().canonicalDimSizes());
mesh.value().getShape());
}

void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
Expand Down Expand Up @@ -648,8 +640,7 @@ LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

return verifyAllToAllOperandAndResultShape(
getOperand(), getResult(), getSplitAxis().getSExtValue(),
getConcatAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
}

void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
Expand All @@ -667,9 +658,9 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
getRootDynamic(), getMeshAxes(), meshShape))) {
getRootDynamic(), getMeshAxes(),
mesh.value().getShape()))) {
return failure();
}

Expand All @@ -690,16 +681,16 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
getRootDynamic(), getMeshAxes(), meshShape))) {
getRootDynamic(), getMeshAxes(),
mesh.value().getShape()))) {
return failure();
}

auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
getMeshAxes(),
mesh.value().canonicalDimSizes());
mesh.value().getShape());
}

void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
Expand All @@ -716,10 +707,10 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (getSource() && failed(verifyInGroupDevice(
getLoc(), getSourceAttrName(), getSource().value(),
getSourceDynamic(), getMeshAxes(), meshShape))) {
if (getSource() &&
failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
getSource().value(), getSourceDynamic(),
getMeshAxes(), mesh.value().getShape()))) {
return failure();
}
return success();
Expand All @@ -739,9 +730,9 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
getRootDynamic(), getMeshAxes(), meshShape))) {
getRootDynamic(), getMeshAxes(),
mesh.value().getShape()))) {
return failure();
}

Expand All @@ -766,7 +757,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

return verifyScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
mesh.value().getShape());
}

void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
Expand All @@ -783,16 +774,16 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
getRootDynamic(), getMeshAxes(), meshShape))) {
getRootDynamic(), getMeshAxes(),
mesh.value().getShape()))) {
return failure();
}

auto scatterAxis = getScatterAxis().getSExtValue();
return verifyScatterOperandAndResultShape(getInput(), getResult(),
scatterAxis, getMeshAxes(),
mesh.value().canonicalDimSizes());
mesh.value().getShape());
}

void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
Expand All @@ -809,10 +800,9 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
getDestination(), getDestinationDynamic(),
getMeshAxes(), meshShape))) {
getMeshAxes(), mesh.value().getShape()))) {
return failure();
}
return success();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
opMeshAxes = opAxesIota;
}
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
return ShapedType::isDynamic(mesh.getShape()[axis]);
})) {
// All mesh dimensions are dynamic. Nothing to fold.
return failure();
Expand All @@ -91,7 +91,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
SmallVector<size_t> newToOldResultsIndexMap;

for (size_t i = 0; i < opMeshAxes.size(); ++i) {
auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
if (ShapedType::isDynamic(meshAxisSize)) {
newToOldResultsIndexMap.push_back(i);
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
MeshShardingAttr sharding) {
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
shardShape(shape.getShape(), mesh.canonicalDimSizes(),
sharding.getSplitAxes(), resShapeArr);
shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
resShapeArr);
return shape.clone(resShapeArr);
}

Expand Down Expand Up @@ -212,9 +212,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,

MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
ShapedType targetShape =
targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
mesh.canonicalDimSizes()[splitMeshAxis]);
ShapedType targetShape = targetShapeInSplitLastAxis(
sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);

Value meshAxisSize =
builder
Expand Down Expand Up @@ -391,8 +390,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
splitTensorAxis);
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
Value allGatherResult = builder.create<AllGatherOp>(
RankedTensorType::get(allGatherResultShape.getShape(),
allGatherResultShape.getElementType()),
Expand Down Expand Up @@ -526,8 +524,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
sourceTensorAxis, targetTensorAxis);
sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
targetTensorAxis);
Value allToAllResult = builder.create<AllToAllOp>(
RankedTensorType::get(allToAllResultShape.getShape(),
allToAllResultShape.getElementType()),
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Mesh/canonicalization.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s

mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
mesh.cluster @mesh0(shape = 2x4)

// CHECK-LABEL: func @all_reduce_empty_mesh_axes
func.func @all_reduce_empty_mesh_axes(
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Mesh/folding.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s

mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
mesh.cluster @mesh0(shape = 4x?x2)
mesh.cluster @mesh1(shape = 2x3)

// CHECK-LABEL: func.func @cluster_shape_op_folding
func.func @cluster_shape_op_folding() -> (index, index) {
Expand Down
Loading