Skip to content

[MLIR] Extend MPI dialect #123255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
1fbae54
Add `MPI_Comm`, `MPI_Request`, `MPI_Status`, `MPI_Op` type definitions
mofeing Jan 16, 2025
dc84ca4
Add `MPI_CommSize`, `MPI_ISend`, `MPI_IRecv` ops
mofeing Jan 16, 2025
2ee10ab
Fix typo
mofeing Jan 16, 2025
539bf43
Finish types
mofeing Jan 25, 2025
662998d
Define `MPI_Op` enum & attr
mofeing Jan 26, 2025
c1ec63c
Add communicator argument to mpi ops as optional input argument
mofeing Jan 26, 2025
7eda791
Add summary of new mpi types
mofeing Jan 26, 2025
b97a541
format code
mofeing Jan 26, 2025
d5725a8
Add `mpi.comm_split` op
mofeing Jan 26, 2025
1a68b34
Add `mpi.barrier` op
mofeing Jan 26, 2025
80a4259
Format code
mofeing Jan 26, 2025
cfb81af
Fix ops returning `MPI_Request`
mofeing Jan 26, 2025
740cf0b
Add `mpi.wait` op
mofeing Jan 26, 2025
1af1425
Add `mpi.allreduce` op
mofeing Jan 26, 2025
c11a60f
Fix assembly formats
mofeing Jan 27, 2025
d971d83
add some tests
mofeing Jan 27, 2025
beb5764
Fix input specifier
mofeing Jan 27, 2025
2317994
Comment predefined constant MPI_Ops
mofeing Jan 27, 2025
63ccc33
Replace `MPI_Op` new type for region
mofeing Jan 27, 2025
d318c60
Go back to only use predefined MPI_Ops
mofeing Jan 28, 2025
8e3aa18
Remove `MPI_Operation` type
mofeing Jan 29, 2025
9c708d4
Add `mpi.comm_world` op to return `MPI_COMM_WORLD`
mofeing Jan 29, 2025
326b13f
Add tests
mofeing Jan 29, 2025
2baf33f
Merge branch 'main' into mlir-mpi
mofeing Jan 29, 2025
1fd5578
Fix anchor of assembly format
mofeing Jan 29, 2025
016b856
Fix more anchors
mofeing Jan 29, 2025
1931b8e
Fix anchors again
mofeing Jan 29, 2025
aec9fbd
fix another anchor
mofeing Jan 29, 2025
d4684fb
fix optional format of `MPI_BarrierOp`
mofeing Jan 29, 2025
794fa25
fix more anchors
mofeing Jan 29, 2025
f0d0f44
fix anchors in `MPI_ISendOp` and `MPI_IRecvOp`
mofeing Jan 29, 2025
92f2cca
fix format
mofeing Jan 29, 2025
3688915
Define `getCanonicalizationPatterns` for `ISendOp`, `IRecvOp`, `AllRe…
mofeing Jan 29, 2025
3abe925
remove duplicated `getCanonicalizationPatterns`
mofeing Jan 29, 2025
7a9fa9c
Remove canonicalization for `AllReduceOp`
mofeing Jan 29, 2025
1926bda
fix test
mofeing Jan 29, 2025
89ec111
fix some assembly formats
mofeing Jan 29, 2025
30fb673
fix syntax
mofeing Jan 29, 2025
6abba5a
Remove MPI_Comm type
mofeing Jan 29, 2025
452f760
fix tests
mofeing Jan 29, 2025
56868e8
change order of results of `MPI_CommRankOp`
mofeing Jan 29, 2025
b9988b3
format code
mofeing Jan 29, 2025
2075c02
format code
mofeing Jan 29, 2025
8477428
refactor assembly format of `isend`, `irecv` and fix tests
mofeing Jan 30, 2025
1259cbc
last fixes
mofeing Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/MPI/IR/MPI.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,43 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
let assemblyFormat = "`<` $value `>`";
}

def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;

def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpNull,
MPI_OpMax,
MPI_OpMin,
MPI_OpSum,
MPI_OpProd,
MPI_OpLand,
MPI_OpBand,
MPI_OpLor,
MPI_OpBor,
MPI_OpLxor,
MPI_OpBxor,
MPI_OpMinloc,
MPI_OpMaxloc,
MPI_OpReplace
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mpi";
}

def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // MLIR_DIALECT_MPI_IR_MPI_TD
195 changes: 186 additions & 9 deletions mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let assemblyFormat = "attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// CommSizeOp
//===----------------------------------------------------------------------===//

def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
let description = [{
Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);

let assemblyFormat = "attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// SendOp
//===----------------------------------------------------------------------===//
Expand All @@ -71,13 +93,17 @@ def MPI_SendOp : MPI_Op<"send", []> {
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.
Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank
);

let results = (outs Optional<MPI_Retval>:$retval);

Expand All @@ -87,6 +113,42 @@ def MPI_SendOp : MPI_Op<"send", []> {
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// ISendOp
//===----------------------------------------------------------------------===//

def MPI_ISendOp : MPI_Op<"isend", []> {
let summary =
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
let description = [{
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank
);

let results = (
outs Optional<MPI_Retval>:$retval,
MPI_Request : $req
);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
"`->` type(results)";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// RecvOp
//===----------------------------------------------------------------------===//
Expand All @@ -100,24 +162,142 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.
Communicators other than `MPI_COMM_WORLD` are not supported for now.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag, I32 : $rank
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
"type($ref) `,` type($tag) `,` type($rank)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// IRecvOp
//===----------------------------------------------------------------------===//

def MPI_IRecvOp : MPI_Op<"irecv", []> {
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
"MPI_COMM_WORLD, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank
);

let results = (
outs Optional<MPI_Retval>:$retval,
MPI_Request : $req
);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
"type($ref) `,` type($tag) `,` type($rank) `->`"
"type(results)";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//

def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
"MPI_COMM_WORLD)`";
let description = [{
MPI_Allreduce performs a reduction operation on the values in the sendbuf
array and stores the result in the recvbuf array. The operation is
performed across all processes in the communicator.

The `op` attribute specifies the reduction operation to be performed.
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
MPI_OpClassAttr : $op
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
"type($sendbuf) `,` type($recvbuf)"
"(`->` type($retval)^)?";
}

//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//

def MPI_Barrier : MPI_Op<"barrier", []> {
let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
let description = [{
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
}

//===----------------------------------------------------------------------===//
// WaitOp
//===----------------------------------------------------------------------===//

def MPI_Wait : MPI_Op<"wait", []> {
let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Wait blocks execution until the request has completed.

The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Request : $req);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
"(`->` type($retval) ^)?";
}

//===----------------------------------------------------------------------===//
// FinalizeOp
Expand All @@ -139,7 +319,6 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> {
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
}


//===----------------------------------------------------------------------===//
// RetvalCheckOp
//===----------------------------------------------------------------------===//
Expand All @@ -163,10 +342,8 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)";
}



//===----------------------------------------------------------------------===//
// RetvalCheckOp
// ErrorClassOp
//===----------------------------------------------------------------------===//

def MPI_ErrorClassOp : MPI_Op<"error_class", []> {
Expand Down
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,26 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
}];
}

//===----------------------------------------------------------------------===//
// mpi::RequestType
//===----------------------------------------------------------------------===//

def MPI_Request : MPI_Type<"Request", "request"> {
let summary = "MPI asynchronous request handler";
let description = [{
This type represents a handler to an asynchronous request.
}];
}

//===----------------------------------------------------------------------===//
// mpi::StatusType
//===----------------------------------------------------------------------===//

def MPI_Status : MPI_Type<"Status", "status"> {
let summary = "MPI reception operation status type";
let description = [{
This type represents the status of a reception operation.
}];
}

#endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/MPI/IR/MPIOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ void mlir::mpi::RecvOp::getCanonicalizationPatterns(
results.add<FoldCast<mlir::mpi::RecvOp>>(context);
}

void mlir::mpi::ISendOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
results.add<FoldCast<mlir::mpi::ISendOp>>(context);
}

void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
Loading