11
11
12
12
include "mlir/Dialect/Mesh/IR/MeshBase.td"
13
13
include "mlir/Dialect/Shape/IR/ShapeBase.td"
14
+ include "mlir/Interfaces/DestinationStyleOpInterface.td"
14
15
include "mlir/Interfaces/InferTypeOpInterface.td"
15
16
include "mlir/Interfaces/SideEffectInterfaces.td"
16
17
include "mlir/IR/BuiltinTypes.td"
@@ -189,23 +190,27 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
189
190
`generic`: is not an allowed value inside a shard attribute.
190
191
191
192
5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
192
- `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
193
- `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
194
- halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
195
- `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
196
- e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
197
- `?` indicates dynamic halo sizes.
193
+ `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
194
+ sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
195
+ gets an additional halo of size 1 at the start of the first dimension and a halo
196
+ size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
197
+ sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
198
+ seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
199
+
200
+ 6. [Optional] Offsets for each shard and sharded tensor dimension.
201
+ `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
202
+ sharded tensor dimension the offsets (starting index) of all shards in that
203
+ dimension and an additional value for the end of the last shard are provided.
204
+ For a 1d sharding this means that position `i` has the exclusive prefix sum for
205
+ shard `i`, and since only contiguous sharding is supported, its inclusive prefix
206
+ sum is at position 'i+1'.
198
207
199
- 6. [Optional] Sizes of sharded dimensions of each shard.
200
- `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
201
- device-mesh one value for each sharded tensor dimension.
202
208
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
203
- `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
204
- the device-mesh will get a shard of shape 16x8x32 and the second device will get a
205
- shard of shape 16x24x32.
206
- `?` indicates dynamic shard dimensions.
209
+ `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
210
+ the device-mesh will get a shard of shape 24x20x32 and the second device will get
211
+ a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
207
212
208
- `halo_sizes` and `sharded_dims_sizes ` are mutually exclusive.
213
+ `halo_sizes` and `sharded_dims_offsets ` are mutually exclusive.
209
214
210
215
Examples:
211
216
@@ -240,7 +245,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
240
245
// The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
241
246
// and it has pre-defined shard sizes. The shards of the devices will have
242
247
// the following shapes: [4x2, 4x3, 4x4, 4x5]
243
- %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5 ]
248
+ %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14 ]
244
249
%sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
245
250
```
246
251
}];
@@ -250,8 +255,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
250
255
Mesh_MeshAxesArrayAttr:$split_axes,
251
256
OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
252
257
OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
253
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes ,
254
- Variadic<I64>:$dynamic_sharded_dims_sizes ,
258
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets ,
259
+ Variadic<I64>:$dynamic_sharded_dims_offsets ,
255
260
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
256
261
Variadic<I64>:$dynamic_halo_sizes
257
262
);
@@ -263,7 +268,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
263
268
`split_axes` `=` $split_axes
264
269
(`partial` `=` $partial_type $partial_axes^)?
265
270
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
266
- (`sharded_dims_sizes ` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes , $static_sharded_dims_sizes )^)?
271
+ (`sharded_dims_offsets ` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets , $static_sharded_dims_offsets )^)?
267
272
attr-dict `:` type($result)
268
273
}];
269
274
let builders = [
@@ -272,16 +277,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
272
277
"ArrayRef<MeshAxis>":$partial_axes,
273
278
"mesh::ReductionKind":$partial_type,
274
279
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
275
- CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes )>,
280
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets )>,
276
281
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
277
282
"ArrayRef<MeshAxesAttr>":$split_axes)>,
278
283
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
279
284
"ArrayRef<MeshAxesAttr>":$split_axes,
280
285
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
281
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes )>,
286
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets )>,
282
287
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
283
288
];
284
289
let hasVerifier = 1;
290
+ let hasCanonicalizer = 1;
285
291
}
286
292
287
293
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1058,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
1052
1058
}
1053
1059
1054
1060
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1055
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
1061
+ DestinationStyleOpInterface,
1062
+ TypesMatchWith<
1063
+ "result has same type as destination",
1064
+ "result", "destination", "$_self">,
1065
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1066
+ AttrSizedOperandSegments
1056
1067
]> {
1057
1068
let summary = "Update halo data.";
1058
1069
let description = [{
1059
1070
This operation updates halo regions of shards, e.g. if their sharding
1060
- specified halos and the actual tensor data might have changed
1071
+ specified halos and the actual tensor/memref data might have changed
1061
1072
on the remote devices. Changes might be caused by mutating operations
1062
1073
and/or if the new halo regions are larger than the existing ones.
1063
1074
1075
+ Source and destination might have different halo sizes.
1076
+
1064
1077
Assumes all devices hold tensors with same-sized halo data as specified
1065
- by `dynamic/static_halo_sizes`.
1078
+ by `source_halo_sizes/static_source_halo_sizes` and
1079
+ `destination_halo_sizes/static_destination_halo_sizes` in source shard
1080
+ and destination/result shard.
1066
1081
1067
1082
`split_axes` specifies for each tensor axis along which mesh axes its halo
1068
1083
data is updated.
1069
1084
1070
- Optionally resizes to new halo sizes `target_halo_sizes`.
1071
1085
}];
1072
1086
let arguments = (ins
1073
- AnyNon0RankedMemRef:$input,
1087
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
1088
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
1074
1089
FlatSymbolRefAttr:$mesh,
1075
1090
Mesh_MeshAxesArrayAttr:$split_axes,
1076
- Variadic<I64>:$dynamic_halo_sizes,
1077
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
1078
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
1091
+ Variadic<I64>:$source_halo_sizes,
1092
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1093
+ Variadic<I64>:$destination_halo_sizes,
1094
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
1095
+ );
1096
+ let results = (outs
1097
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
1079
1098
);
1080
1099
let assemblyFormat = [{
1081
- $input `on` $mesh
1100
+ $source `into` $destination
1101
+ `on` $mesh
1082
1102
`split_axes` `=` $split_axes
1083
- (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
1084
- (`target_halo_sizes` `=` $target_halo_sizes^)?
1085
- attr-dict `:` type($input)
1103
+ (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1104
+ (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
1105
+ attr-dict `:` type($source) `->` type($result)
1106
+ }];
1107
+ let extraClassDeclaration = [{
1108
+ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
1086
1109
}];
1087
1110
}
1088
1111
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
0 commit comments