Skip to content

Commit ffc7fea

Browse files
authored
[mlir][mesh] Handling changed halo region sizes during spmdization (#114238)
* Changed `MeshSharding::sharded_dims_sizes` from representing sizes per shard to offsets to origin per shard. - Local shard size are now a simple subtraction - Offsets are now readily available without a reduction operation - Enables constant value/shape propagation through standard canonicalization - Renamed to `sharded_dims_offsets` accordingly. * First spmdization pattern for halo regions. - Triggers when source and destination shardings differ only in their halo sizes - Copies local data from source into a new tensor and calls update_halo - Supports arbitrary mesh dimensions (unlike the other patterns which work on 1d meshes only) * `UpdateHaloOp` implements `DestinationStyleOpInterface` and accepts tensors and memrefs - also accepts target and source halo sizes; both are required for proper lowering * minor refactoring for testing partial MeshSharding equality * Canonicalization for ShardingOp folding constant values into respective `static_*` attributes
1 parent a912c81 commit ffc7fea

File tree

9 files changed

+401
-132
lines changed

9 files changed

+401
-132
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/OpDefinition.h"
1616
#include "mlir/IR/PatternMatch.h"
1717
#include "mlir/IR/SymbolTable.h"
18+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1819
#include "mlir/Interfaces/InferTypeOpInterface.h"
1920
#include "mlir/Interfaces/SideEffectInterfaces.h"
2021
#include "llvm/Support/MathExtras.h"
@@ -45,9 +46,9 @@ class MeshSharding {
4546
SmallVector<MeshAxis> partial_axes;
4647
ReductionKind partial_type;
4748
SmallVector<int64_t> static_halo_sizes;
48-
SmallVector<int64_t> static_sharded_dims_sizes;
49+
SmallVector<int64_t> static_sharded_dims_offsets;
4950
SmallVector<Value> dynamic_halo_sizes;
50-
SmallVector<Value> dynamic_sharded_dims_sizes;
51+
SmallVector<Value> dynamic_sharded_dims_offsets;
5152

5253
public:
5354
MeshSharding() = default;
@@ -57,21 +58,21 @@ class MeshSharding {
5758
ArrayRef<MeshAxis> partial_axes_ = {},
5859
ReductionKind partial_type_ = ReductionKind::Sum,
5960
ArrayRef<int64_t> static_halo_sizes_ = {},
60-
ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
61+
ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
6162
ArrayRef<Value> dynamic_halo_sizes_ = {},
62-
ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
63+
ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
6364
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
6465
::llvm::StringRef getMesh() const { return mesh.getValue(); }
6566
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
6667
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
6768
ReductionKind getPartialType() const { return partial_type; }
6869
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
69-
ArrayRef<int64_t> getStaticShardedDimsSizes() const {
70-
return static_sharded_dims_sizes;
70+
ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
71+
return static_sharded_dims_offsets;
7172
}
7273
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
73-
ArrayRef<Value> getDynamicShardedDimsSizes() const {
74-
return dynamic_sharded_dims_sizes;
74+
ArrayRef<Value> getDynamicShardedDimsOffsets() const {
75+
return dynamic_sharded_dims_offsets;
7576
}
7677
operator bool() const { return (!mesh) == false; }
7778
bool operator==(Value rhs) const;
@@ -80,6 +81,8 @@ class MeshSharding {
8081
bool operator!=(const MeshSharding &rhs) const;
8182
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
8283
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
84+
bool equalHaloSizes(const MeshSharding &rhs) const;
85+
bool equalShardSizes(const MeshSharding &rhs) const;
8386
};
8487

8588
} // namespace mesh

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/Mesh/IR/MeshBase.td"
1313
include "mlir/Dialect/Shape/IR/ShapeBase.td"
14+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1415
include "mlir/Interfaces/InferTypeOpInterface.td"
1516
include "mlir/Interfaces/SideEffectInterfaces.td"
1617
include "mlir/IR/BuiltinTypes.td"
@@ -189,23 +190,27 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
189190
`generic`: is not an allowed value inside a shard attribute.
190191

191192
5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
192-
`halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
193-
`halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
194-
halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
195-
`halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
196-
e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
197-
`?` indicates dynamic halo sizes.
193+
`halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
194+
sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
195+
gets an additional halo of size 1 at the start of the first dimension and a halo
196+
size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
197+
sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
198+
seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
199+
200+
6. [Optional] Offsets for each shard and sharded tensor dimension.
201+
`sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
202+
sharded tensor dimension the offsets (starting index) of all shards in that
203+
dimension and an additional value for the end of the last shard are provided.
204+
For a 1d sharding this means that position `i` has the exclusive prefix sum for
205+
shard `i`, and since only contiguous sharding is supported, its inclusive prefix
206+
sum is at position 'i+1'.
198207

199-
6. [Optional] Sizes of sharded dimensions of each shard.
200-
`sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
201-
device-mesh one value for each sharded tensor dimension.
202208
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
203-
`sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
204-
the device-mesh will get a shard of shape 16x8x32 and the second device will get a
205-
shard of shape 16x24x32.
206-
`?` indicates dynamic shard dimensions.
209+
`sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
210+
the device-mesh will get a shard of shape 24x20x32 and the second device will get
211+
a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
207212

208-
`halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
213+
`halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
209214

210215
Examples:
211216

@@ -240,7 +245,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
240245
// The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
241246
// and it has pre-defined shard sizes. The shards of the devices will have
242247
// the following shapes: [4x2, 4x3, 4x4, 4x5]
243-
%sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
248+
%sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
244249
%sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
245250
```
246251
}];
@@ -250,8 +255,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
250255
Mesh_MeshAxesArrayAttr:$split_axes,
251256
OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
252257
OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
253-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
254-
Variadic<I64>:$dynamic_sharded_dims_sizes,
258+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
259+
Variadic<I64>:$dynamic_sharded_dims_offsets,
255260
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
256261
Variadic<I64>:$dynamic_halo_sizes
257262
);
@@ -263,7 +268,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
263268
`split_axes` `=` $split_axes
264269
(`partial` `=` $partial_type $partial_axes^)?
265270
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
266-
(`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
271+
(`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
267272
attr-dict `:` type($result)
268273
}];
269274
let builders = [
@@ -272,16 +277,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
272277
"ArrayRef<MeshAxis>":$partial_axes,
273278
"mesh::ReductionKind":$partial_type,
274279
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
275-
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
280+
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
276281
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
277282
"ArrayRef<MeshAxesAttr>":$split_axes)>,
278283
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
279284
"ArrayRef<MeshAxesAttr>":$split_axes,
280285
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
281-
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
286+
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
282287
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
283288
];
284289
let hasVerifier = 1;
290+
let hasCanonicalizer = 1;
285291
}
286292

287293
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1058,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
10521058
}
10531059

10541060
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1055-
DeclareOpInterfaceMethods<SymbolUserOpInterface>
1061+
DestinationStyleOpInterface,
1062+
TypesMatchWith<
1063+
"result has same type as destination",
1064+
"result", "destination", "$_self">,
1065+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1066+
AttrSizedOperandSegments
10561067
]> {
10571068
let summary = "Update halo data.";
10581069
let description = [{
10591070
This operation updates halo regions of shards, e.g. if their sharding
1060-
specified halos and the actual tensor data might have changed
1071+
specified halos and the actual tensor/memref data might have changed
10611072
on the remote devices. Changes might be caused by mutating operations
10621073
and/or if the new halo regions are larger than the existing ones.
10631074

1075+
Source and destination might have different halo sizes.
1076+
10641077
Assumes all devices hold tensors with same-sized halo data as specified
1065-
by `dynamic/static_halo_sizes`.
1078+
by `source_halo_sizes/static_source_halo_sizes` and
1079+
`destination_halo_sizes/static_destination_halo_sizes` in source shard
1080+
and destination/result shard.
10661081

10671082
`split_axes` specifies for each tensor axis along which mesh axes its halo
10681083
data is updated.
10691084

1070-
Optionally resizes to new halo sizes `target_halo_sizes`.
10711085
}];
10721086
let arguments = (ins
1073-
AnyNon0RankedMemRef:$input,
1087+
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
1088+
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
10741089
FlatSymbolRefAttr:$mesh,
10751090
Mesh_MeshAxesArrayAttr:$split_axes,
1076-
Variadic<I64>:$dynamic_halo_sizes,
1077-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
1078-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
1091+
Variadic<I64>:$source_halo_sizes,
1092+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1093+
Variadic<I64>:$destination_halo_sizes,
1094+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
1095+
);
1096+
let results = (outs
1097+
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
10791098
);
10801099
let assemblyFormat = [{
1081-
$input `on` $mesh
1100+
$source `into` $destination
1101+
`on` $mesh
10821102
`split_axes` `=` $split_axes
1083-
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
1084-
(`target_halo_sizes` `=` $target_halo_sizes^)?
1085-
attr-dict `:` type($input)
1103+
(`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1104+
(`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
1105+
attr-dict `:` type($source) `->` type($result)
1106+
}];
1107+
let extraClassDeclaration = [{
1108+
MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
10861109
}];
10871110
}
10881111
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

0 commit comments

Comments
 (0)