Skip to content

Commit 1a8fb88

Browse files
authored
[mlir][mesh] Add resharding spmdization on a 1D device mesh (#76179)
The current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size.
1 parent 7122f55 commit 1a8fb88

File tree

14 files changed

+2004
-55
lines changed

14 files changed

+2004
-55
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
3333
let useDefaultAttributePrinterParser = 1;
3434
let hasConstantMaterializer = 1;
3535
}
36+
37+
def Mesh_MeshAxis : I<16>;
38+
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
39+
3640
//===----------------------------------------------------------------------===//
3741
// Mesh Enums.
3842
//===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
125129

126130
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
127131
// it is also a partial_sum along mesh axis 1.
128-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
132+
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
129133

130134
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
131135
// it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
158162
}]>
159163
];
160164

165+
let extraClassDeclaration = [{
166+
bool operator==(::mlir::Attribute rhs) const;
167+
bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
168+
}];
169+
161170
let genVerifyDecl = 1;
162171
}
163172

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
namespace mlir {
3131
namespace mesh {
3232

33+
using MeshAxis = int16_t;
34+
using MeshAxesAttr = DenseI16ArrayAttr;
35+
3336
bool isReductionLoop(IteratorType iType);
3437

3538
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
1111

1212
include "mlir/Dialect/Mesh/IR/MeshBase.td"
13+
include "mlir/Dialect/Shape/IR/ShapeBase.td"
1314
include "mlir/Interfaces/InferTypeOpInterface.td"
1415
include "mlir/Interfaces/SideEffectInterfaces.td"
1516
include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
9596
let hasVerifier = 1;
9697
}
9798

99+
def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
100+
let summary = "Get the shape of the cluster.";
101+
let arguments = (ins
102+
FlatSymbolRefAttr:$mesh,
103+
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
104+
);
105+
106+
let results = (outs
107+
Variadic<Index>:$result
108+
);
109+
110+
let assemblyFormat = [{
111+
$mesh (`axes` `=` $axes^)?
112+
attr-dict `:` type($result)
113+
}];
114+
115+
let builders = [
116+
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
117+
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
118+
];
119+
}
120+
98121
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
99122
let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
100123
let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
186209
}];
187210
}
188211

212+
def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
213+
let summary = "Get the index of current device along specified mesh axis.";
214+
let description = [{
215+
It is used in the SPMD format of IR.
216+
The `axes` mush be non-negative and less than the total number of mesh axes.
217+
}];
218+
let arguments = (ins
219+
FlatSymbolRefAttr:$mesh,
220+
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
221+
);
222+
let results = (outs
223+
Variadic<Index>:$result
224+
);
225+
let assemblyFormat = [{
226+
`on` $mesh (`axes` `=` $axes^)?
227+
attr-dict `:` type($result)
228+
}];
229+
let builders = [
230+
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
231+
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
232+
];
233+
}
234+
189235
//===----------------------------------------------------------------------===//
190236
// collective communication ops
191237
//===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
197243
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
198244
dag commonArgs = (ins
199245
FlatSymbolRefAttr:$mesh,
200-
DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
246+
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
201247
);
202248
}
203249

0 commit comments

Comments
 (0)