10
10
#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
11
11
12
12
include "mlir/Dialect/Mesh/IR/MeshBase.td"
13
+ include "mlir/Dialect/Shape/IR/ShapeBase.td"
13
14
include "mlir/Interfaces/InferTypeOpInterface.td"
14
15
include "mlir/Interfaces/SideEffectInterfaces.td"
15
16
include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
95
96
let hasVerifier = 1;
96
97
}
97
98
99
+ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
100
+ let summary = "Get the shape of the cluster.";
101
+ let arguments = (ins
102
+ FlatSymbolRefAttr:$mesh,
103
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
104
+ );
105
+
106
+ let results = (outs
107
+ Variadic<Index>:$result
108
+ );
109
+
110
+ let assemblyFormat = [{
111
+ $mesh (`axes` `=` $axes^)?
112
+ attr-dict `:` type($result)
113
+ }];
114
+
115
+ let builders = [
116
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
117
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
118
+ ];
119
+ }
120
+
98
121
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
99
122
let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
100
123
let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
186
209
}];
187
210
}
188
211
212
+ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
213
+ let summary = "Get the index of current device along specified mesh axis.";
214
+ let description = [{
215
+ It is used in the SPMD format of IR.
216
+ The `axes` mush be non-negative and less than the total number of mesh axes.
217
+ }];
218
+ let arguments = (ins
219
+ FlatSymbolRefAttr:$mesh,
220
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
221
+ );
222
+ let results = (outs
223
+ Variadic<Index>:$result
224
+ );
225
+ let assemblyFormat = [{
226
+ `on` $mesh (`axes` `=` $axes^)?
227
+ attr-dict `:` type($result)
228
+ }];
229
+ let builders = [
230
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
231
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
232
+ ];
233
+ }
234
+
189
235
//===----------------------------------------------------------------------===//
190
236
// collective communication ops
191
237
//===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
197
243
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
198
244
dag commonArgs = (ins
199
245
FlatSymbolRefAttr:$mesh,
200
- DefaultValuedAttr<DenseI16ArrayAttr , "{}">:$mesh_axes
246
+ DefaultValuedAttr<Mesh_MeshAxesAttr , "{}">:$mesh_axes
201
247
);
202
248
}
203
249
0 commit comments