Skip to content

[mlir][mesh] Add resharding spmdization on a 1D device mesh #76179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 2, 2024

Conversation

sogartar
Copy link
Contributor

The current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 21, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 21, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Boian Petkantchin (sogartar)

Changes

The current implementation supports only sharding of tensor axes that have size divisible by the mesh axis size.


Patch is 83.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76179.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-1)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+3)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+47-1)
  • (added) mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc (+621)
  • (added) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+35)
  • (modified) mlir/include/mlir/Support/MathExtras.h (+11)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+160-53)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+3)
  • (added) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+639)
  • (modified) mlir/test/Dialect/Mesh/invalid.mlir (+96)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+49)
  • (added) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+154)
  • (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (+124)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3329fb4..a9d30dfbb9a76e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let hasConstantMaterializer = 1;
 }
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
 //===----------------------------------------------------------------------===//
 // Mesh Enums.
 //===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     }]>
   ];
 
+  let extraClassDeclaration = [{
+    bool operator==(::mlir::Attribute rhs) const;
+    bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+  }];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2eb0189b7..ce7d5d045122d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
 namespace mlir {
 namespace mesh {
 
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
 bool isReductionLoop(IteratorType iType);
 
 bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb97763ad..1ed54b6519e4d8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_TD
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   let hasVerifier = 1;
 }
 
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the shape of the cluster.";
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+
+  let results = (outs
+    Variadic<Index>:$result
+  );
+
+  let assemblyFormat = [{
+    $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
   let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the index of current device along specified mesh axis.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
       [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
   );
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
new file mode 100644
index 00000000000000..181f07177e0af9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -0,0 +1,621 @@
+Reshadring Spmdization Examples
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[1, 0]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 |             |
++----+----+----+             |
+| 12 | 21 | 23 |             |
++----+----+----+             ↓
+
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[1]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
+
+unsharded 4x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+
+sharded on a 2x3 mesh
+
+sharding = [[], [0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 14 | 15 | 16 |             |
+| 24 | 25 | 26 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[], [0]]
+
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 |             |
+| 21 22 23 | 21 22 23 |             |
++----------+----------+             |
+| 14 15 16 | 14 15 16 |             |
+| 24 25 26 | 24 25 26 |             |
++----------+----------+             ↓
+
+all-gather along mesh axis 1
+
+--------------------------------------------------------------
+
+Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
+
+unsharded 4x8 tensor
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh contents:
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 |             |             |
+| 21 22 | 23 24 |             |             |
++-------+-------+             |             |
+| 15 16 | 17 18 |             |             |
+| 25 26 | 27 28 |             |             |
++-------+-------+             ↓             |
++-------+-------+                           |
+| 31 32 | 33 34 |                           |
+| 41 42 | 43 44 |                           |
++-------+-------+                           |
+| 35 36 | 37 38 |                           |
+| 45 46 | 47 48 |                           |
++-------+-------+                           ↓
+
+Transform into
+sharding = [[0], [2]]
+
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             |             |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             ↓             |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           ↓
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 2x3 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 |             |
+| 21 22 | 23 24 | 25 26 |             |
+| 31 32 | 33 34 | 35 36 |             |
++-------+-------+-------+             |
+| 41 42 | 43 44 | 45 46 |             |
+| 51 52 | 53 54 | 55 56 |             |
+| 61 62 | 63 64 | 65 66 |             |
++-------+-------+-------+             ↓
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 |             |
+| 21 22 23 | 41 42 43 | 61 62 63 |             |
++----------+----------+----------+             |
+| 14 15 16 | 34 35 36 | 54 55 56 |             |
+| 24 25 26 | 44 45 46 | 64 65 66 |             |
++----------+----------+----------+             ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+shard on 2x6 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
++----+----+----+----+----+----+             |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 |             |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 |             |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 |             |
++----+----+----+----+----+----+             ↓
+
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
++----------+----------+             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
++==========+==========+             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
++----------+----------+             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
+
+M x N mesh.
+K x L tensor t.
+d(m, n) the tensor on device (m, n).
+
+sharding = [[0], [1]]
+Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+
+substitute
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+
+For the inverse map we get
+t[i, j] -> d(
+    i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+    i % (K ceildiv M), j % (L ceildiv N)
+]
+
+Check:
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+    13 floordiv (16 ceildiv 3),
+    17 floordiv (23 ceilvid 4)
+)[
+    13 % (16 ceildiv 3),
+    17 % (23 ceilvid 4)
+]
+= d(
+    13 floordiv 6,
+    17 floordiv 6
+)[
+    13 % 6,
+    17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+    2 * (16 ceildiv 3) + 1,
+    2 * (23 ceildiv 4) + 5
+]
+= t[
+    2 * 6 + 1,
+    2 * 6 + 5
+]
+= t[13, 17]
+
+
+sharding = [[1], [0]]
+Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+
+substitute
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+
+For the inverse map we get
+t[i, j] -> d(
+    j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+    i % (K ceildiv N), j % (L ceildiv M)
+]
+
+Check:
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+    19 floordiv (14 ceildiv 5),
+    9 floordiv (27 ceildiv 2)
+)[
+    9 % (27 ceildiv 2),
+    19 % (14 ceildiv 5)
+]
+= d(
+    19 floordiv 3,
+    9 floordiv 14
+)[
+    9 % 14
+    19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+    0 * (27 ceildiv 2) + 9,
+    6 * (14 ceildiv 5) + 1
+]
+= t[
+    0 * 14 + 9,
+    6 * 3 + 1
+]
+= t[9, 19]
+
+sharding = [[0], [1]]
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+
+sharding = [[1], [0]]
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+
+sharding [[0], [1]] -> [[1], [0]]
+d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
+d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+    (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+    (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+    (n * (K ceildiv N) + k) % (K ceildiv N),
+    (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+Device placement on a 2x3 mesh
+11 12 13  <- devices
+21 22 23
+
+sharding [[0], [1]]
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 |               |
++----+----+----+               |
+| 21 | 22 | 23 |               |
++----+----+----+               ↓
+
+transform to
+sharding [[1], [0]]
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 |               |
++----+----+               |
+| 12 | 22 |               |
++----+----+               |
+| 13 | 23 |               |
++----+----+               ↓
+
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       11        |        12       |        13       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       21        |        22       |        23       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          11              +               21         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          12              +               22         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          13              +               23         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     11  11      + 12  11 + 12  21 +     13  21      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+|     11  12      | 12  12 | 12  22 |     13  22      |
++-----------------+--------+--------+-----------------+
+|     21  12      | 22  12 | 22  22 |     23  22      |
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     21  13      + 22  13 + 22  23 +     23  23      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+
+If S and T are the source and target shard sizes along some tensor axis.
+Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 3 mesh
+
+sharding = [[0], []]
+
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 |             |
+| 21 22 23 24 25 26 |             |
++-------------------+             |
+| 31 32 33 34 35 36 |             |
+| 41 42 43 44 45 46 |             |
++-------------------+             |
+| 51 52 53 54 55 56 |             |
+| 61 62 63 64 65 66 |             |
++-------------------+             ↓
+
+transform to
+sharding = [[], [0]]
+
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+
+--------------------------------------------------------------
+
+Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.
+
+unsharded 4x4 tensor
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 |             |             |
+| 21 | 22 |             |             |
++----+----+             |             |
+| 13 | 14 |             |             |
+| 23 | 24 |             |             |
++----+----+             ↓             |
++----+----+                           |
+| 31 | 32 |                           |
+| 41...
[truncated]

@sogartar sogartar requested a review from joker-eph December 21, 2023 20:31
@sogartar
Copy link
Contributor Author

@yaochengji, could you review this PR?

@sogartar
Copy link
Contributor Author

sogartar commented Dec 21, 2023

As a part of this PR I added two new ops to get the shape of the mesh and the index of the current device. I can move them into their own PR.

@@ -0,0 +1,621 @@
Reshadring Spmdization Examples
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that this should be here. It is not a user facing doc, but it may be important to illustrate what is going on in resharding and what are the future challenges to make a communication efficient implementation.

@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}

def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
Copy link
Member

@yaochengji yaochengji Jan 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an Int value representing the device/process id, not a list of int?

E.g. https://www.tensorflow.org/xla/operation_semantics#allreduce uses a single int.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you mean to have the correspondence between multi index and flat/linear index.
Do you want such an operation explicitly in the IR that extracts the linear index?
If it is to be used during lowering we can have a C++ function that adds IR to flatten the index. There may be already somewhere in MLIR such a thing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a flat index is what I mean. I think both an explicit IR and implementation during lowering is fine.

But we'd better clearly wrote the doc clearly how the multi-index is converted to flat index.

E.g. a mesh cluster of 2x2, (1, 0) represents device id 1 or 2 might need to be defined.

We can leave this doc to the next PR.

@sogartar sogartar requested a review from yaochengji January 2, 2024 18:45
@yaochengji
Copy link
Member

LGTM, thanks

@sogartar
Copy link
Contributor Author

sogartar commented Jan 2, 2024

I missed to commit the doc markdown file in my last commit.
I also rebased on top of main.

@sogartar sogartar merged commit 1a8fb88 into llvm:main Jan 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants