Skip to content

Commit 80f4b5f

Browse files
committed
dim fixes, proper testing
1 parent 827c300 commit 80f4b5f

File tree

2 files changed

+339
-146
lines changed

2 files changed

+339
-146
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 180 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements a translation of Mesh communicatin ops tp MPI ops.
9+
// This file implements a translation of Mesh communication ops tp MPI ops.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

@@ -21,6 +21,8 @@
2121
#include "mlir/IR/Builders.h"
2222
#include "mlir/IR/BuiltinAttributes.h"
2323
#include "mlir/IR/BuiltinTypes.h"
24+
#include "mlir/IR/PatternMatch.h"
25+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2426

2527
#define DEBUG_TYPE "mesh-to-mpi"
2628
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -34,138 +36,190 @@ using namespace mlir;
3436
using namespace mlir::mesh;
3537

3638
namespace {
39+
40+
// This pattern converts the mesh.update_halo operation to MPI calls
41+
struct ConvertUpdateHaloOp
42+
: public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
43+
using OpRewritePattern::OpRewritePattern;
44+
45+
mlir::LogicalResult
46+
matchAndRewrite(mlir::mesh::UpdateHaloOp op,
47+
mlir::PatternRewriter &rewriter) const override {
48+
// Halos are exchanged as 2 blocks per dimension (one for each side: down
49+
// and up). It is assumed that the last dim in a default memref is
50+
// contiguous, hence iteration starts with the complete halo on the first
51+
// dim which should be contiguous (unless the source is not). The size of
52+
// the exchanged data will decrease when iterating over dimensions. That's
53+
// good because the halos of last dim will be most fragmented.
54+
// memref.subview is used to read and write the halo data from and to the
55+
// local data. subviews and halos have dynamic and static values, so
56+
// OpFoldResults are used whenever possible.
57+
58+
SymbolTableCollection symbolTableCollection;
59+
auto loc = op.getLoc();
60+
61+
// convert a OpFoldResult into a Value
62+
auto toValue = [&rewriter, &loc](OpFoldResult &v) {
63+
return v.is<Value>()
64+
? v.get<Value>()
65+
: rewriter.create<::mlir::arith::ConstantOp>(
66+
loc,
67+
rewriter.getIndexAttr(
68+
cast<IntegerAttr>(v.get<Attribute>()).getInt()));
69+
};
70+
71+
auto array = op.getInput();
72+
auto rank = array.getType().getRank();
73+
auto mesh = op.getMesh();
74+
auto meshOp = getMesh(op, symbolTableCollection);
75+
auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
76+
op.getDynamicHaloSizes(), rewriter);
77+
// subviews need Index values
78+
for (auto &sz : haloSizes) {
79+
if (sz.is<Value>()) {
80+
sz = rewriter
81+
.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
82+
sz.get<Value>())
83+
.getResult();
84+
}
85+
}
86+
87+
// most of the offset/size/stride data is the same for all dims
88+
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
89+
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
90+
SmallVector<OpFoldResult> shape(rank);
91+
// we need the actual shape to compute offsets and sizes
92+
for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
93+
if (ShapedType::isDynamic(s)) {
94+
shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
95+
} else {
96+
shape[i] = rewriter.getIndexAttr(s);
97+
}
98+
}
99+
100+
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
101+
auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
102+
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
103+
auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
104+
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
105+
rewriter.getIndexType());
106+
auto myMultiIndex =
107+
rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
108+
.getResult();
109+
// halo sizes are provided for split dimensions only
110+
auto currHaloDim = 0;
111+
112+
for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
113+
if (splitAxes.empty()) {
114+
continue;
115+
}
116+
// Get the linearized ids of the neighbors (down and up) for the
117+
// given split
118+
auto tmp = rewriter
119+
.create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
120+
splitAxes)
121+
.getResults();
122+
// MPI operates on i32...
123+
Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
124+
loc, rewriter.getI32Type(), tmp[0]),
125+
rewriter.create<arith::IndexCastOp>(
126+
loc, rewriter.getI32Type(), tmp[1])};
127+
// store for later
128+
auto orgDimSize = shape[dim];
129+
// this dim's offset to the start of the upper halo
130+
auto upperOffset = rewriter.create<arith::SubIOp>(
131+
loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
132+
133+
// Make sure we send/recv in a way that does not lead to a dead-lock.
134+
// The current approach is by far not optimal, this should be at least
135+
// be a red-black pattern or using MPI_sendrecv.
136+
// Also, buffers should be re-used.
137+
// Still using temporary contiguous buffers for MPI communication...
138+
// Still yielding a "serialized" communication pattern...
139+
auto genSendRecv = [&](auto dim, bool upperHalo) {
140+
auto orgOffset = offsets[dim];
141+
shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
142+
: haloSizes[currHaloDim * 2];
143+
// Check if we need to send and/or receive
144+
// Processes on the mesh borders have only one neighbor
145+
auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
146+
auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
147+
auto hasFrom = rewriter.create<arith::CmpIOp>(
148+
loc, arith::CmpIPredicate::sge, from, zero);
149+
auto hasTo = rewriter.create<arith::CmpIOp>(
150+
loc, arith::CmpIPredicate::sge, to, zero);
151+
auto buffer = rewriter.create<memref::AllocOp>(
152+
loc, shape, array.getType().getElementType());
153+
// if has neighbor: copy halo data from array to buffer and send
154+
rewriter.create<scf::IfOp>(
155+
loc, hasTo, [&](OpBuilder &builder, Location loc) {
156+
offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
157+
: OpFoldResult(upperOffset);
158+
auto subview = builder.create<memref::SubViewOp>(
159+
loc, array, offsets, shape, strides);
160+
builder.create<memref::CopyOp>(loc, subview, buffer);
161+
builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
162+
builder.create<scf::YieldOp>(loc);
163+
});
164+
// if has neighbor: receive halo data into buffer and copy to array
165+
rewriter.create<scf::IfOp>(
166+
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
167+
offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
168+
: OpFoldResult(builder.getIndexAttr(0));
169+
builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
170+
auto subview = builder.create<memref::SubViewOp>(
171+
loc, array, offsets, shape, strides);
172+
builder.create<memref::CopyOp>(loc, buffer, subview);
173+
builder.create<scf::YieldOp>(loc);
174+
});
175+
rewriter.create<memref::DeallocOp>(loc, buffer);
176+
offsets[dim] = orgOffset;
177+
};
178+
179+
genSendRecv(dim, false);
180+
genSendRecv(dim, true);
181+
182+
// prepare shape and offsets for next split dim
183+
auto _haloSz =
184+
rewriter
185+
.create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
186+
toValue(haloSizes[currHaloDim * 2 + 1]))
187+
.getResult();
188+
// the shape for next halo excludes the halo on both ends for the
189+
// current dim
190+
shape[dim] =
191+
rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
192+
.getResult();
193+
// the offsets for next halo starts after the down halo for the
194+
// current dim
195+
offsets[dim] = haloSizes[currHaloDim * 2];
196+
// on to next halo
197+
++currHaloDim;
198+
}
199+
rewriter.eraseOp(op);
200+
return mlir::success();
201+
}
202+
};
203+
37204
struct ConvertMeshToMPIPass
38205
: public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
39206
using Base::Base;
40207

41208
/// Run the dialect converter on the module.
42209
void runOnOperation() override {
43-
getOperation()->walk([&](UpdateHaloOp op) {
44-
SymbolTableCollection symbolTableCollection;
45-
OpBuilder builder(op);
46-
auto loc = op.getLoc();
47-
48-
auto toValue = [&builder, &loc](OpFoldResult &v) {
49-
return v.is<Value>()
50-
? v.get<Value>()
51-
: builder.create<::mlir::arith::ConstantOp>(
52-
loc,
53-
builder.getIndexAttr(
54-
cast<IntegerAttr>(v.get<Attribute>()).getInt()));
55-
};
210+
auto *ctx = &getContext();
211+
mlir::RewritePatternSet patterns(ctx);
56212

57-
auto array = op.getInput();
58-
auto rank = array.getType().getRank();
59-
auto mesh = op.getMesh();
60-
auto meshOp = getMesh(op, symbolTableCollection);
61-
auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
62-
op.getDynamicHaloSizes(), builder);
63-
for (auto &sz : haloSizes) {
64-
if (sz.is<Value>()) {
65-
sz = builder
66-
.create<arith::IndexCastOp>(loc, builder.getIndexType(),
67-
sz.get<Value>())
68-
.getResult();
69-
}
70-
}
71-
72-
SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
73-
SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
74-
SmallVector<OpFoldResult> shape(rank);
75-
for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
76-
if (ShapedType::isDynamic(s)) {
77-
shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
78-
} else {
79-
shape[i] = builder.getIndexAttr(s);
80-
}
81-
}
213+
patterns.insert<ConvertUpdateHaloOp>(ctx);
82214

83-
auto tagAttr = builder.getI32IntegerAttr(91); // whatever
84-
auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
85-
auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
86-
auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
87-
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
88-
builder.getIndexType());
89-
auto myMultiIndex =
90-
builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
91-
.getResult();
92-
auto currHaloDim = 0;
93-
94-
for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
95-
if (!splitAxes.empty()) {
96-
auto tmp = builder
97-
.create<NeighborsLinearIndicesOp>(
98-
loc, mesh, myMultiIndex, splitAxes)
99-
.getResults();
100-
Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
101-
loc, builder.getI32Type(), tmp[0]),
102-
builder.create<arith::IndexCastOp>(
103-
loc, builder.getI32Type(), tmp[1])};
104-
auto orgDimSize = shape[dim];
105-
auto upperOffset = builder.create<arith::SubIOp>(
106-
loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
107-
108-
// make sure we send/recv in a way that does not lead to a dead-lock
109-
// This is by far not optimal, this should be at least MPI_sendrecv
110-
// and - probably even more importantly - buffers should be re-used
111-
// Currently using temporary, contiguous buffer for MPI communication
112-
auto genSendRecv = [&](auto dim, bool upperHalo) {
113-
auto orgOffset = offsets[dim];
114-
shape[dim] =
115-
upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
116-
auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
117-
auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
118-
auto hasFrom = builder.create<arith::CmpIOp>(
119-
loc, arith::CmpIPredicate::sge, from, zero);
120-
auto hasTo = builder.create<arith::CmpIOp>(
121-
loc, arith::CmpIPredicate::sge, to, zero);
122-
auto buffer = builder.create<memref::AllocOp>(
123-
loc, shape, array.getType().getElementType());
124-
builder.create<scf::IfOp>(
125-
loc, hasTo, [&](OpBuilder &builder, Location loc) {
126-
offsets[dim] = upperHalo
127-
? OpFoldResult(builder.getIndexAttr(0))
128-
: OpFoldResult(upperOffset);
129-
auto subview = builder.create<memref::SubViewOp>(
130-
loc, array, offsets, shape, strides);
131-
builder.create<memref::CopyOp>(loc, subview, buffer);
132-
builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
133-
to);
134-
builder.create<scf::YieldOp>(loc);
135-
});
136-
builder.create<scf::IfOp>(
137-
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
138-
offsets[dim] = upperHalo
139-
? OpFoldResult(upperOffset)
140-
: OpFoldResult(builder.getIndexAttr(0));
141-
builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
142-
from);
143-
auto subview = builder.create<memref::SubViewOp>(
144-
loc, array, offsets, shape, strides);
145-
builder.create<memref::CopyOp>(loc, buffer, subview);
146-
builder.create<scf::YieldOp>(loc);
147-
});
148-
builder.create<memref::DeallocOp>(loc, buffer);
149-
offsets[dim] = orgOffset;
150-
};
151-
152-
genSendRecv(dim, false);
153-
genSendRecv(dim, true);
154-
155-
shape[dim] = builder
156-
.create<arith::SubIOp>(
157-
loc, toValue(orgDimSize),
158-
builder
159-
.create<arith::AddIOp>(
160-
loc, toValue(haloSizes[dim * 2]),
161-
toValue(haloSizes[dim * 2 + 1]))
162-
.getResult())
163-
.getResult();
164-
offsets[dim] = haloSizes[dim * 2];
165-
++currHaloDim;
166-
}
167-
}
168-
});
215+
(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
216+
std::move(patterns));
169217
}
170218
};
171-
} // namespace
219+
220+
} // namespace
221+
222+
// Create a pass that convert Mesh to MPI
223+
std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
224+
return std::make_unique<ConvertMeshToMPIPass>();
225+
}

0 commit comments

Comments
 (0)