Skip to content

Commit cd85636

Browse files
committed
fixed corner halos by reversing data-exchanges from high to low dims
1 parent c106abc commit cd85636

File tree

2 files changed

+137
-115
lines changed

2 files changed

+137
-115
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct ConvertUpdateHaloOp
7070

7171
auto array = op.getInput();
7272
auto rank = array.getType().getRank();
73+
auto opSplitAxes = op.getSplitAxes().getAxes();
7374
auto mesh = op.getMesh();
7475
auto meshOp = getMesh(op, symbolTableCollection);
7576
auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
@@ -87,32 +88,54 @@ struct ConvertUpdateHaloOp
8788
// most of the offset/size/stride data is the same for all dims
8889
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
8990
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
90-
SmallVector<OpFoldResult> shape(rank);
91+
SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
92+
auto currHaloDim = -1; // halo sizes are provided for split dimensions only
9193
// we need the actual shape to compute offsets and sizes
92-
for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
94+
for (auto i = 0; i < rank; ++i) {
95+
auto s = array.getType().getShape()[i];
9396
if (ShapedType::isDynamic(s)) {
9497
shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
9598
} else {
9699
shape[i] = rewriter.getIndexAttr(s);
97100
}
101+
102+
if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
103+
++currHaloDim;
104+
// the offsets for lower dim sstarts after their down halo
105+
offsets[i] = haloSizes[currHaloDim * 2];
106+
107+
// prepare shape and offsets of highest dim's halo exchange
108+
auto _haloSz =
109+
rewriter
110+
.create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
111+
toValue(haloSizes[currHaloDim * 2 + 1]))
112+
.getResult();
113+
// the halo shape of lower dims exlude the halos
114+
dimSizes[i] =
115+
rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
116+
.getResult();
117+
} else {
118+
dimSizes[i] = shape[i];
119+
}
98120
}
99121

100122
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
101123
auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
102124
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
103125
auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
126+
104127
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
105128
rewriter.getIndexType());
106129
auto myMultiIndex =
107130
rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
108131
.getResult();
109-
// halo sizes are provided for split dimensions only
110-
auto currHaloDim = 0;
111-
112-
for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
132+
// traverse all split axes from high to low dim
133+
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
134+
auto splitAxes = opSplitAxes[dim];
113135
if (splitAxes.empty()) {
114136
continue;
115137
}
138+
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
116139
// Get the linearized ids of the neighbors (down and up) for the
117140
// given split
118141
auto tmp = rewriter
@@ -124,22 +147,24 @@ struct ConvertUpdateHaloOp
124147
loc, rewriter.getI32Type(), tmp[0]),
125148
rewriter.create<arith::IndexCastOp>(
126149
loc, rewriter.getI32Type(), tmp[1])};
127-
// store for later
128-
auto orgDimSize = shape[dim];
129-
// this dim's offset to the start of the upper halo
130-
auto upperOffset = rewriter.create<arith::SubIOp>(
150+
151+
auto lowerRecvOffset = rewriter.getIndexAttr(0);
152+
auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
153+
auto upperRecvOffset = rewriter.create<arith::SubIOp>(
131154
loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
155+
auto upperSendOffset = rewriter.create<arith::SubIOp>(
156+
loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
132157

133158
// Make sure we send/recv in a way that does not lead to a dead-lock.
134159
// The current approach is by far not optimal, this should be at least
135160
// be a red-black pattern or using MPI_sendrecv.
136161
// Also, buffers should be re-used.
137162
// Still using temporary contiguous buffers for MPI communication...
138163
// Still yielding a "serialized" communication pattern...
139-
auto genSendRecv = [&](auto dim, bool upperHalo) {
164+
auto genSendRecv = [&](bool upperHalo) {
140165
auto orgOffset = offsets[dim];
141-
shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
142-
: haloSizes[currHaloDim * 2];
166+
dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
167+
: haloSizes[currHaloDim * 2];
143168
// Check if we need to send and/or receive
144169
// Processes on the mesh borders have only one neighbor
145170
auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
@@ -149,52 +174,42 @@ struct ConvertUpdateHaloOp
149174
auto hasTo = rewriter.create<arith::CmpIOp>(
150175
loc, arith::CmpIPredicate::sge, to, zero);
151176
auto buffer = rewriter.create<memref::AllocOp>(
152-
loc, shape, array.getType().getElementType());
177+
loc, dimSizes, array.getType().getElementType());
153178
// if has neighbor: copy halo data from array to buffer and send
154179
rewriter.create<scf::IfOp>(
155180
loc, hasTo, [&](OpBuilder &builder, Location loc) {
156-
offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
157-
: OpFoldResult(upperOffset);
181+
offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
182+
: OpFoldResult(upperSendOffset);
158183
auto subview = builder.create<memref::SubViewOp>(
159-
loc, array, offsets, shape, strides);
184+
loc, array, offsets, dimSizes, strides);
160185
builder.create<memref::CopyOp>(loc, subview, buffer);
161186
builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
162187
builder.create<scf::YieldOp>(loc);
163188
});
164189
// if has neighbor: receive halo data into buffer and copy to array
165190
rewriter.create<scf::IfOp>(
166191
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
167-
offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
168-
: OpFoldResult(builder.getIndexAttr(0));
192+
offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
193+
: OpFoldResult(lowerRecvOffset);
169194
builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
170195
auto subview = builder.create<memref::SubViewOp>(
171-
loc, array, offsets, shape, strides);
196+
loc, array, offsets, dimSizes, strides);
172197
builder.create<memref::CopyOp>(loc, buffer, subview);
173198
builder.create<scf::YieldOp>(loc);
174199
});
175200
rewriter.create<memref::DeallocOp>(loc, buffer);
176201
offsets[dim] = orgOffset;
177202
};
178203

179-
genSendRecv(dim, false);
180-
genSendRecv(dim, true);
181-
182-
// prepare shape and offsets for next split dim
183-
auto _haloSz =
184-
rewriter
185-
.create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
186-
toValue(haloSizes[currHaloDim * 2 + 1]))
187-
.getResult();
188-
// the shape for next halo excludes the halo on both ends for the
189-
// current dim
190-
shape[dim] =
191-
rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
192-
.getResult();
193-
// the offsets for next halo starts after the down halo for the
194-
// current dim
195-
offsets[dim] = haloSizes[currHaloDim * 2];
204+
genSendRecv(false);
205+
genSendRecv(true);
206+
207+
// the shape for lower dims include higher dims' halos
208+
dimSizes[dim] = shape[dim];
209+
// -> the offset for higher dims is always 0
210+
offsets[dim] = rewriter.getIndexAttr(0);
196211
// on to next halo
197-
++currHaloDim;
212+
--currHaloDim;
198213
}
199214
rewriter.eraseOp(op);
200215
return mlir::success();

0 commit comments

Comments
 (0)