-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] adding shard-size control #98145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-linalg Author: Frank Schlimbach (fschlimb) Changes
What previously was %sharded0 = mesh.shard %arg0 <@<!-- -->mesh0, [[0]]> : tensor<4x8xf32>
%sharded1 = mesh.shard %arg1 <@<!-- -->mesh0, [[0]]> annotate_for_users : tensor<16x8xf32> is now %sharding = mesh.sharding @<!-- -->mesh0, [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> and allows additional annotations to control the shard sizes: mesh.mesh @<!-- -->mesh1d_4(shape = 4)
%sharding0 = mesh.sharding @<!-- -->mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32>
%sharding0 = mesh.sharding @<!-- -->mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
@sogartar @yaochengji Patch is 207.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98145.diff 28 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 7ba966d8cab7c..f26c6285efd89 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -13,6 +13,10 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
+
set(LLVM_TARGET_DEFINITIONS MeshOps.td)
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 3a85bf2d552f3..61403ac178980 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -12,6 +12,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/EnumAttr.td"
//===----------------------------------------------------------------------===//
@@ -31,11 +32,13 @@ def Mesh_Dialect : Dialect {
];
let useDefaultAttributePrinterParser = 1;
+ let useDefaultTypePrinterParser = 1;
let hasConstantMaterializer = 1;
}
def Mesh_MeshAxis : I<16>;
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
//===----------------------------------------------------------------------===//
// Mesh Enums.
@@ -59,104 +62,33 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
}
def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
- let assemblyFormat = "`<` $value `>`";
+ let assemblyFormat = "$value";
+}
+
+class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
+ string baseCppClass = "::mlir::Type">
+ : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
+ let mnemonic = typeMnemonic;
+}
+
+def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
+ let summary = "sharding definition";
+ let assemblyFormat = "";
}
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
-def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
- let mnemonic = "shard";
-
- let parameters = (ins
- AttrParameter<"::mlir::FlatSymbolRefAttr",
- "The mesh on which tensors are sharded.">:$mesh,
- ArrayRefParameter<"MeshAxesAttr">:$split_axes,
- OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
- OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
- );
-
- let summary = "Attribute that extends tensor type to distributed tensor type.";
-
- let description = [{
- The MeshSharding attribute is used in a `mesh.shard` operation.
- It specifies how a tensor is sharded and distributed across the process
- mesh.
-
- 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
- mesh where the distributed tensor is placed. The symbol must resolve to a
- `mesh.mesh` operation.
-
- 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
- maximum size is the `rank` of the related tensor. For the i-th sub-array, if
- its value is [x, y], it indicates that the tensor's i-th dimension is splitted
- along the x and y axes of the device mesh.
-
- 3. `partial_axes`: if not empty, this signifies that the tensor is partial
- one along the specified mesh axes. An all-reduce should be applied to obtain
- the complete tensor, with reduction type being specified by `partial_type`.
-
- 4. `partial_type`: indicates the reduction type of the possible all-reduce
- op. It has 4 possible values:
- `generic`: is not an allowed value inside a shard attribute.
-
- Example:
-
- ```
- mesh.mesh @mesh0(shape = 2x2x4)
-
- // The tensor is fully replicated on @mesh0.
- // Currently, there must be at least one sub-array present in axes, even
- // if it's empty. Otherwise, a parsing error will occur.
- #mesh.shard<@mesh0, [[]]>
-
- // The tensor is sharded on the first dimension along axis 0 of @mesh0
- #mesh.shard<@mesh0, [[0]]>
-
- // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
- // it is also a partial_sum along mesh axis 1.
- #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
-
- // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
- // it is also a partial_max along mesh axis 1.
- #mesh.shard<@mesh0, [[0]], partial = max[1]>
-
- // Could be used in the attribute of mesh.shard op
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
- ```
- }];
- let assemblyFormat = [{
- `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
- $partial_axes^ `]`)? `>`
- }];
-
- let builders = [
- AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
- "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
- "ArrayRef<MeshAxis>": $partial_axes,
- "mesh::ReductionKind": $partial_type), [{
- SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
- split_axes, [&](ArrayRef<MeshAxis> array) {
- return MeshAxesAttr::get($_ctxt, array);
- });
- return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
- partial_type);
- }]>,
- AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
- "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
- }]>
- ];
-
+def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
+ let mnemonic = "axisarray";
+ let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
+ let assemblyFormat = "`[` $axes `]`";
let extraClassDeclaration = [{
- bool operator==(::mlir::Attribute rhs) const;
- bool operator!=(::mlir::Attribute rhs) const;
- bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
- bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
+ size_t size() const { return getAxes().size(); }
+ auto begin() const { return getAxes().begin(); }
+ auto end() const { return getAxes().end(); }
}];
-
- let genVerifyDecl = 1;
}
#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index b27c9e81b3293..3c467d6f95948 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,6 +24,8 @@ namespace mesh {
using MeshAxis = int16_t;
using MeshAxesAttr = DenseI16ArrayAttr;
+using ShardShapeAttr = DenseI64ArrayAttr;
+using HaloSizePairAttr = DenseI64ArrayAttr;
} // namespace mesh
} // namespace mlir
@@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
+namespace mlir {
+namespace mesh {
+
+class MeshSharding {
+private:
+ ::mlir::FlatSymbolRefAttr mesh;
+ SmallVector<MeshAxesAttr> split_axes;
+ SmallVector<MeshAxis> partial_axes;
+ ReductionKind partial_type;
+ SmallVector<int64_t> static_halo_sizes;
+ SmallVector<int64_t> static_sharded_dims_sizes;
+ SmallVector<Value> dynamic_halo_sizes;
+ SmallVector<Value> dynamic_sharded_dims_sizes;
+
+public:
+ MeshSharding() = default;
+ MeshSharding(Value rhs);
+ static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
+ ArrayRef<MeshAxesAttr> split_axes_,
+ ArrayRef<MeshAxis> partial_axes_ = {},
+ ReductionKind partial_type_ = ReductionKind::Sum,
+ ArrayRef<int64_t> static_halo_sizes_ = {},
+ ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+ ArrayRef<Value> dynamic_halo_sizes_ = {},
+ ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+ ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+ ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+ ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+ ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+ ReductionKind getPartialType() const { return partial_type; }
+ ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+ ArrayRef<int64_t> getStaticShardedDimsSizes() const {
+ return static_sharded_dims_sizes;
+ }
+ ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+ ArrayRef<Value> getDynamicShardedDimsSizes() const {
+ return dynamic_sharded_dims_sizes;
+ }
+ operator bool() const { return (!mesh) == false; }
+ bool operator==(Value rhs) const;
+ bool operator!=(Value rhs) const;
+ bool operator==(const MeshSharding &rhs) const;
+ bool operator!=(const MeshSharding &rhs) const;
+ bool sameExceptConstraint(const MeshSharding &rhs) const;
+ bool sameConstraint(const MeshSharding &rhs) const;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -50,9 +105,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
}
// Is the same tensor replicated on all processes.
-inline bool isFullReplication(MeshShardingAttr attr) {
- return attr.getPartialAxes().empty() &&
- llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+inline bool isFullReplication(MeshSharding sharding) {
+ return sharding.getPartialAxes().empty() &&
+ llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
return axes.asArrayRef().empty();
});
}
@@ -80,8 +135,10 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
template <>
inline mesh::MeshOp
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
- return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
- symbolTableCollection);
+ return getMesh(
+ op.getOperation(),
+ cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
+ symbolTableCollection);
}
// Get the number of processes that participate in each group
@@ -131,22 +188,22 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
// result in a shape for each shard of ?x2x?.
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
- MeshShardingAttr sharding);
+ MeshSharding sharding);
// If ranked tensor type return its sharded counterpart.
//
// If not ranked tensor type return `type`.
// `sharding` in that case must be null.
-Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Insert shard op if there is not one that already has the same sharding.
// May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
- OpResult result, OpBuilder &builder);
-void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
+ OpBuilder &builder);
+void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8e1e475463585..49c4037942f6f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -20,7 +20,7 @@ include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
-// Mesh Dialect operations.
+// Mesh operations.
//===----------------------------------------------------------------------===//
class Mesh_Op<string mnemonic, list<Trait> traits = []> :
@@ -105,22 +105,221 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the multi index of current device along specified mesh axes.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ If the axes are empty then get the index along all axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+ ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the linear index of the current device.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.process_linear_index on @mesh : index
+ ```
+ if `@mesh` has shape `(10, 20, 30)`, a device with multi
+ index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh);
+ let results = (outs Index:$result);
+ let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+ Pure,
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Define a sharding of a tensor.";
+ let description = [{
+ The MeshSharding specifies how a tensor is sharded and distributed across the
+ process mesh. It is typically used in a `mesh.shard` operation.
+ The operation has the follwing attributes and operands:
+
+ 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+ mesh where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.mesh` operation.
+
+ 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+ maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+ its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+ along the x and y axes of the device mesh.
+
+ 3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
+ one along the specified mesh axes. An all-reduce should be applied to obtain
+ the complete tensor, with reduction type being specified by `partial_type`.
+
+ 4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
+ op. It has 4 possible values:
+ `generic`: is not an allowed value inside a shard attribute.
+
+ 5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+ `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+ `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+ halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
+ `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+ e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+ `?` indicates dynamic halo sizes.
+
+ 6. [Optional] Sizes of sharded dimensions of each shard.
+ `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+ device-mesh one value for each sharded tensor dimension.
+ Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+ `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+ the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+ shard of shape 16x24x32.
+ `?` indicates dynamic shard dimensions.
+
+ `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+
+ Examples:
+
+ ```
+ mesh.mesh @mesh0(shape = 2x2x4)
+ mesh.mesh @mesh1d_4(shape = 4)
+
+ // The tensor is fully replicated on @mesh0.
+ // Currently, there must be at least one sub-array present in axes, even
+ // if it's empty. Otherwise, a parsing error will occur.
+ %sharding0 = mesh.sharding @mesh0, [[]]
+
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0
+ %sharding1 = mesh.sharding @mesh0, [[0]]
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // it is also a partial_sum along mesh axis 1.
+ %sharding2 = mesh.sharding @mesh0, [[0], []] partial = sum[1]
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // it is also a partial_max along mesh axis 1.
+ %sharding3 = mesh.sharding @mesh0, [[0]] partial = max[1]
+
+ // Could be used for a mesh.shard op
+ %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // and it has halo-sizes of 1 and 2 on the sharded dim.
+ %halo_sharding = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2]
+ %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+
+ // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+ // and it has pre-defined shard sizes. The shards of the devices will have
+ // the following shapes: [4x2, 4x3, 4x4, 4x5]
+ %sharding4 = mesh.sharding @mesh1d_4, [[], [0]] sharded_dims_sizes = [2, 3, 4, 5]
+ %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ Mesh_MeshAxesArrayAttr:$split_axes,
+ OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
+ OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
+ Variadic<I64>:$dynamic_sharded_dims_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
+ Variadic<I64>:$dynamic_halo_sizes
+ );
+ let results = (outs
+ Mesh_Sharding:$result
+ );
+ let assemblyFormat = [{
+ $mesh `,` $split_axes
+ (`partial` `=` $partial_type $partial_axes^)?
+ (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
+ (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes,
+ "ArrayRef<MeshAxis>":$partial_axes,
+ "mesh::ReductionKind":$partial_type,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+ OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes)>,
+ OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+ "ArrayRef<MeshAxesAttr...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for contributing. I have a few remarks.
Thanks for your detailed review! See my comments/modifications. |
4. `force`: A unit attribute requesting an explicit sharding of the data, | ||
therefore not allowing to be optimizied away. This is useful in the presence | ||
of halos and inplace semantics. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add explanation in the doc with an example?
I still don't understand. It would be desirable to avoid the complexity of another attribute.
Shouldn't insert_slice
know how to handle its spmdization? It would have requirements that the destination-passing style operand and the result are related and that they need to have the same sharding. If this constraint can not be satisfied then resharding will be inserted during sharding propagation.
✅ With the latest revision this PR passed the C/C++ code formatter. |
@sogartar You're probably right and we can go without the the I also added a check in a symbolverifier to disallow sharded_dims_sizes on dynamic meshes. As far as I can tell I addressed all your concerns and suggestions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fschlimb, thank you for your contribution. Could you rebase before merging since this PR has been opened for a while and there may be some conflicts.
The commit/pr title should be prefixed with [mlir][mesh]
.
Adding halo_sizes and shard_dims_sizes to sharding. First spmdization of halo annotated sharding
@fschlimb Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested Please check whether problems have been caused by your change specifically, as How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/89/builds/3580 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/2153 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/80/builds/1865 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/130/builds/1817 Here is the relevant piece of the build log for the reference:
|
Reverted along with the fixup. Please build the projects that are failing in CI and run the tests above before creating a new PR. |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/35/builds/1679 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/124/builds/111 Here is the relevant piece of the build log for the reference:
|
This is a fixed copy of #98145 (necessary after it got reverted). @sogartar @yaochengji This PR adds the following to #98145: - `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not returning a result to clarify its inplace-semantics - `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per tensor/memref-axis (similar to `mesh.sharding`) - The implementation of `Shardinginterface` for tensor operation (`tensor.empty` for now) moved from the tensor library to the mesh interface library. `spmdize` uses features from `mesh` dialect. @rengolin agreed that `tensor` should not depend on `mesh` so this functionality cannot live in a `tensor`s lib. The unfulfilled dependency caused the issues leading to reverting #98145. Such cases are generally possible and might lead to re-considering the current structure (like for tosa ops). - rebased onto latest main -------------------------- Replacing `#mesh.sharding` attribute with operation `mesh.sharding` - extended semantics now allow providing optional `halo_sizes` and `sharded_dims_sizes` - internally a sharding is represented as a non-IR class `mesh::MeshSharding` What previously was ```mlir %sharded0 = mesh.shard %arg0 <@Mesh0, [[0]]> : tensor<4x8xf32> %sharded1 = mesh.shard %arg1 <@Mesh0, [[0]]> annotate_for_users : tensor<16x8xf32> ``` is now ```mlir %sharding = mesh.sharding @Mesh0, [[0]] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32> %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ``` and allows additional annotations to control the shard sizes: ```mlir mesh.mesh @Mesh0 (shape = 4) %sharding0 = mesh.sharding @Mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding %0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32> %sharding1 = mesh.sharding @Mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding %1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32> ``` - `mesh.shard` op accepts additional optional attribute `force`, useful for halo updates - Some initial spmdization support for the new semantics - Support for `tensor.empty` reacting on `sharded_dims_sizes` and `halo_sizes` in the sharding - New collective operation `mesh.update_halo` as a spmdized target for shardings with `halo_sizes` --------- Co-authored-by: frank.schlimbach <[email protected]> Co-authored-by: Jie Fu <[email protected]>
#mesh.sharding
attribute with operationmesh.sharding
halo_sizes
andsharded_dims_sizes
mesh::MeshSharding
What previously was
is now
and allows additional annotations to control the shard sizes:
mesh.shard
op accepts additional optional attributeforce
, useful for halo updatestensor.empty
reacting onsharded_dims_sizes
andhalo_sizes
in the shardingmesh.update_halo
as a spmdized target for shardings withhalo_sizes
@sogartar @yaochengji