Skip to content

Commit 49f080a

Browse files
fschlimbmofeing
andauthored
[mlir][mpi] Mandatory Communicator (llvm#133280)
This is replacing llvm#125361 - communicator is mandatory - new mpi.comm_world - new mp.comm_split - lowering and test --------- Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
1 parent aa889ed commit 49f080a

File tree

7 files changed

+389
-175
lines changed

7 files changed

+389
-175
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,41 @@ def MPI_InitOp : MPI_Op<"init", []> {
3737
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
3838
}
3939

40+
//===----------------------------------------------------------------------===//
41+
// CommWorldOp
42+
//===----------------------------------------------------------------------===//
43+
44+
def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
45+
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
46+
let description = [{
47+
This operation returns the predefined MPI_COMM_WORLD communicator.
48+
}];
49+
50+
let results = (outs MPI_Comm : $comm);
51+
52+
let assemblyFormat = "attr-dict `:` type(results)";
53+
}
54+
4055
//===----------------------------------------------------------------------===//
4156
// CommRankOp
4257
//===----------------------------------------------------------------------===//
4358

4459
def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
4560
let summary = "Get the current rank, equivalent to "
46-
"`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
61+
"`MPI_Comm_rank(comm, &rank)`";
4762
let description = [{
48-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
49-
5063
This operation can optionally return an `!mpi.retval` value that can be used
5164
to check for errors.
5265
}];
5366

67+
let arguments = (ins MPI_Comm : $comm);
68+
5469
let results = (
5570
outs Optional<MPI_Retval> : $retval,
5671
I32 : $rank
5772
);
5873

59-
let assemblyFormat = "attr-dict `:` type(results)";
74+
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
6075
}
6176

6277
//===----------------------------------------------------------------------===//
@@ -65,20 +80,48 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
6580

6681
def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
6782
let summary = "Get the size of the group associated to the communicator, "
68-
"equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
83+
"equivalent to `MPI_Comm_size(comm, &size)`";
6984
let description = [{
70-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
71-
7285
This operation can optionally return an `!mpi.retval` value that can be used
7386
to check for errors.
7487
}];
7588

89+
let arguments = (ins MPI_Comm : $comm);
90+
7691
let results = (
7792
outs Optional<MPI_Retval> : $retval,
7893
I32 : $size
7994
);
8095

81-
let assemblyFormat = "attr-dict `:` type(results)";
96+
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
97+
}
98+
99+
//===----------------------------------------------------------------------===//
100+
// CommSplitOp
101+
//===----------------------------------------------------------------------===//
102+
103+
def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
104+
let summary = "Partition the group associated with the given communicator into "
105+
"disjoint subgroups";
106+
let description = [{
107+
This operation splits the communicator into multiple sub-communicators.
108+
The color value determines the group of processes that will be part of the
109+
new communicator. The key value determines the rank of the calling process
110+
in the new communicator.
111+
112+
This operation can optionally return an `!mpi.retval` value that can be used
113+
to check for errors.
114+
}];
115+
116+
let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
117+
118+
let results = (
119+
outs Optional<MPI_Retval> : $retval,
120+
MPI_Comm : $newcomm
121+
);
122+
123+
let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
124+
"type(results)";
82125
}
83126

84127
//===----------------------------------------------------------------------===//
@@ -87,27 +130,26 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
87130

88131
def MPI_SendOp : MPI_Op<"send", []> {
89132
let summary =
90-
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
133+
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
91134
let description = [{
92135
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
93136
`dest`. The `tag` value and communicator enables the library to determine
94137
the matching of multiple sends and receives between the same ranks.
95138

96-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
97-
98139
This operation can optionally return an `!mpi.retval` value that can be used
99140
to check for errors.
100141
}];
101142

102143
let arguments = (
103144
ins AnyMemRef : $ref,
104145
I32 : $tag,
105-
I32 : $dest
146+
I32 : $dest,
147+
MPI_Comm : $comm
106148
);
107149

108150
let results = (outs Optional<MPI_Retval>:$retval);
109151

110-
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
152+
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm `)` attr-dict `:` "
111153
"type($ref) `,` type($tag) `,` type($dest)"
112154
"(`->` type($retval)^)?";
113155
let hasCanonicalizer = 1;
@@ -119,32 +161,31 @@ def MPI_SendOp : MPI_Op<"send", []> {
119161

120162
def MPI_ISendOp : MPI_Op<"isend", []> {
121163
let summary =
122-
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
164+
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
123165
let description = [{
124166
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
125167
rank `dest`. The `tag` value and communicator enables the library to
126168
determine the matching of multiple sends and receives between the same
127169
ranks.
128170

129-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
130-
131171
This operation can optionally return an `!mpi.retval` value that can be used
132172
to check for errors.
133173
}];
134174

135175
let arguments = (
136176
ins AnyMemRef : $ref,
137177
I32 : $tag,
138-
I32 : $rank
178+
I32 : $dest,
179+
MPI_Comm : $comm
139180
);
140181

141182
let results = (
142183
outs Optional<MPI_Retval>:$retval,
143184
MPI_Request : $req
144185
);
145186

146-
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
147-
"`:` type($ref) `,` type($tag) `,` type($rank) "
187+
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm`)` attr-dict "
188+
"`:` type($ref) `,` type($tag) `,` type($dest) "
148189
"`->` type(results)";
149190
let hasCanonicalizer = 1;
150191
}
@@ -155,14 +196,13 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
155196

156197
def MPI_RecvOp : MPI_Op<"recv", []> {
157198
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
158-
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
199+
"comm, MPI_STATUS_IGNORE)`";
159200
let description = [{
160201
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
161202
from rank `source`. The `tag` value and communicator enables the library to
162203
determine the matching of multiple sends and receives between the same
163204
ranks.
164205

165-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
166206
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
167207
is not yet ported to MLIR.
168208

@@ -172,13 +212,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
172212

173213
let arguments = (
174214
ins AnyMemRef : $ref,
175-
I32 : $tag, I32 : $source
215+
I32 : $tag, I32 : $source,
216+
MPI_Comm : $comm
176217
);
177218

178219
let results = (outs Optional<MPI_Retval>:$retval);
179220

180-
let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
181-
"type($ref) `,` type($tag) `,` type($source)"
221+
let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm `)` attr-dict"
222+
" `:` type($ref) `,` type($tag) `,` type($source) "
182223
"(`->` type($retval)^)?";
183224
let hasCanonicalizer = 1;
184225
}
@@ -188,34 +229,33 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
188229
//===----------------------------------------------------------------------===//
189230

190231
def MPI_IRecvOp : MPI_Op<"irecv", []> {
191-
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
192-
"MPI_COMM_WORLD, &req)`";
232+
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
233+
"comm, &req)`";
193234
let description = [{
194235
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
195-
from rank `dest`. The `tag` value and communicator enables the library to
236+
from rank `source`. The `tag` value and communicator enables the library to
196237
determine the matching of multiple sends and receives between the same
197238
ranks.
198239

199-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
200-
201240
This operation can optionally return an `!mpi.retval` value that can be used
202241
to check for errors.
203242
}];
204243

205244
let arguments = (
206245
ins AnyMemRef : $ref,
207246
I32 : $tag,
208-
I32 : $rank
247+
I32 : $source,
248+
MPI_Comm : $comm
209249
);
210250

211251
let results = (
212252
outs Optional<MPI_Retval>:$retval,
213253
MPI_Request : $req
214254
);
215255

216-
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
217-
"type($ref) `,` type($tag) `,` type($rank) `->`"
218-
"type(results)";
256+
let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm`)` attr-dict "
257+
"`:` type($ref) `,` type($tag) `,` type($source)"
258+
"`->` type(results)";
219259
let hasCanonicalizer = 1;
220260
}
221261

@@ -224,8 +264,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
224264
//===----------------------------------------------------------------------===//
225265

226266
def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
227-
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
228-
"MPI_COMM_WORLD)`";
267+
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
229268
let description = [{
230269
MPI_Allreduce performs a reduction operation on the values in the sendbuf
231270
array and stores the result in the recvbuf array. The operation is
@@ -235,22 +274,21 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
235274
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
236275
supported.
237276

238-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
239-
240277
This operation can optionally return an `!mpi.retval` value that can be used
241278
to check for errors.
242279
}];
243280

244281
let arguments = (
245282
ins AnyMemRef : $sendbuf,
246283
AnyMemRef : $recvbuf,
247-
MPI_OpClassEnum : $op
284+
MPI_OpClassEnum : $op,
285+
MPI_Comm : $comm
248286
);
249287

250288
let results = (outs Optional<MPI_Retval>:$retval);
251289

252-
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
253-
"type($sendbuf) `,` type($recvbuf)"
290+
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
291+
"attr-dict `:` type($sendbuf) `,` type($recvbuf) "
254292
"(`->` type($retval)^)?";
255293
}
256294

@@ -259,20 +297,23 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
259297
//===----------------------------------------------------------------------===//
260298

261299
def MPI_Barrier : MPI_Op<"barrier", []> {
262-
let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
300+
let summary = "Equivalent to `MPI_Barrier(comm)`";
263301
let description = [{
264302
MPI_Barrier blocks execution until all processes in the communicator have
265303
reached this routine.
266304

267-
Communicators other than `MPI_COMM_WORLD` are not supported for now.
268-
269305
This operation can optionally return an `!mpi.retval` value that can be used
270306
to check for errors.
271307
}];
272308

309+
let arguments = (ins MPI_Comm : $comm);
310+
273311
let results = (outs Optional<MPI_Retval>:$retval);
274312

275-
let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
313+
let assemblyFormat = [{
314+
`(` $comm `)` attr-dict
315+
(`->` type($retval)^)?
316+
}];
276317
}
277318

278319
//===----------------------------------------------------------------------===//
@@ -295,8 +336,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
295336

296337
let results = (outs Optional<MPI_Retval>:$retval);
297338

298-
let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
299-
"(`->` type($retval) ^)?";
339+
let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?";
300340
}
301341

302342
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MPI/IR/MPITypes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
4040
}];
4141
}
4242

43+
//===----------------------------------------------------------------------===//
44+
// mpi::CommType
45+
//===----------------------------------------------------------------------===//
46+
47+
def MPI_Comm : MPI_Type<"Comm", "comm"> {
48+
let summary = "MPI communicator handler";
49+
let description = [{
50+
This type represents a handler for the MPI communicator.
51+
}];
52+
}
53+
4354
//===----------------------------------------------------------------------===//
4455
// mpi::RequestType
4556
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)