Skip to content

Commit 8aeb96d

Browse files
committed
implementing OMPIImplTraits::getMPIOp
1 parent 80179ce commit 8aeb96d

File tree

3 files changed

+82
-13
lines changed

3 files changed

+82
-13
lines changed

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,10 @@ class OMPIImplTraits : public MPIImplTraits {
291291

292292
auto context = rewriter.getContext();
293293
// get external opaque struct pointer type
294-
auto commStructT =
294+
auto typeStructT =
295295
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
296296
// make sure global op definition exists
297-
getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
297+
getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
298298
// get address of symbol
299299
return rewriter.create<LLVM::AddressOfOp>(
300300
loc, LLVM::LLVMPointerType::get(context),
@@ -303,8 +303,61 @@ class OMPIImplTraits : public MPIImplTraits {
303303

304304
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
305305
mpi::MPI_OpClassEnum opAttr) override {
306-
llvm_unreachable("getMPIOp not implemented for OpenMPI");
307-
return Value();
306+
StringRef op;
307+
switch (opAttr) {
308+
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
309+
op = "ompi_mpi_no_op";
310+
break;
311+
case mpi::MPI_OpClassEnum::MPI_MAX:
312+
op = "ompi_mpi_max";
313+
break;
314+
case mpi::MPI_OpClassEnum::MPI_MIN:
315+
op = "ompi_mpi_min";
316+
break;
317+
case mpi::MPI_OpClassEnum::MPI_SUM:
318+
op = "ompi_mpi_sum";
319+
break;
320+
case mpi::MPI_OpClassEnum::MPI_PROD:
321+
op = "ompi_mpi_prod";
322+
break;
323+
case mpi::MPI_OpClassEnum::MPI_LAND:
324+
op = "ompi_mpi_land";
325+
break;
326+
case mpi::MPI_OpClassEnum::MPI_BAND:
327+
op = "ompi_mpi_band";
328+
break;
329+
case mpi::MPI_OpClassEnum::MPI_LOR:
330+
op = "ompi_mpi_lor";
331+
break;
332+
case mpi::MPI_OpClassEnum::MPI_BOR:
333+
op = "ompi_mpi_bor";
334+
break;
335+
case mpi::MPI_OpClassEnum::MPI_LXOR:
336+
op = "ompi_mpi_lxor";
337+
break;
338+
case mpi::MPI_OpClassEnum::MPI_BXOR:
339+
op = "ompi_mpi_bxor";
340+
break;
341+
case mpi::MPI_OpClassEnum::MPI_MINLOC:
342+
op = "ompi_mpi_minloc";
343+
break;
344+
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
345+
op = "ompi_mpi_maxloc";
346+
break;
347+
case mpi::MPI_OpClassEnum::MPI_REPLACE:
348+
op = "ompi_mpi_replace";
349+
break;
350+
}
351+
auto context = rewriter.getContext();
352+
// get external opaque struct pointer type
353+
auto opStructT =
354+
LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
355+
// make sure global op definition exists
356+
getOrDefineExternalStruct(loc, rewriter, op, opStructT);
357+
// get address of symbol
358+
return rewriter.create<LLVM::AddressOfOp>(
359+
loc, LLVM::LLVMPointerType::get(context),
360+
SymbolRefAttr::get(context, op));
308361
}
309362
};
310363

mlir/test/Conversion/MPIToLLVM/ops.mlir renamed to mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
7272
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
7373
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
7474
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
75-
76-
// CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
77-
// CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
75+
76+
// CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
77+
// CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
7878
// CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
79-
// CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
79+
// CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
8080
// CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
81-
// CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
82-
// CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
81+
// CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
82+
// CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
8383
// CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
84-
// CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
84+
// CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
8585
// CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
8686
// CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
8787
// CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
@@ -172,6 +172,22 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
172172
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
173173
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
174174

175+
// CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
176+
// CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
177+
// CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
178+
// CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
179+
// CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
180+
// CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
181+
// CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
182+
// CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
183+
// CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
184+
// CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
185+
// CHECK-NEXT: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
186+
// CHECK-NEXT: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
187+
// CHECK-NEXT: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
188+
// CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
189+
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
190+
175191
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
176192
%3 = mpi.finalize : !mpi.retval
177193

mlir/test/Dialect/MPI/ops.mlir renamed to mlir/test/Dialect/MPI/mpiops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
4848
// CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
4949
%err7 = mpi.barrier : !mpi.retval
5050

51-
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
51+
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
5252
mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
5353

54-
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
54+
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
5555
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
5656

5757
// CHECK-NEXT: %7 = mpi.finalize : !mpi.retval

0 commit comments

Comments
 (0)