Skip to content

Commit fc18b13

Browse files
authored
[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (#76886)
Analogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh.
1 parent 10b03e6 commit fc18b13

File tree

6 files changed

+18
-10
lines changed

6 files changed

+18
-10
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
7979
let mnemonic = "shard";
8080

8181
let parameters = (ins
82-
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
82+
AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
8383
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
8484
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
8585
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
9191
The MeshSharding attribute could be used in the encoding of a
9292
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
9393

94-
1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
94+
1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
9595
cluster where the distributed tensor is placed. The symbol must resolve to a
9696
`mesh.cluster` operation.
9797

@@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
145145
}];
146146

147147
let builders = [
148-
AttrBuilder<(ins "SymbolRefAttr":$cluster,
148+
AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
149149
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
150150
"ArrayRef<MeshAxis>": $partial_axes,
151151
"mesh::Partial": $partial_type), [{
@@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
156156
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
157157
partial_type);
158158
}]>,
159-
AttrBuilder<(ins "SymbolRefAttr":$cluster,
159+
AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
160160
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
161161
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
162162
}]>

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
196196
```
197197
}];
198198
let arguments = (ins
199-
Builtin_RankedTensor:$src,
199+
AnyRankedTensor:$src,
200200
MeshSharding:$shard,
201201
UnitAttr:$annotate_for_users
202202
);
203203
let results = (outs
204-
Builtin_RankedTensor:$result
204+
AnyRankedTensor:$result
205205
);
206206
let assemblyFormat = [{
207207
$src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ struct ShardingOption {
2525
// An array of int array. The sub-array at the i-th position signifies the
2626
// mesh axes the i-th loop will be sharded on.
2727
ShardingArray shardingArray = {};
28-
SymbolRefAttr cluster = nullptr;
28+
FlatSymbolRefAttr cluster = nullptr;
2929
// `empty` being true indicates that no sharding information can be inferred
3030
// at present. Note that it is different from the case where an operation is
3131
// not sharded.
3232
bool empty = false;
3333
ShardingOption() = default;
34-
ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
34+
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
3535
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
3636
};
3737

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
266266

267267
LogicalResult
268268
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
269-
SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
269+
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
270270
ArrayRef<MeshAxis> partialAxes, Partial) {
271271
// TODO: At present cluster symbol ref is not verified. This is due to the
272272
// difficulty in fetching the corresponding symbol op based on an attribute.

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ namespace {
215215
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
216216
static LogicalResult fillShardingOption(Operation *op,
217217
ShardingOption &shardingOption,
218-
SymbolRefAttr cluster,
218+
FlatSymbolRefAttr cluster,
219219
ArrayRef<MeshAxis> meshAxes,
220220
unsigned loopIdx) {
221221
if ((shardingOption.cluster && cluster &&

mlir/test/Dialect/Mesh/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ func.func @mesh_axis_negtive_in_partial(
7070

7171
// -----
7272

73+
func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
74+
// expected-error@+2 {{custom op 'mesh.shard' invalid kind of attribute specified}}
75+
// expected-error@+1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'cluster' which is to be a `::mlir::FlatSymbolRefAttr`}}
76+
%0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32>
77+
}
78+
79+
// -----
80+
7381
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
7482

7583
func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {

0 commit comments

Comments
 (0)