Skip to content

Commit d023e65

Browse files
committed
lowering mpi.comm_split
1 parent 42c99a9 commit d023e65

File tree

4 files changed

+86
-10
lines changed

4 files changed

+86
-10
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
100100
// CommSplitOp
101101
//===----------------------------------------------------------------------===//
102102

103-
def MPI_CommSplit : MPI_Op<"comm_split", []> {
103+
def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
104104
let summary = "Partition the group associated to the given communicator into "
105105
"disjoint subgroups";
106106
let description = [{
@@ -121,7 +121,6 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> {
121121
);
122122

123123
let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
124-
"type($color) `,` type($key) `->` "
125124
"type(results)";
126125
}
127126

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,58 @@ struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
480480
}
481481
};
482482

483+
//===----------------------------------------------------------------------===//
484+
// CommSplitOpLowering
485+
//===----------------------------------------------------------------------===//
486+
487+
struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
488+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
489+
490+
LogicalResult
491+
matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
492+
ConversionPatternRewriter &rewriter) const override {
493+
// grab a reference to the global module op:
494+
auto moduleOp = op->getParentOfType<ModuleOp>();
495+
auto mpiTraits = MPIImplTraits::get(moduleOp);
496+
Type i32 = rewriter.getI32Type();
497+
Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
498+
Location loc = op.getLoc();
499+
500+
// get communicator
501+
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
502+
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
503+
auto outPtr =
504+
rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
505+
506+
// int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
507+
auto funcType =
508+
LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
509+
// get or create function declaration:
510+
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
511+
"MPI_Comm_split", funcType);
512+
513+
auto callOp = rewriter.create<LLVM::CallOp>(
514+
loc, funcDecl,
515+
ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
516+
outPtr.getRes()});
517+
518+
// load the communicator into a register
519+
auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
520+
521+
// if retval is checked, replace uses of retval with the results from the
522+
// call op
523+
SmallVector<Value> replacements;
524+
if (op.getRetval())
525+
replacements.push_back(callOp.getResult());
526+
527+
// replace op
528+
replacements.push_back(res.getRes());
529+
rewriter.replaceOp(op, replacements);
530+
531+
return success();
532+
}
533+
};
534+
483535
//===----------------------------------------------------------------------===//
484536
// CommRankOpLowering
485537
//===----------------------------------------------------------------------===//
@@ -512,7 +564,7 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
512564
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
513565
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
514566

515-
// replace init with function call
567+
// replace with function call
516568
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
517569
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
518570
auto callOp = rewriter.create<LLVM::CallOp>(
@@ -722,9 +774,10 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
722774
converter.addConversion([](mpi::CommType type) {
723775
return IntegerType::get(type.getContext(), 64);
724776
});
725-
patterns.add<CommRankOpLowering, CommWorldOpLowering, FinalizeOpLowering,
726-
InitOpLowering, SendOpLowering, RecvOpLowering,
727-
AllReduceOpLowering>(converter);
777+
patterns
778+
.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
779+
FinalizeOpLowering, InitOpLowering, SendOpLowering, RecvOpLowering,
780+
AllReduceOpLowering>(converter);
728781
}
729782

730783
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// COM: Test MPICH ABI
44
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
55
// CHECK: llvm.func @MPI_Finalize() -> i32
6+
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
67
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
78
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
89
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
@@ -75,6 +76,17 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
7576
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
7677
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
7778
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
79+
80+
// CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
81+
%color = arith.constant 10 : i32
82+
// CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
83+
%key = arith.constant 22 : i32
84+
// CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32
85+
// CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
86+
// CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr
87+
// CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32
88+
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
89+
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
7890

7991
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
8092
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -104,6 +116,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
104116
// COM: Test OpenMPI ABI
105117
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
106118
// CHECK: llvm.func @MPI_Finalize() -> i32
119+
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
107120
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
108121
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
109122
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -195,6 +208,17 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
195208
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
196209
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
197210

211+
// CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
212+
%color = arith.constant 10 : i32
213+
// CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
214+
%key = arith.constant 22 : i32
215+
// CHECK: [[v53:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
216+
// CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
217+
// CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x !llvm.ptr : (i32) -> !llvm.ptr
218+
// CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
219+
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
220+
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
221+
198222
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
199223
%3 = mpi.finalize : !mpi.retval
200224

mlir/test/Dialect/MPI/mpiops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
2323
// CHECK-NEXT: [[vretval_1:%.*]], [[vsize_2:%.*]] = mpi.comm_size([[v1]]) : !mpi.retval, i32
2424
%retval_0, %size_1 = mpi.comm_size(%comm) : !mpi.retval, i32
2525

26-
// CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : i32, i32 -> !mpi.comm
27-
%new_comm = mpi.comm_split(%comm, %rank, %rank) : i32, i32 -> !mpi.comm
26+
// CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.comm
27+
%new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm
2828

29-
// CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : i32, i32 -> !mpi.retval, !mpi.comm
30-
%retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : i32, i32 -> !mpi.retval, !mpi.comm
29+
// CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.retval, !mpi.comm
30+
%retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : !mpi.retval, !mpi.comm
3131

3232
// CHECK-NEXT: mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
3333
mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32

0 commit comments

Comments
 (0)