Skip to content

[mlir][mesh] adding shard-size control #98145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)

set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)

set(LLVM_TARGET_DEFINITIONS MeshOps.td)
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
Expand Down
112 changes: 22 additions & 90 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
Expand All @@ -31,11 +32,13 @@ def Mesh_Dialect : Dialect {
];

let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
let hasConstantMaterializer = 1;
}

def Mesh_MeshAxis : I<16>;
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;

//===----------------------------------------------------------------------===//
// Mesh Enums.
Expand All @@ -59,104 +62,33 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
}

def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
let assemblyFormat = "`<` $value `>`";
let assemblyFormat = "$value";
}

class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
let mnemonic = typeMnemonic;
}

def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
let summary = "sharding definition";
let assemblyFormat = "";
}

//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//

def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let mnemonic = "shard";

let parameters = (ins
AttrParameter<"::mlir::FlatSymbolRefAttr",
"The mesh on which tensors are sharded.">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
);

let summary = "Attribute that extends tensor type to distributed tensor type.";

let description = [{
The MeshSharding attribute is used in a `mesh.shard` operation.
It specifies how a tensor is sharded and distributed across the process
mesh.

1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
mesh where the distributed tensor is placed. The symbol must resolve to a
`mesh.mesh` operation.

2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
along the x and y axes of the device mesh.

3. `partial_axes`: if not empty, this signifies that the tensor is partial
one along the specified mesh axes. An all-reduce should be applied to obtain
the complete tensor, with reduction type being specified by `partial_type`.

4. `partial_type`: indicates the reduction type of the possible all-reduce
op. It has 4 possible values:
`generic`: is not an allowed value inside a shard attribute.

Example:

```
mesh.mesh @mesh0(shape = 2x2x4)

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
#mesh.shard<@mesh0, [[]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0
#mesh.shard<@mesh0, [[0]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
#mesh.shard<@mesh0, [[0]], partial = max[1]>

// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
```
}];
let assemblyFormat = [{
`<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
$partial_axes^ `]`)? `>`
}];

let builders = [
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::ReductionKind": $partial_type), [{
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
});
return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
partial_type);
}]>,
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
}]>
];

def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
let mnemonic = "axisarray";
let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
let assemblyFormat = "`[` $axes `]`";
let extraClassDeclaration = [{
bool operator==(::mlir::Attribute rhs) const;
bool operator!=(::mlir::Attribute rhs) const;
bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
size_t size() const { return getAxes().size(); }
auto begin() const { return getAxes().begin(); }
auto end() const { return getAxes().end(); }
}];

let genVerifyDecl = 1;
}

#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
79 changes: 68 additions & 11 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ namespace mesh {

using MeshAxis = int16_t;
using MeshAxesAttr = DenseI16ArrayAttr;
using ShardShapeAttr = DenseI64ArrayAttr;
using HaloSizePairAttr = DenseI64ArrayAttr;

} // namespace mesh
} // namespace mlir
Expand All @@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"

namespace mlir {
namespace mesh {

class MeshSharding {
private:
::mlir::FlatSymbolRefAttr mesh;
SmallVector<MeshAxesAttr> split_axes;
SmallVector<MeshAxis> partial_axes;
ReductionKind partial_type;
SmallVector<int64_t> static_halo_sizes;
SmallVector<int64_t> static_sharded_dims_sizes;
SmallVector<Value> dynamic_halo_sizes;
SmallVector<Value> dynamic_sharded_dims_sizes;

public:
MeshSharding() = default;
MeshSharding(Value rhs);
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxesAttr> split_axes_,
ArrayRef<MeshAxis> partial_axes_ = {},
ReductionKind partial_type_ = ReductionKind::Sum,
ArrayRef<int64_t> static_halo_sizes_ = {},
ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
ArrayRef<Value> dynamic_halo_sizes_ = {},
ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
ArrayRef<int64_t> getStaticShardedDimsSizes() const {
return static_sharded_dims_sizes;
}
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
ArrayRef<Value> getDynamicShardedDimsSizes() const {
return dynamic_sharded_dims_sizes;
}
operator bool() const { return (!mesh) == false; }
bool operator==(Value rhs) const;
bool operator!=(Value rhs) const;
bool operator==(const MeshSharding &rhs) const;
bool operator!=(const MeshSharding &rhs) const;
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
};

} // namespace mesh
} // namespace mlir

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"

Expand All @@ -50,9 +105,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
}

// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
return attr.getPartialAxes().empty() &&
llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
inline bool isFullReplication(MeshSharding sharding) {
return sharding.getPartialAxes().empty() &&
llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
return axes.asArrayRef().empty();
});
}
Expand Down Expand Up @@ -80,8 +135,10 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
template <>
inline mesh::MeshOp
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
symbolTableCollection);
return getMesh(
op.getOperation(),
cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
symbolTableCollection);
}

// Get the number of processes that participate in each group
Expand Down Expand Up @@ -131,22 +188,22 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
// result in a shape for each shard of ?x2x?.
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
MeshShardingAttr sharding);
MeshSharding sharding);

// If ranked tensor type return its sharded counterpart.
//
// If not ranked tensor type return `type`.
// `sharding` in that case must be null.
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
Type shardType(Type type, MeshOp mesh, MeshSharding sharding);

// Insert shard op if there is not one that already has the same sharding.
// May insert resharding if required.
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
OpResult result, OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);

Expand Down
Loading