Skip to content

[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr #76886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let mnemonic = "shard";

let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
Expand All @@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
The MeshSharding attribute could be used in the encoding of a
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:

1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
cluster where the distributed tensor is placed. The symbol must resolve to a
`mesh.cluster` operation.

Expand Down Expand Up @@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}];

let builders = [
AttrBuilder<(ins "SymbolRefAttr":$cluster,
AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
Expand All @@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
partial_type);
}]>,
AttrBuilder<(ins "SymbolRefAttr":$cluster,
AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
}]>
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
```
}];
let arguments = (ins
Builtin_RankedTensor:$src,
AnyRankedTensor:$src,
MeshSharding:$shard,
UnitAttr:$annotate_for_users
);
let results = (outs
Builtin_RankedTensor:$result
AnyRankedTensor:$result
);
let assemblyFormat = [{
$src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
// mesh axes the i-th loop will be sharded on.
ShardingArray shardingArray = {};
SymbolRefAttr cluster = nullptr;
FlatSymbolRefAttr cluster = nullptr;
// `empty` being true indicates that no sharding information can be inferred
// at present. Note that it is different from the case where an operation is
// not sharded.
bool empty = false;
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
};

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,

LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
ArrayRef<MeshAxis> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ namespace {
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
SymbolRefAttr cluster,
FlatSymbolRefAttr cluster,
ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) {
if ((shardingOption.cluster && cluster &&
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Mesh/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ func.func @mesh_axis_negtive_in_partial(

// -----

func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
// expected-error@+2 {{custom op 'mesh.shard' invalid kind of attribute specified}}
// expected-error@+1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'cluster' which is to be a `::mlir::FlatSymbolRefAttr`}}
%0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32>
}

// -----

mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)

func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
Expand Down