Skip to content

Commit 628abb6

Browse files
committed
initial hack lowering mesh.update_halo to MPI
1 parent 00c198b commit 628abb6

File tree

9 files changed

+325
-0
lines changed

9 files changed

+325
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
10+
#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
11+
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Support/LLVM.h"
14+
15+
namespace mlir {
16+
class Pass;
17+
18+
#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
22+
/// to MPI primitives.
23+
std::unique_ptr<Pass> createConvertMeshToMPIPass();
24+
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
5252
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
5353
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
54+
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
5455
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
5556
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
5657
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
878878
];
879879
}
880880

881+
//===----------------------------------------------------------------------===//
882+
// MeshToMPI
883+
//===----------------------------------------------------------------------===//
884+
885+
def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
886+
let summary = "Convert Mesh dialect to MPI dialect.";
887+
let description = [{
888+
This pass converts communication operations
889+
from the Mesh dialect to operations from the MPI dialect.
890+
}];
891+
let dependentDialects = [
892+
"memref::MemRefDialect",
893+
"mpi::MPIDialect",
894+
"scf::SCFDialect"
895+
];
896+
}
897+
881898
//===----------------------------------------------------------------------===//
882899
// NVVMToLLVM
883900
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
155155
];
156156
}
157157

158+
def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
159+
Pure,
160+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
161+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
162+
]> {
163+
let summary =
164+
"For given split axes get the linear index the direct neighbor processes.";
165+
let description = [{
166+
Example:
167+
```
168+
%idx = mesh.neighbor_linear_index on @mesh for $device
169+
split_axes = $split_axes : index
170+
```
171+
Given `@mesh` with shape `(10, 20, 30)`,
172+
`device` = `(1, 2, 3)`
173+
`$split_axes` = `[1]`
174+
it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
175+
and `(1, 3, 3)`: `693`.
176+
177+
A negative value is returned if `$device` has no neighbor in the given
178+
direction along the given `split_axes`.
179+
}];
180+
let arguments = (ins FlatSymbolRefAttr:$mesh,
181+
Variadic<Index>:$device,
182+
Mesh_MeshAxesAttr:$split_axes);
183+
let results = (outs Index:$neighbor_down, Index:$neighbor_up);
184+
let assemblyFormat = [{
185+
`on` $mesh `[` $device `]`
186+
`split_axes` `=` $split_axes
187+
attr-dict `:` type(results)
188+
}];
189+
}
190+
158191
//===----------------------------------------------------------------------===//
159192
// Sharding operations.
160193
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
4141
add_subdirectory(MemRefToEmitC)
4242
add_subdirectory(MemRefToLLVM)
4343
add_subdirectory(MemRefToSPIRV)
44+
add_subdirectory(MeshToMPI)
4445
add_subdirectory(NVGPUToNVVM)
4546
add_subdirectory(NVVMToLLVM)
4647
add_subdirectory(OpenACCToSCF)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
add_mlir_conversion_library(MLIRMeshToMPI
2+
MeshToMPI.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRFuncDialect
15+
MLIRIR
16+
MLIRLinalgTransforms
17+
MLIRMemRefDialect
18+
MLIRPass
19+
MLIRMeshDialect
20+
MLIRMPIDialect
21+
MLIRTransforms
22+
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a translation of Mesh communicatin ops tp MPI ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/MPI/IR/MPI.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
19+
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
21+
#include "mlir/IR/Builders.h"
22+
#include "mlir/IR/BuiltinAttributes.h"
23+
#include "mlir/IR/BuiltinTypes.h"
24+
25+
#define DEBUG_TYPE "mesh-to-mpi"
26+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27+
28+
namespace mlir {
29+
#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
30+
#include "mlir/Conversion/Passes.h.inc"
31+
} // namespace mlir
32+
33+
using namespace mlir;
34+
using namespace mlir::mesh;
35+
36+
namespace {
37+
struct ConvertMeshToMPIPass
38+
: public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
39+
using Base::Base;
40+
41+
/// Run the dialect converter on the module.
42+
void runOnOperation() override {
43+
getOperation()->walk([&](UpdateHaloOp op) {
44+
SymbolTableCollection symbolTableCollection;
45+
OpBuilder builder(op);
46+
auto loc = op.getLoc();
47+
48+
auto toValue = [&builder, &loc](OpFoldResult &v) {
49+
return v.is<Value>()
50+
? v.get<Value>()
51+
: builder.create<::mlir::arith::ConstantOp>(
52+
loc,
53+
builder.getIndexAttr(
54+
cast<IntegerAttr>(v.get<Attribute>()).getInt()));
55+
};
56+
57+
auto array = op.getInput();
58+
auto rank = array.getType().getRank();
59+
auto mesh = op.getMesh();
60+
auto meshOp = getMesh(op, symbolTableCollection);
61+
auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
62+
op.getDynamicHaloSizes(), builder);
63+
for (auto &sz : haloSizes) {
64+
if (sz.is<Value>()) {
65+
sz = builder
66+
.create<arith::IndexCastOp>(loc, builder.getIndexType(),
67+
sz.get<Value>())
68+
.getResult();
69+
}
70+
}
71+
72+
SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
73+
SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
74+
SmallVector<OpFoldResult> shape(rank);
75+
for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
76+
if (ShapedType::isDynamic(s)) {
77+
shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
78+
} else {
79+
shape[i] = builder.getIndexAttr(s);
80+
}
81+
}
82+
83+
auto tagAttr = builder.getI32IntegerAttr(91); // whatever
84+
auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
85+
auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
86+
auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
87+
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
88+
builder.getIndexType());
89+
auto myMultiIndex =
90+
builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
91+
.getResult();
92+
auto currHaloDim = 0;
93+
94+
for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
95+
if (!splitAxes.empty()) {
96+
auto tmp = builder
97+
.create<NeighborsLinearIndicesOp>(
98+
loc, mesh, myMultiIndex, splitAxes)
99+
.getResults();
100+
Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
101+
loc, builder.getI32Type(), tmp[0]),
102+
builder.create<arith::IndexCastOp>(
103+
loc, builder.getI32Type(), tmp[1])};
104+
auto orgDimSize = shape[dim];
105+
auto upperOffset = builder.create<arith::SubIOp>(
106+
loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
107+
108+
// make sure we send/recv in a way that does not lead to a dead-lock
109+
// This is by far not optimal, this should be at least MPI_sendrecv
110+
// and - probably even more importantly - buffers should be re-used
111+
// Currently using temporary, contiguous buffer for MPI communication
112+
auto genSendRecv = [&](auto dim, bool upperHalo) {
113+
auto orgOffset = offsets[dim];
114+
shape[dim] =
115+
upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
116+
auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
117+
auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
118+
auto hasFrom = builder.create<arith::CmpIOp>(
119+
loc, arith::CmpIPredicate::sge, from, zero);
120+
auto hasTo = builder.create<arith::CmpIOp>(
121+
loc, arith::CmpIPredicate::sge, to, zero);
122+
auto buffer = builder.create<memref::AllocOp>(
123+
loc, shape, array.getType().getElementType());
124+
builder.create<scf::IfOp>(
125+
loc, hasTo, [&](OpBuilder &builder, Location loc) {
126+
offsets[dim] = upperHalo
127+
? OpFoldResult(builder.getIndexAttr(0))
128+
: OpFoldResult(upperOffset);
129+
auto subview = builder.create<memref::SubViewOp>(
130+
loc, array, offsets, shape, strides);
131+
builder.create<memref::CopyOp>(loc, subview, buffer);
132+
builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
133+
to);
134+
builder.create<scf::YieldOp>(loc);
135+
});
136+
builder.create<scf::IfOp>(
137+
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
138+
offsets[dim] = upperHalo
139+
? OpFoldResult(upperOffset)
140+
: OpFoldResult(builder.getIndexAttr(0));
141+
builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
142+
from);
143+
auto subview = builder.create<memref::SubViewOp>(
144+
loc, array, offsets, shape, strides);
145+
builder.create<memref::CopyOp>(loc, buffer, subview);
146+
builder.create<scf::YieldOp>(loc);
147+
});
148+
builder.create<memref::DeallocOp>(loc, buffer);
149+
offsets[dim] = orgOffset;
150+
};
151+
152+
genSendRecv(dim, false);
153+
genSendRecv(dim, true);
154+
155+
shape[dim] = builder
156+
.create<arith::SubIOp>(
157+
loc, toValue(orgDimSize),
158+
builder
159+
.create<arith::AddIOp>(
160+
loc, toValue(haloSizes[dim * 2]),
161+
toValue(haloSizes[dim * 2 + 1]))
162+
.getResult())
163+
.getResult();
164+
offsets[dim] = haloSizes[dim * 2];
165+
++currHaloDim;
166+
}
167+
}
168+
});
169+
}
170+
};
171+
} // namespace

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
730730
setNameFn(getResult(), "proc_linear_idx");
731731
}
732732

733+
//===----------------------------------------------------------------------===//
734+
// mesh.neighbors_linear_indices op
735+
//===----------------------------------------------------------------------===//
736+
737+
LogicalResult
738+
NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
739+
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
740+
if (failed(mesh)) {
741+
return failure();
742+
}
743+
return success();
744+
}
745+
746+
void NeighborsLinearIndicesOp::getAsmResultNames(
747+
function_ref<void(Value, StringRef)> setNameFn) {
748+
setNameFn(getNeighborDown(), "down_linear_idx");
749+
setNameFn(getNeighborUp(), "up_linear_idx");
750+
}
751+
733752
//===----------------------------------------------------------------------===//
734753
// collective communication ops
735754
//===----------------------------------------------------------------------===//
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
2+
3+
// CHECK: mesh.mesh @mesh0
4+
mesh.mesh @mesh0(shape = 2x2x4)
5+
6+
// -----
7+
8+
// CHECK-LABEL: func @update_halo
9+
func.func @update_halo_1d(
10+
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
11+
%arg0 : memref<12x12xi8>) {
12+
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
13+
// CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
14+
// CHECK-SAME: split_axes = {{\[\[}}0]]
15+
// CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
16+
%c2 = arith.constant 2 : i64
17+
mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
18+
halo_sizes = [2, %c2] : memref<12x12xi8>
19+
return
20+
}
21+
22+
func.func @update_halo_2d(
23+
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
24+
%arg0 : memref<12x12xi8>) {
25+
%c2 = arith.constant 2 : i64
26+
// CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
27+
// CHECK-SAME: split_axes = {{\[\[}}0], [1]]
28+
// CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
29+
// CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
30+
mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
31+
halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
32+
: memref<12x12xi8>
33+
return
34+
}

0 commit comments

Comments
 (0)