Skip to content

Commit 6e57326

Browse files
fschlimbDinistro
andauthored
[mlir][mesh, mpi] More on MeshToMPI (#129048)
- do not create MPI operations if no halo exchange is needed - allow returning sharding information through `!mesh.sharding` (gets converted into a tuple of tensors) - lowering `mesh.shard_shape` including fixes to the operation itself - global symbol `static_mpi_rank` replaced by an DLTI attribute (now aligned with MPIToLLVM) - smaller fixes and some minor cleanup --------- Co-authored-by: Christian Ulmann <[email protected]>
1 parent cd1d9a8 commit 6e57326

File tree

10 files changed

+856
-257
lines changed

10 files changed

+856
-257
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
881881
let description = [{
882882
This pass converts communication operations from the Mesh dialect to the
883883
MPI dialect.
884-
If it finds a global named "static_mpi_rank" it will use that splat value
885-
instead of calling MPI_Comm_rank. This allows optimizations like constant
886-
shape propagation and fusion because shard/partition sizes depend on the
887-
rank.
884+
If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
885+
use that integer value instead of calling MPI_Comm_rank. This allows
886+
optimizations like constant shape propagation and fusion because
887+
shard/partition sizes depend on the rank.
888888
}];
889889
let dependentDialects = [
890890
"memref::MemRefDialect",

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
345345
}];
346346
}
347347

348-
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
349-
let summary = "Get the shard shape of a given process/device.";
348+
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
349+
Pure, AttrSizedOperandSegments,
350+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
351+
]> {
352+
let summary = "Get the shard shape for a given process/device.";
350353
let description = [{
351-
The device/process id is a linearized id of the device/process in the mesh.
354+
The device/process id is a multi-index of the device/process in the mesh.
352355
This operation might be used during spmdization when the shard shape depends
353356
on (non-constant) values used in `mesh.sharding`.
354357
}];
355358
let arguments = (ins
356-
DenseI64ArrayAttr:$shape,
359+
DenseI64ArrayAttr:$dims,
360+
Variadic<Index>:$dims_dynamic,
357361
Mesh_Sharding:$sharding,
358-
Index:$device
362+
DenseI64ArrayAttr:$device,
363+
Variadic<Index>:$device_dynamic
359364
);
360365
let results = (outs Variadic<Index>:$result);
361366
let assemblyFormat = [{
362-
custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
367+
`dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
368+
`sharding` `=` $sharding
369+
`device` `=` custom<DynamicIndexList>($device_dynamic, $device)
370+
attr-dict `:` type(results)
363371
}];
364372
let builders = [
365-
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
373+
OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
366374
];
367375
}
368376

mlir/lib/Conversion/MeshToMPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
1111
Core
1212

1313
LINK_LIBS PUBLIC
14+
MLIRDLTIDialect
1415
MLIRFuncDialect
1516
MLIRIR
1617
MLIRLinalgTransforms

0 commit comments

Comments
 (0)