Skip to content

Commit 79eb406

Browse files
authored
[mlir][mesh, MPI] Mesh2mpi (#104566)
Pass for lowering `Mesh` to `MPI`. Initial commit lowers `UpdateHaloOp` only.
1 parent 0c0f765 commit 79eb406

File tree

14 files changed

+836
-25
lines changed

14 files changed

+836
-25
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
10+
#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
11+
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Support/LLVM.h"
14+
15+
namespace mlir {
16+
class Pass;
17+
18+
#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
22+
/// to MPI primitives.
23+
std::unique_ptr<::mlir::Pass> createConvertMeshToMPIPass();
24+
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
5252
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
5353
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
54+
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
5455
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
5556
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
5657
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,29 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
883883
];
884884
}
885885

886+
//===----------------------------------------------------------------------===//
887+
// MeshToMPI
888+
//===----------------------------------------------------------------------===//
889+
890+
def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
891+
let summary = "Convert Mesh dialect to MPI dialect.";
892+
let constructor = "mlir::createConvertMeshToMPIPass()";
893+
let description = [{
894+
This pass converts communication operations from the Mesh dialect to the
895+
MPI dialect.
896+
If it finds a global named "static_mpi_rank" it will use that splat value
897+
instead of calling MPI_Comm_rank. This allows optimizations like constant
898+
shape propagation and fusion because shard/partition sizes depend on the
899+
rank.
900+
}];
901+
let dependentDialects = [
902+
"memref::MemRefDialect",
903+
"mpi::MPIDialect",
904+
"scf::SCFDialect",
905+
"bufferization::BufferizationDialect"
906+
];
907+
}
908+
886909
//===----------------------------------------------------------------------===//
887910
// NVVMToLLVM
888911
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
8484
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
8585
"type($ref) `,` type($tag) `,` type($rank)"
8686
"(`->` type($retval)^)?";
87+
let hasCanonicalizer = 1;
8788
}
8889

8990
//===----------------------------------------------------------------------===//
@@ -114,6 +115,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
114115
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
115116
"type($ref) `,` type($tag) `,` type($rank)"
116117
"(`->` type($retval)^)?";
118+
let hasCanonicalizer = 1;
117119
}
118120

119121

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,40 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
156156
];
157157
}
158158

159+
def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
160+
Pure,
161+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
162+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
163+
]> {
164+
let summary =
165+
"For given mesh index get the linear indices of the direct neighbor processes along the given split.";
166+
let description = [{
167+
Example:
168+
```
169+
mesh.mesh @mesh0(shape = 10x20x30)
170+
%c1 = arith.constant 1 : index
171+
%c2 = arith.constant 2 : index
172+
%c3 = arith.constant 3 : index
173+
%idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
174+
```
175+
The above returns two indices, `633` and `693`, which correspond to the
176+
index of the previous process `(1, 1, 3)`, and the next process
177+
`(1, 3, 3) along the split axis `1`.
178+
179+
A negative value is returned if there is no neighbor in the respective
180+
direction along the given `split_axes`.
181+
}];
182+
let arguments = (ins FlatSymbolRefAttr:$mesh,
183+
Variadic<Index>:$device,
184+
Mesh_MeshAxesAttr:$split_axes);
185+
let results = (outs Index:$neighbor_down, Index:$neighbor_up);
186+
let assemblyFormat = [{
187+
`on` $mesh `[` $device `]`
188+
`split_axes` `=` $split_axes
189+
attr-dict `:` type(results)
190+
}];
191+
}
192+
159193
//===----------------------------------------------------------------------===//
160194
// Sharding operations.
161195
//===----------------------------------------------------------------------===//
@@ -1058,12 +1092,12 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
10581092
}
10591093

10601094
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
1095+
Pure,
10611096
DestinationStyleOpInterface,
10621097
TypesMatchWith<
10631098
"result has same type as destination",
10641099
"result", "destination", "$_self">,
1065-
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1066-
AttrSizedOperandSegments
1100+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
10671101
]> {
10681102
let summary = "Update halo data.";
10691103
let description = [{
@@ -1072,7 +1106,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10721106
on the remote devices. Changes might be caused by mutating operations
10731107
and/or if the new halo regions are larger than the existing ones.
10741108

1075-
Source and destination might have different halo sizes.
1109+
Destination is supposed to be initialized with the local data (not halos).
10761110

10771111
Assumes all devices hold tensors with same-sized halo data as specified
10781112
by `source_halo_sizes/static_source_halo_sizes` and
@@ -1084,25 +1118,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10841118

10851119
}];
10861120
let arguments = (ins
1087-
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
10881121
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
10891122
FlatSymbolRefAttr:$mesh,
10901123
Mesh_MeshAxesArrayAttr:$split_axes,
1091-
Variadic<I64>:$source_halo_sizes,
1092-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1093-
Variadic<I64>:$destination_halo_sizes,
1094-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
1124+
Variadic<I64>:$halo_sizes,
1125+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
10951126
);
10961127
let results = (outs
10971128
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
10981129
);
10991130
let assemblyFormat = [{
1100-
$source `into` $destination
1131+
$destination
11011132
`on` $mesh
11021133
`split_axes` `=` $split_axes
1103-
(`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1104-
(`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
1105-
attr-dict `:` type($source) `->` type($result)
1134+
(`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
1135+
attr-dict `:` type($result)
11061136
}];
11071137
let extraClassDeclaration = [{
11081138
MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
4141
add_subdirectory(MemRefToEmitC)
4242
add_subdirectory(MemRefToLLVM)
4343
add_subdirectory(MemRefToSPIRV)
44+
add_subdirectory(MeshToMPI)
4445
add_subdirectory(NVGPUToNVVM)
4546
add_subdirectory(NVVMToLLVM)
4647
add_subdirectory(OpenACCToSCF)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
add_mlir_conversion_library(MLIRMeshToMPI
2+
MeshToMPI.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRFuncDialect
15+
MLIRIR
16+
MLIRLinalgTransforms
17+
MLIRMemRefDialect
18+
MLIRPass
19+
MLIRMeshDialect
20+
MLIRMPIDialect
21+
MLIRTransforms
22+
)

0 commit comments

Comments
 (0)