@@ -70,6 +70,7 @@ struct ConvertUpdateHaloOp
70
70
71
71
auto array = op.getInput ();
72
72
auto rank = array.getType ().getRank ();
73
+ auto opSplitAxes = op.getSplitAxes ().getAxes ();
73
74
auto mesh = op.getMesh ();
74
75
auto meshOp = getMesh (op, symbolTableCollection);
75
76
auto haloSizes = getMixedValues (op.getStaticHaloSizes (),
@@ -87,32 +88,54 @@ struct ConvertUpdateHaloOp
87
88
// most of the offset/size/stride data is the same for all dims
88
89
SmallVector<OpFoldResult> offsets (rank, rewriter.getIndexAttr (0 ));
89
90
SmallVector<OpFoldResult> strides (rank, rewriter.getIndexAttr (1 ));
90
- SmallVector<OpFoldResult> shape (rank);
91
+ SmallVector<OpFoldResult> shape (rank), dimSizes (rank);
92
+ auto currHaloDim = -1 ; // halo sizes are provided for split dimensions only
91
93
// we need the actual shape to compute offsets and sizes
92
- for (auto [i, s] : llvm::enumerate (array.getType ().getShape ())) {
94
+ for (auto i = 0 ; i < rank; ++i) {
95
+ auto s = array.getType ().getShape ()[i];
93
96
if (ShapedType::isDynamic (s)) {
94
97
shape[i] = rewriter.create <memref::DimOp>(loc, array, s).getResult ();
95
98
} else {
96
99
shape[i] = rewriter.getIndexAttr (s);
97
100
}
101
+
102
+ if ((size_t )i < opSplitAxes.size () && !opSplitAxes[i].empty ()) {
103
+ ++currHaloDim;
104
+ // the offsets for lower dim sstarts after their down halo
105
+ offsets[i] = haloSizes[currHaloDim * 2 ];
106
+
107
+ // prepare shape and offsets of highest dim's halo exchange
108
+ auto _haloSz =
109
+ rewriter
110
+ .create <arith::AddIOp>(loc, toValue (haloSizes[currHaloDim * 2 ]),
111
+ toValue (haloSizes[currHaloDim * 2 + 1 ]))
112
+ .getResult ();
113
+ // the halo shape of lower dims exlude the halos
114
+ dimSizes[i] =
115
+ rewriter.create <arith::SubIOp>(loc, toValue (shape[i]), _haloSz)
116
+ .getResult ();
117
+ } else {
118
+ dimSizes[i] = shape[i];
119
+ }
98
120
}
99
121
100
122
auto tagAttr = rewriter.getI32IntegerAttr (91 ); // we just pick something
101
123
auto tag = rewriter.create <::mlir::arith::ConstantOp>(loc, tagAttr);
102
124
auto zeroAttr = rewriter.getI32IntegerAttr (0 ); // for detecting v<0
103
125
auto zero = rewriter.create <::mlir::arith::ConstantOp>(loc, zeroAttr);
126
+
104
127
SmallVector<Type> indexResultTypes (meshOp.getShape ().size (),
105
128
rewriter.getIndexType ());
106
129
auto myMultiIndex =
107
130
rewriter.create <ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
108
131
.getResult ();
109
- // halo sizes are provided for split dimensions only
110
- auto currHaloDim = 0 ;
111
-
112
- for (auto [dim, splitAxes] : llvm::enumerate (op.getSplitAxes ())) {
132
+ // traverse all split axes from high to low dim
133
+ for (ssize_t dim = opSplitAxes.size () - 1 ; dim >= 0 ; --dim) {
134
+ auto splitAxes = opSplitAxes[dim];
113
135
if (splitAxes.empty ()) {
114
136
continue ;
115
137
}
138
+ assert (currHaloDim >= 0 && (size_t )currHaloDim < haloSizes.size () / 2 );
116
139
// Get the linearized ids of the neighbors (down and up) for the
117
140
// given split
118
141
auto tmp = rewriter
@@ -124,22 +147,24 @@ struct ConvertUpdateHaloOp
124
147
loc, rewriter.getI32Type (), tmp[0 ]),
125
148
rewriter.create <arith::IndexCastOp>(
126
149
loc, rewriter.getI32Type (), tmp[1 ])};
127
- // store for later
128
- auto orgDimSize = shape[dim] ;
129
- // this dim's offset to the start of the upper halo
130
- auto upperOffset = rewriter.create <arith::SubIOp>(
150
+
151
+ auto lowerRecvOffset = rewriter. getIndexAttr ( 0 ) ;
152
+ auto lowerSendOffset = toValue (haloSizes[currHaloDim * 2 ]);
153
+ auto upperRecvOffset = rewriter.create <arith::SubIOp>(
131
154
loc, toValue (shape[dim]), toValue (haloSizes[currHaloDim * 2 + 1 ]));
155
+ auto upperSendOffset = rewriter.create <arith::SubIOp>(
156
+ loc, upperRecvOffset, toValue (haloSizes[currHaloDim * 2 ]));
132
157
133
158
// Make sure we send/recv in a way that does not lead to a dead-lock.
134
159
// The current approach is by far not optimal, this should be at least
135
160
// be a red-black pattern or using MPI_sendrecv.
136
161
// Also, buffers should be re-used.
137
162
// Still using temporary contiguous buffers for MPI communication...
138
163
// Still yielding a "serialized" communication pattern...
139
- auto genSendRecv = [&](auto dim, bool upperHalo) {
164
+ auto genSendRecv = [&](bool upperHalo) {
140
165
auto orgOffset = offsets[dim];
141
- shape [dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1 ]
142
- : haloSizes[currHaloDim * 2 ];
166
+ dimSizes [dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1 ]
167
+ : haloSizes[currHaloDim * 2 ];
143
168
// Check if we need to send and/or receive
144
169
// Processes on the mesh borders have only one neighbor
145
170
auto to = upperHalo ? neighbourIDs[1 ] : neighbourIDs[0 ];
@@ -149,52 +174,42 @@ struct ConvertUpdateHaloOp
149
174
auto hasTo = rewriter.create <arith::CmpIOp>(
150
175
loc, arith::CmpIPredicate::sge, to, zero);
151
176
auto buffer = rewriter.create <memref::AllocOp>(
152
- loc, shape , array.getType ().getElementType ());
177
+ loc, dimSizes , array.getType ().getElementType ());
153
178
// if has neighbor: copy halo data from array to buffer and send
154
179
rewriter.create <scf::IfOp>(
155
180
loc, hasTo, [&](OpBuilder &builder, Location loc) {
156
- offsets[dim] = upperHalo ? OpFoldResult (builder. getIndexAttr ( 0 ) )
157
- : OpFoldResult (upperOffset );
181
+ offsets[dim] = upperHalo ? OpFoldResult (lowerSendOffset )
182
+ : OpFoldResult (upperSendOffset );
158
183
auto subview = builder.create <memref::SubViewOp>(
159
- loc, array, offsets, shape , strides);
184
+ loc, array, offsets, dimSizes , strides);
160
185
builder.create <memref::CopyOp>(loc, subview, buffer);
161
186
builder.create <mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
162
187
builder.create <scf::YieldOp>(loc);
163
188
});
164
189
// if has neighbor: receive halo data into buffer and copy to array
165
190
rewriter.create <scf::IfOp>(
166
191
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
167
- offsets[dim] = upperHalo ? OpFoldResult (upperOffset )
168
- : OpFoldResult (builder. getIndexAttr ( 0 ) );
192
+ offsets[dim] = upperHalo ? OpFoldResult (upperRecvOffset )
193
+ : OpFoldResult (lowerRecvOffset );
169
194
builder.create <mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
170
195
auto subview = builder.create <memref::SubViewOp>(
171
- loc, array, offsets, shape , strides);
196
+ loc, array, offsets, dimSizes , strides);
172
197
builder.create <memref::CopyOp>(loc, buffer, subview);
173
198
builder.create <scf::YieldOp>(loc);
174
199
});
175
200
rewriter.create <memref::DeallocOp>(loc, buffer);
176
201
offsets[dim] = orgOffset;
177
202
};
178
203
179
- genSendRecv (dim, false );
180
- genSendRecv (dim, true );
181
-
182
- // prepare shape and offsets for next split dim
183
- auto _haloSz =
184
- rewriter
185
- .create <arith::AddIOp>(loc, toValue (haloSizes[currHaloDim * 2 ]),
186
- toValue (haloSizes[currHaloDim * 2 + 1 ]))
187
- .getResult ();
188
- // the shape for next halo excludes the halo on both ends for the
189
- // current dim
190
- shape[dim] =
191
- rewriter.create <arith::SubIOp>(loc, toValue (orgDimSize), _haloSz)
192
- .getResult ();
193
- // the offsets for next halo starts after the down halo for the
194
- // current dim
195
- offsets[dim] = haloSizes[currHaloDim * 2 ];
204
+ genSendRecv (false );
205
+ genSendRecv (true );
206
+
207
+ // the shape for lower dims include higher dims' halos
208
+ dimSizes[dim] = shape[dim];
209
+ // -> the offset for higher dims is always 0
210
+ offsets[dim] = rewriter.getIndexAttr (0 );
196
211
// on to next halo
197
- ++ currHaloDim;
212
+ -- currHaloDim;
198
213
}
199
214
rewriter.eraseOp (op);
200
215
return mlir::success ();
0 commit comments