6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file implements a translation of Mesh communicatin ops tp MPI ops.
9
+ // This file implements a translation of Mesh communication ops tp MPI ops.
10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
12
21
21
#include " mlir/IR/Builders.h"
22
22
#include " mlir/IR/BuiltinAttributes.h"
23
23
#include " mlir/IR/BuiltinTypes.h"
24
+ #include " mlir/IR/PatternMatch.h"
25
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
24
26
25
27
#define DEBUG_TYPE " mesh-to-mpi"
26
28
#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
@@ -34,138 +36,190 @@ using namespace mlir;
34
36
using namespace mlir ::mesh;
35
37
36
38
namespace {
39
+
40
+ // This pattern converts the mesh.update_halo operation to MPI calls
41
+ struct ConvertUpdateHaloOp
42
+ : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
43
+ using OpRewritePattern::OpRewritePattern;
44
+
45
+ mlir::LogicalResult
46
+ matchAndRewrite (mlir::mesh::UpdateHaloOp op,
47
+ mlir::PatternRewriter &rewriter) const override {
48
+ // Halos are exchanged as 2 blocks per dimension (one for each side: down
49
+ // and up). It is assumed that the last dim in a default memref is
50
+ // contiguous, hence iteration starts with the complete halo on the first
51
+ // dim which should be contiguous (unless the source is not). The size of
52
+ // the exchanged data will decrease when iterating over dimensions. That's
53
+ // good because the halos of last dim will be most fragmented.
54
+ // memref.subview is used to read and write the halo data from and to the
55
+ // local data. subviews and halos have dynamic and static values, so
56
+ // OpFoldResults are used whenever possible.
57
+
58
+ SymbolTableCollection symbolTableCollection;
59
+ auto loc = op.getLoc ();
60
+
61
+ // convert a OpFoldResult into a Value
62
+ auto toValue = [&rewriter, &loc](OpFoldResult &v) {
63
+ return v.is <Value>()
64
+ ? v.get <Value>()
65
+ : rewriter.create <::mlir::arith::ConstantOp>(
66
+ loc,
67
+ rewriter.getIndexAttr (
68
+ cast<IntegerAttr>(v.get <Attribute>()).getInt ()));
69
+ };
70
+
71
+ auto array = op.getInput ();
72
+ auto rank = array.getType ().getRank ();
73
+ auto mesh = op.getMesh ();
74
+ auto meshOp = getMesh (op, symbolTableCollection);
75
+ auto haloSizes = getMixedValues (op.getStaticHaloSizes (),
76
+ op.getDynamicHaloSizes (), rewriter);
77
+ // subviews need Index values
78
+ for (auto &sz : haloSizes) {
79
+ if (sz.is <Value>()) {
80
+ sz = rewriter
81
+ .create <arith::IndexCastOp>(loc, rewriter.getIndexType (),
82
+ sz.get <Value>())
83
+ .getResult ();
84
+ }
85
+ }
86
+
87
+ // most of the offset/size/stride data is the same for all dims
88
+ SmallVector<OpFoldResult> offsets (rank, rewriter.getIndexAttr (0 ));
89
+ SmallVector<OpFoldResult> strides (rank, rewriter.getIndexAttr (1 ));
90
+ SmallVector<OpFoldResult> shape (rank);
91
+ // we need the actual shape to compute offsets and sizes
92
+ for (auto [i, s] : llvm::enumerate (array.getType ().getShape ())) {
93
+ if (ShapedType::isDynamic (s)) {
94
+ shape[i] = rewriter.create <memref::DimOp>(loc, array, s).getResult ();
95
+ } else {
96
+ shape[i] = rewriter.getIndexAttr (s);
97
+ }
98
+ }
99
+
100
+ auto tagAttr = rewriter.getI32IntegerAttr (91 ); // we just pick something
101
+ auto tag = rewriter.create <::mlir::arith::ConstantOp>(loc, tagAttr);
102
+ auto zeroAttr = rewriter.getI32IntegerAttr (0 ); // for detecting v<0
103
+ auto zero = rewriter.create <::mlir::arith::ConstantOp>(loc, zeroAttr);
104
+ SmallVector<Type> indexResultTypes (meshOp.getShape ().size (),
105
+ rewriter.getIndexType ());
106
+ auto myMultiIndex =
107
+ rewriter.create <ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
108
+ .getResult ();
109
+ // halo sizes are provided for split dimensions only
110
+ auto currHaloDim = 0 ;
111
+
112
+ for (auto [dim, splitAxes] : llvm::enumerate (op.getSplitAxes ())) {
113
+ if (splitAxes.empty ()) {
114
+ continue ;
115
+ }
116
+ // Get the linearized ids of the neighbors (down and up) for the
117
+ // given split
118
+ auto tmp = rewriter
119
+ .create <NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
120
+ splitAxes)
121
+ .getResults ();
122
+ // MPI operates on i32...
123
+ Value neighbourIDs[2 ] = {rewriter.create <arith::IndexCastOp>(
124
+ loc, rewriter.getI32Type (), tmp[0 ]),
125
+ rewriter.create <arith::IndexCastOp>(
126
+ loc, rewriter.getI32Type (), tmp[1 ])};
127
+ // store for later
128
+ auto orgDimSize = shape[dim];
129
+ // this dim's offset to the start of the upper halo
130
+ auto upperOffset = rewriter.create <arith::SubIOp>(
131
+ loc, toValue (shape[dim]), toValue (haloSizes[currHaloDim * 2 + 1 ]));
132
+
133
+ // Make sure we send/recv in a way that does not lead to a dead-lock.
134
+ // The current approach is by far not optimal, this should be at least
135
+ // be a red-black pattern or using MPI_sendrecv.
136
+ // Also, buffers should be re-used.
137
+ // Still using temporary contiguous buffers for MPI communication...
138
+ // Still yielding a "serialized" communication pattern...
139
+ auto genSendRecv = [&](auto dim, bool upperHalo) {
140
+ auto orgOffset = offsets[dim];
141
+ shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1 ]
142
+ : haloSizes[currHaloDim * 2 ];
143
+ // Check if we need to send and/or receive
144
+ // Processes on the mesh borders have only one neighbor
145
+ auto to = upperHalo ? neighbourIDs[1 ] : neighbourIDs[0 ];
146
+ auto from = upperHalo ? neighbourIDs[0 ] : neighbourIDs[1 ];
147
+ auto hasFrom = rewriter.create <arith::CmpIOp>(
148
+ loc, arith::CmpIPredicate::sge, from, zero);
149
+ auto hasTo = rewriter.create <arith::CmpIOp>(
150
+ loc, arith::CmpIPredicate::sge, to, zero);
151
+ auto buffer = rewriter.create <memref::AllocOp>(
152
+ loc, shape, array.getType ().getElementType ());
153
+ // if has neighbor: copy halo data from array to buffer and send
154
+ rewriter.create <scf::IfOp>(
155
+ loc, hasTo, [&](OpBuilder &builder, Location loc) {
156
+ offsets[dim] = upperHalo ? OpFoldResult (builder.getIndexAttr (0 ))
157
+ : OpFoldResult (upperOffset);
158
+ auto subview = builder.create <memref::SubViewOp>(
159
+ loc, array, offsets, shape, strides);
160
+ builder.create <memref::CopyOp>(loc, subview, buffer);
161
+ builder.create <mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
162
+ builder.create <scf::YieldOp>(loc);
163
+ });
164
+ // if has neighbor: receive halo data into buffer and copy to array
165
+ rewriter.create <scf::IfOp>(
166
+ loc, hasFrom, [&](OpBuilder &builder, Location loc) {
167
+ offsets[dim] = upperHalo ? OpFoldResult (upperOffset)
168
+ : OpFoldResult (builder.getIndexAttr (0 ));
169
+ builder.create <mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
170
+ auto subview = builder.create <memref::SubViewOp>(
171
+ loc, array, offsets, shape, strides);
172
+ builder.create <memref::CopyOp>(loc, buffer, subview);
173
+ builder.create <scf::YieldOp>(loc);
174
+ });
175
+ rewriter.create <memref::DeallocOp>(loc, buffer);
176
+ offsets[dim] = orgOffset;
177
+ };
178
+
179
+ genSendRecv (dim, false );
180
+ genSendRecv (dim, true );
181
+
182
+ // prepare shape and offsets for next split dim
183
+ auto _haloSz =
184
+ rewriter
185
+ .create <arith::AddIOp>(loc, toValue (haloSizes[currHaloDim * 2 ]),
186
+ toValue (haloSizes[currHaloDim * 2 + 1 ]))
187
+ .getResult ();
188
+ // the shape for next halo excludes the halo on both ends for the
189
+ // current dim
190
+ shape[dim] =
191
+ rewriter.create <arith::SubIOp>(loc, toValue (orgDimSize), _haloSz)
192
+ .getResult ();
193
+ // the offsets for next halo starts after the down halo for the
194
+ // current dim
195
+ offsets[dim] = haloSizes[currHaloDim * 2 ];
196
+ // on to next halo
197
+ ++currHaloDim;
198
+ }
199
+ rewriter.eraseOp (op);
200
+ return mlir::success ();
201
+ }
202
+ };
203
+
37
204
struct ConvertMeshToMPIPass
38
205
: public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
39
206
using Base::Base;
40
207
41
208
// / Run the dialect converter on the module.
42
209
void runOnOperation () override {
43
- getOperation ()->walk ([&](UpdateHaloOp op) {
44
- SymbolTableCollection symbolTableCollection;
45
- OpBuilder builder (op);
46
- auto loc = op.getLoc ();
47
-
48
- auto toValue = [&builder, &loc](OpFoldResult &v) {
49
- return v.is <Value>()
50
- ? v.get <Value>()
51
- : builder.create <::mlir::arith::ConstantOp>(
52
- loc,
53
- builder.getIndexAttr (
54
- cast<IntegerAttr>(v.get <Attribute>()).getInt ()));
55
- };
210
+ auto *ctx = &getContext ();
211
+ mlir::RewritePatternSet patterns (ctx);
56
212
57
- auto array = op.getInput ();
58
- auto rank = array.getType ().getRank ();
59
- auto mesh = op.getMesh ();
60
- auto meshOp = getMesh (op, symbolTableCollection);
61
- auto haloSizes = getMixedValues (op.getStaticHaloSizes (),
62
- op.getDynamicHaloSizes (), builder);
63
- for (auto &sz : haloSizes) {
64
- if (sz.is <Value>()) {
65
- sz = builder
66
- .create <arith::IndexCastOp>(loc, builder.getIndexType (),
67
- sz.get <Value>())
68
- .getResult ();
69
- }
70
- }
71
-
72
- SmallVector<OpFoldResult> offsets (rank, builder.getIndexAttr (0 ));
73
- SmallVector<OpFoldResult> strides (rank, builder.getIndexAttr (1 ));
74
- SmallVector<OpFoldResult> shape (rank);
75
- for (auto [i, s] : llvm::enumerate (array.getType ().getShape ())) {
76
- if (ShapedType::isDynamic (s)) {
77
- shape[i] = builder.create <memref::DimOp>(loc, array, s).getResult ();
78
- } else {
79
- shape[i] = builder.getIndexAttr (s);
80
- }
81
- }
213
+ patterns.insert <ConvertUpdateHaloOp>(ctx);
82
214
83
- auto tagAttr = builder.getI32IntegerAttr (91 ); // whatever
84
- auto tag = builder.create <::mlir::arith::ConstantOp>(loc, tagAttr);
85
- auto zeroAttr = builder.getI32IntegerAttr (0 ); // whatever
86
- auto zero = builder.create <::mlir::arith::ConstantOp>(loc, zeroAttr);
87
- SmallVector<Type> indexResultTypes (meshOp.getShape ().size (),
88
- builder.getIndexType ());
89
- auto myMultiIndex =
90
- builder.create <ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
91
- .getResult ();
92
- auto currHaloDim = 0 ;
93
-
94
- for (auto [dim, splitAxes] : llvm::enumerate (op.getSplitAxes ())) {
95
- if (!splitAxes.empty ()) {
96
- auto tmp = builder
97
- .create <NeighborsLinearIndicesOp>(
98
- loc, mesh, myMultiIndex, splitAxes)
99
- .getResults ();
100
- Value neighbourIDs[2 ] = {builder.create <arith::IndexCastOp>(
101
- loc, builder.getI32Type (), tmp[0 ]),
102
- builder.create <arith::IndexCastOp>(
103
- loc, builder.getI32Type (), tmp[1 ])};
104
- auto orgDimSize = shape[dim];
105
- auto upperOffset = builder.create <arith::SubIOp>(
106
- loc, toValue (shape[dim]), toValue (haloSizes[dim * 2 + 1 ]));
107
-
108
- // make sure we send/recv in a way that does not lead to a dead-lock
109
- // This is by far not optimal, this should be at least MPI_sendrecv
110
- // and - probably even more importantly - buffers should be re-used
111
- // Currently using temporary, contiguous buffer for MPI communication
112
- auto genSendRecv = [&](auto dim, bool upperHalo) {
113
- auto orgOffset = offsets[dim];
114
- shape[dim] =
115
- upperHalo ? haloSizes[dim * 2 + 1 ] : haloSizes[dim * 2 ];
116
- auto to = upperHalo ? neighbourIDs[1 ] : neighbourIDs[0 ];
117
- auto from = upperHalo ? neighbourIDs[0 ] : neighbourIDs[1 ];
118
- auto hasFrom = builder.create <arith::CmpIOp>(
119
- loc, arith::CmpIPredicate::sge, from, zero);
120
- auto hasTo = builder.create <arith::CmpIOp>(
121
- loc, arith::CmpIPredicate::sge, to, zero);
122
- auto buffer = builder.create <memref::AllocOp>(
123
- loc, shape, array.getType ().getElementType ());
124
- builder.create <scf::IfOp>(
125
- loc, hasTo, [&](OpBuilder &builder, Location loc) {
126
- offsets[dim] = upperHalo
127
- ? OpFoldResult (builder.getIndexAttr (0 ))
128
- : OpFoldResult (upperOffset);
129
- auto subview = builder.create <memref::SubViewOp>(
130
- loc, array, offsets, shape, strides);
131
- builder.create <memref::CopyOp>(loc, subview, buffer);
132
- builder.create <mpi::SendOp>(loc, TypeRange{}, buffer, tag,
133
- to);
134
- builder.create <scf::YieldOp>(loc);
135
- });
136
- builder.create <scf::IfOp>(
137
- loc, hasFrom, [&](OpBuilder &builder, Location loc) {
138
- offsets[dim] = upperHalo
139
- ? OpFoldResult (upperOffset)
140
- : OpFoldResult (builder.getIndexAttr (0 ));
141
- builder.create <mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
142
- from);
143
- auto subview = builder.create <memref::SubViewOp>(
144
- loc, array, offsets, shape, strides);
145
- builder.create <memref::CopyOp>(loc, buffer, subview);
146
- builder.create <scf::YieldOp>(loc);
147
- });
148
- builder.create <memref::DeallocOp>(loc, buffer);
149
- offsets[dim] = orgOffset;
150
- };
151
-
152
- genSendRecv (dim, false );
153
- genSendRecv (dim, true );
154
-
155
- shape[dim] = builder
156
- .create <arith::SubIOp>(
157
- loc, toValue (orgDimSize),
158
- builder
159
- .create <arith::AddIOp>(
160
- loc, toValue (haloSizes[dim * 2 ]),
161
- toValue (haloSizes[dim * 2 + 1 ]))
162
- .getResult ())
163
- .getResult ();
164
- offsets[dim] = haloSizes[dim * 2 ];
165
- ++currHaloDim;
166
- }
167
- }
168
- });
215
+ (void )mlir::applyPatternsAndFoldGreedily (getOperation (),
216
+ std::move (patterns));
169
217
}
170
218
};
171
- } // namespace
219
+
220
+ } // namespace
221
+
222
+ // Create a pass that convert Mesh to MPI
223
+ std::unique_ptr<::mlir::OperationPass<void >> createConvertMeshToMPIPass () {
224
+ return std::make_unique<ConvertMeshToMPIPass>();
225
+ }
0 commit comments