-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] Add resharding spmdization on a 1D device mesh #76179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Boian Petkantchin (sogartar) ChangesThe current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size. Patch is 83.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76179.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3329fb4..a9d30dfbb9a76e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
}
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
//===----------------------------------------------------------------------===//
// Mesh Enums.
//===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
- tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>
];
+ let extraClassDeclaration = [{
+ bool operator==(::mlir::Attribute rhs) const;
+ bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2eb0189b7..ce7d5d045122d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
namespace mlir {
namespace mesh {
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
bool isReductionLoop(IteratorType iType);
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb97763ad..1ed54b6519e4d8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the shape of the cluster.";
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+
+ let results = (outs
+ Variadic<Index>:$result
+ );
+
+ let assemblyFormat = [{
+ $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the index of current device along specified mesh axis.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
);
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
new file mode 100644
index 00000000000000..181f07177e0af9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -0,0 +1,621 @@
+Reshadring Spmdization Examples
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[1, 0]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 | |
++----+----+----+ |
+| 12 | 21 | 23 | |
++----+----+----+ ↓
+
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[1]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
+
+unsharded 4x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+
+sharded on a 2x3 mesh
+
+sharding = [[], [0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 14 | 15 | 16 | |
+| 24 | 25 | 26 | |
++----+----+----+ ↓
+
+Transform into
+sharding = [[], [0]]
+
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 | |
+| 21 22 23 | 21 22 23 | |
++----------+----------+ |
+| 14 15 16 | 14 15 16 | |
+| 24 25 26 | 24 25 26 | |
++----------+----------+ ↓
+
+all-gather along mesh axis 1
+
+--------------------------------------------------------------
+
+Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
+
+unsharded 4x8 tensor
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh contents:
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 | | |
+| 21 22 | 23 24 | | |
++-------+-------+ | |
+| 15 16 | 17 18 | | |
+| 25 26 | 27 28 | | |
++-------+-------+ ↓ |
++-------+-------+ |
+| 31 32 | 33 34 | |
+| 41 42 | 43 44 | |
++-------+-------+ |
+| 35 36 | 37 38 | |
+| 45 46 | 47 48 | |
++-------+-------+ ↓
+
+Transform into
+sharding = [[0], [2]]
+
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ | |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ ↓ |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ ↓
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 2x3 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 | |
+| 21 22 | 23 24 | 25 26 | |
+| 31 32 | 33 34 | 35 36 | |
++-------+-------+-------+ |
+| 41 42 | 43 44 | 45 46 | |
+| 51 52 | 53 54 | 55 56 | |
+| 61 62 | 63 64 | 65 66 | |
++-------+-------+-------+ ↓
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 | |
+| 21 22 23 | 41 42 43 | 61 62 63 | |
++----------+----------+----------+ |
+| 14 15 16 | 34 35 36 | 54 55 56 | |
+| 24 25 26 | 44 45 46 | 64 65 66 | |
++----------+----------+----------+ ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+shard on 2x6 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 | |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 | |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 | |
++----+----+----+----+----+----+ |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 | |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 | |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 | |
++----+----+----+----+----+----+ ↓
+
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
++----------+----------+ |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
++==========+==========+ |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
++----------+----------+ |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
+
+M x N mesh.
+K x L tensor t.
+d(m, n) the tensor on device (m, n).
+
+sharding = [[0], [1]]
+Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+
+substitute
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+
+For the inverse map we get
+t[i, j] -> d(
+ i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+ i % (K ceildiv M), j % (L ceildiv N)
+]
+
+Check:
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+ 13 floordiv (16 ceildiv 3),
+ 17 floordiv (23 ceilvid 4)
+)[
+ 13 % (16 ceildiv 3),
+ 17 % (23 ceilvid 4)
+]
+= d(
+ 13 floordiv 6,
+ 17 floordiv 6
+)[
+ 13 % 6,
+ 17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+ 2 * (16 ceildiv 3) + 1,
+ 2 * (23 ceildiv 4) + 5
+]
+= t[
+ 2 * 6 + 1,
+ 2 * 6 + 5
+]
+= t[13, 17]
+
+
+sharding = [[1], [0]]
+Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+
+substitute
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+
+For the inverse map we get
+t[i, j] -> d(
+ j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+ i % (K ceildiv N), j % (L ceildiv M)
+]
+
+Check:
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+ 19 floordiv (14 ceildiv 5),
+ 9 floordiv (27 ceildiv 2)
+)[
+ 9 % (27 ceildiv 2),
+ 19 % (14 ceildiv 5)
+]
+= d(
+ 19 floordiv 3,
+ 9 floordiv 14
+)[
+ 9 % 14
+ 19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+ 0 * (27 ceildiv 2) + 9,
+ 6 * (14 ceildiv 5) + 1
+]
+= t[
+ 0 * 14 + 9,
+ 6 * 3 + 1
+]
+= t[9, 19]
+
+sharding = [[0], [1]]
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+
+sharding = [[1], [0]]
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+
+sharding [[0], [1]] -> [[1], [0]]
+d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
+d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+ (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+ (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+ (n * (K ceildiv N) + k) % (K ceildiv N),
+ (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+Device placement on a 2x3 mesh
+11 12 13 <- devices
+21 22 23
+
+sharding [[0], [1]]
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+
+transform to
+sharding [[1], [0]]
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 | |
++----+----+ |
+| 12 | 22 | |
++----+----+ |
+| 13 | 23 | |
++----+----+ ↓
+
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 11 | 12 | 13 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 21 | 22 | 23 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | |
++ 11 + 21 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 12 + 22 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 13 + 23 +
+| | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | | | |
++ 11 11 + 12 11 + 12 21 + 13 21 +
+| | | | |
++-----------------+--------+--------+-----------------+
+| 11 12 | 12 12 | 12 22 | 13 22 |
++-----------------+--------+--------+-----------------+
+| 21 12 | 22 12 | 22 22 | 23 22 |
++-----------------+--------+--------+-----------------+
+| | | | |
++ 21 13 + 22 13 + 22 23 + 23 23 +
+| | | | |
++-----------------+--------+--------+-----------------+
+
+If S and T are the source and target shard sizes along some tensor axis.
+Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 3 mesh
+
+sharding = [[0], []]
+
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 | |
+| 21 22 23 24 25 26 | |
++-------------------+ |
+| 31 32 33 34 35 36 | |
+| 41 42 43 44 45 46 | |
++-------------------+ |
+| 51 52 53 54 55 56 | |
+| 61 62 63 64 65 66 | |
++-------------------+ ↓
+
+transform to
+sharding = [[], [0]]
+
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+
+--------------------------------------------------------------
+
+Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.
+
+unsharded 4x4 tensor
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 | | |
+| 21 | 22 | | |
++----+----+ | |
+| 13 | 14 | | |
+| 23 | 24 | | |
++----+----+ ↓ |
++----+----+ |
+| 31 | 32 | |
+| 41...
[truncated]
|
@yaochengji, could you review this PR? |
As a part of this PR I added two new ops to get the shape of the mesh and the index of the current device. I can move them into their own PR. |
@@ -0,0 +1,621 @@ | |||
Reshadring Spmdization Examples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that this should be here. It is not a user facing doc, but it may be important to illustrate what is going on in resharding and what are the future challenges to make a communication efficient implementation.
mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
Outdated
Show resolved
Hide resolved
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> { | |||
}]; | |||
} | |||
|
|||
def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an Int value representing the device/process id, not a list of int?
E.g. https://www.tensorflow.org/xla/operation_semantics#allreduce uses a single int.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you mean to have the correspondence between multi index and flat/linear index.
Do you want such an operation explicitly in the IR that extracts the linear index?
If it is to be used during lowering we can have a C++ function that adds IR to flatten the index. There may be already somewhere in MLIR such a thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, a flat index is what I mean. I think both an explicit IR and implementation during lowering is fine.
But we'd better clearly wrote the doc clearly how the multi-index is converted to flat index.
E.g. a mesh cluster of 2x2, (1, 0) represents device id 1 or 2 might need to be defined.
We can leave this doc to the next PR.
LGTM, thanks |
The current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size.
…nDoc Co-authored-by: Chengji Yao <[email protected]>
…nDoc Co-authored-by: Chengji Yao <[email protected]>
8b5fceb
to
a483720
Compare
I missed to commit the doc markdown file in my last commit. |
The current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size.