Skip to content

Commit 174d629

Browse files
committed
merge conflicts, fixing mpi.all_reduce
1 parent 9f23e95 commit 174d629

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
175175
let arguments = (
176176
ins AnyMemRef : $ref,
177177
I32 : $tag,
178-
I32 : $rank,
178+
I32 : $dest,
179179
MPI_Comm : $comm
180180
);
181181

@@ -184,8 +184,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
184184
MPI_Request : $req
185185
);
186186

187-
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
188-
"`:` type($ref) `,` type($tag) `,` type($rank) "
187+
let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm`)` attr-dict "
188+
"`:` type($ref) `,` type($tag) `,` type($dest) "
189189
"`->` type(results)";
190190
let hasCanonicalizer = 1;
191191
}
@@ -229,11 +229,11 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
229229
//===----------------------------------------------------------------------===//
230230

231231
def MPI_IRecvOp : MPI_Op<"irecv", []> {
232-
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
232+
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
233233
"comm, &req)`";
234234
let description = [{
235235
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
236-
from rank `dest`. The `tag` value and communicator enables the library to
236+
from rank `source`. The `tag` value and communicator enables the library to
237237
determine the matching of multiple sends and receives between the same
238238
ranks.
239239

@@ -244,7 +244,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
244244
let arguments = (
245245
ins AnyMemRef : $ref,
246246
I32 : $tag,
247-
I32 : $rank,
247+
I32 : $source,
248248
MPI_Comm : $comm
249249
);
250250

@@ -253,8 +253,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
253253
MPI_Request : $req
254254
);
255255

256-
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
257-
"`:` type($ref) `,` type($tag) `,` type($rank)"
256+
let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm`)` attr-dict "
257+
"`:` type($ref) `,` type($tag) `,` type($source)"
258258
"`->` type(results)";
259259
let hasCanonicalizer = 1;
260260
}
@@ -281,7 +281,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
281281
let arguments = (
282282
ins AnyMemRef : $sendbuf,
283283
AnyMemRef : $recvbuf,
284-
MPI_OpClassAttr : $op,
284+
MPI_OpClassEnum : $op,
285285
MPI_Comm : $comm
286286
);
287287

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
721721
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
722722
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
723723
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
724-
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
724+
Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
725+
725726
// 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
726727
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
727728
auto funcType = LLVM::LLVMFunctionType::get(

mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
100100
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
101101
// CHECK: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
102102
// CHECK: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
103-
// CHECK: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
103+
// CHECK: [[v61:%.*]] = llvm.trunc [[comm]] : i64 to i32
104104
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
105-
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
105+
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
106106

107107
// CHECK: llvm.call @MPI_Finalize() : () -> i32
108108
%3 = mpi.finalize : !mpi.retval
@@ -204,9 +204,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
204204
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
205205
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
206206
// CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
207-
// CHECK: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
207+
// CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
208208
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
209-
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
209+
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
210210

211211
// CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
212212
%color = arith.constant 10 : i32

0 commit comments

Comments
 (0)