Skip to content

[mlir][mpi] Mandatory Communicator #133280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 86 additions & 46 deletions mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,41 @@ def MPI_InitOp : MPI_Op<"init", []> {
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
}

//===----------------------------------------------------------------------===//
// CommWorldOp
//===----------------------------------------------------------------------===//

def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
let description = [{
This operation returns the predefined MPI_COMM_WORLD communicator.
}];

let results = (outs MPI_Comm : $comm);

let assemblyFormat = "attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// CommRankOp
//===----------------------------------------------------------------------===//

def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Comm : $comm);

let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $rank
);

let assemblyFormat = "attr-dict `:` type(results)";
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
Expand All @@ -65,20 +80,48 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {

def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Comm : $comm);

let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);

let assemblyFormat = "attr-dict `:` type(results)";
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// CommSplitOp
//===----------------------------------------------------------------------===//

def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{
This operation splits the communicator into multiple sub-communicators.
The color value determines the group of processes that will be part of the
new communicator. The key value determines the rank of the calling process
in the new communicator.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);

let results = (
outs Optional<MPI_Retval> : $retval,
MPI_Comm : $newcomm
);

let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
"type(results)";
}

//===----------------------------------------------------------------------===//
Expand All @@ -87,27 +130,26 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {

def MPI_SendOp : MPI_Op<"send", []> {
let summary =
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $dest
I32 : $dest,
MPI_Comm : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($dest)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
Expand All @@ -119,32 +161,31 @@ def MPI_SendOp : MPI_Op<"send", []> {

def MPI_ISendOp : MPI_Op<"isend", []> {
let summary =
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank
I32 : $dest,
MPI_Comm : $comm
);

let results = (
outs Optional<MPI_Retval>:$retval,
MPI_Request : $req
);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($dest) "
"`->` type(results)";
let hasCanonicalizer = 1;
}
Expand All @@ -155,14 +196,13 @@ def MPI_ISendOp : MPI_Op<"isend", []> {

def MPI_RecvOp : MPI_Op<"recv", []> {
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
"comm, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supported for now.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

Expand All @@ -172,13 +212,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag, I32 : $source
I32 : $tag, I32 : $source,
MPI_Comm : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($source)"
let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm `)` attr-dict"
" `:` type($ref) `,` type($tag) `,` type($source) "
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
Expand All @@ -188,34 +229,33 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
//===----------------------------------------------------------------------===//

def MPI_IRecvOp : MPI_Op<"irecv", []> {
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
"MPI_COMM_WORLD, &req)`";
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
"comm, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank
I32 : $source,
MPI_Comm : $comm
);

let results = (
outs Optional<MPI_Retval>:$retval,
MPI_Request : $req
);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
"type($ref) `,` type($tag) `,` type($rank) `->`"
"type(results)";
let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($source)"
"`->` type(results)";
let hasCanonicalizer = 1;
}

Expand All @@ -224,8 +264,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
//===----------------------------------------------------------------------===//

def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
"MPI_COMM_WORLD)`";
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
let description = [{
MPI_Allreduce performs a reduction operation on the values in the sendbuf
array and stores the result in the recvbuf array. The operation is
Expand All @@ -235,22 +274,21 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
MPI_OpClassEnum : $op
MPI_OpClassEnum : $op,
MPI_Comm : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
"type($sendbuf) `,` type($recvbuf)"
let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
"attr-dict `:` type($sendbuf) `,` type($recvbuf) "
"(`->` type($retval)^)?";
}

Expand All @@ -259,20 +297,23 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
//===----------------------------------------------------------------------===//

def MPI_Barrier : MPI_Op<"barrier", []> {
let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
let summary = "Equivalent to `MPI_Barrier(comm)`";
let description = [{
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Comm : $comm);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
let assemblyFormat = [{
`(` $comm `)` attr-dict
(`->` type($retval)^)?
}];
}

//===----------------------------------------------------------------------===//
Expand All @@ -295,8 +336,7 @@ def MPI_Wait : MPI_Op<"wait", []> {

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
"(`->` type($retval) ^)?";
let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?";
}

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
}];
}

//===----------------------------------------------------------------------===//
// mpi::CommType
//===----------------------------------------------------------------------===//

def MPI_Comm : MPI_Type<"Comm", "comm"> {
let summary = "MPI communicator handler";
let description = [{
This type represents a handler for the MPI communicator.
}];
}

//===----------------------------------------------------------------------===//
// mpi::RequestType
//===----------------------------------------------------------------------===//
Expand Down
Loading