Skip to content

Commit 2003e17

Browse files
committed
using i64 as intermediate type for \!mpi.comm and appropriate casting
1 parent 46a19b8 commit 2003e17

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,17 @@ class MPIImplTraits {
6767
ModuleOp &getModuleOp() { return moduleOp; }
6868

6969
/// Gets or creates MPI_COMM_WORLD as a Value.
70+
/// Different MPI implementations have different types for communicator.
71+
/// Using i64 as a portable, intermediate type.
72+
/// Appropriate cast needs to take place before calling MPI functions.
7073
virtual Value getCommWorld(const Location loc,
7174
ConversionPatternRewriter &rewriter) = 0;
7275

76+
/// Type converter provides i64 type for communicator type.
77+
/// Converts to native type, which might be ptr or int or whatever.
78+
virtual Value castComm(const Location loc,
79+
ConversionPatternRewriter &rewriter, Value comm) = 0;
80+
7381
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
7482
virtual intptr_t getStatusIgnore() = 0;
7583

@@ -103,10 +111,15 @@ class MPICHImplTraits : public MPIImplTraits {
103111
Value getCommWorld(const Location loc,
104112
ConversionPatternRewriter &rewriter) override {
105113
static constexpr int MPI_COMM_WORLD = 0x44000000;
106-
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
114+
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
107115
MPI_COMM_WORLD);
108116
}
109117

118+
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
119+
Value comm) override {
120+
return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
121+
}
122+
110123
intptr_t getStatusIgnore() override { return 1; }
111124

112125
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
@@ -170,9 +183,16 @@ class OMPIImplTraits : public MPIImplTraits {
170183
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
171184

172185
// get address of symbol
173-
return rewriter.create<LLVM::AddressOfOp>(
186+
auto comm = rewriter.create<LLVM::AddressOfOp>(
174187
loc, LLVM::LLVMPointerType::get(context),
175188
SymbolRefAttr::get(context, name));
189+
return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
190+
}
191+
192+
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
193+
Value comm) override {
194+
return rewriter.create<LLVM::IntToPtrOp>(
195+
loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
176196
}
177197

178198
intptr_t getStatusIgnore() override { return 0; }
@@ -338,7 +358,7 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
338358

339359
auto mpiTraits = MPIImplTraits::get(moduleOp);
340360
// get communicator
341-
Value comm = adaptor.getComm();
361+
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
342362

343363
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
344364
auto rankFuncType =
@@ -406,7 +426,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
406426
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
407427
auto mpiTraits = MPIImplTraits::get(moduleOp);
408428
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
409-
Value comm = adaptor.getComm();
429+
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
410430

411431
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
412432
// tag, comm)`
@@ -465,7 +485,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
465485
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
466486
auto mpiTraits = MPIImplTraits::get(moduleOp);
467487
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
468-
Value comm = adaptor.getComm();
488+
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
469489
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
470490
loc, i64, mpiTraits->getStatusIgnore());
471491
statusIgnore =
@@ -517,10 +537,12 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
517537

518538
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
519539
RewritePatternSet &patterns) {
520-
// FIXME: Need tldi info to get mpi implementation to know the Communicator
521-
// type
522-
Type commType = IntegerType::get(&converter.getContext(), 32);
523-
converter.addConversion([&](mpi::CommType type) { return commType; });
540+
// Using i64 as a portable, intermediate type for !mpi.comm.
541+
// It would be nicer to somehow get the right type directly, but TLDI is not
542+
// available here.
543+
converter.addConversion([](mpi::CommType type) {
544+
return IntegerType::get(type.getContext(), 64);
545+
});
524546
patterns.add<CommRankOpLowering, CommWorldOpLowering, FinalizeOpLowering,
525547
InitOpLowering, SendOpLowering, RecvOpLowering>(converter);
526548
}

mlir/test/Conversion/MPIToLLVM/ops.mlir

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
2222
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
2323
%0 = mpi.init : !mpi.retval
2424

25+
// CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
2526
%comm = mpi.comm_world : !mpi.comm
26-
// CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
27+
28+
// CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32
2729
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
2830
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
2931
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
@@ -36,7 +38,8 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
3638
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
3739
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
3840
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
39-
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v8]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
41+
// CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
42+
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
4043
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
4144

4245
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -45,7 +48,8 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
4548
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
4649
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
4750
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
48-
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v8]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
51+
// CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
52+
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
4953
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
5054

5155
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -54,9 +58,10 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
5458
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
5559
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
5660
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
61+
// CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
5762
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
5863
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
59-
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v8]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
64+
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
6065
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
6166

6267
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -65,9 +70,10 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
6570
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
6671
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
6772
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
73+
// CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
6874
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
6975
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
70-
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v8]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
76+
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
7177
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
7278

7379
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
@@ -105,9 +111,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
105111

106112
%comm = mpi.comm_world : !mpi.comm
107113
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
114+
// CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64
115+
// CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
108116
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
109117
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
110-
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
118+
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
111119
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
112120

113121
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
@@ -117,7 +125,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
117125
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
118126
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
119127
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
120-
// CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
128+
// CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
121129
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
122130
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
123131

@@ -127,7 +135,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
127135
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
128136
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
129137
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
130-
// CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
138+
// CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
131139
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
132140
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
133141

@@ -137,7 +145,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
137145
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
138146
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
139147
// CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
140-
// CHECK: [[v35:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
148+
// CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
141149
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
142150
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
143151
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
@@ -149,7 +157,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
149157
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
150158
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
151159
// CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
152-
// CHECK: [[v45:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
160+
// CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
153161
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
154162
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
155163
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32

0 commit comments

Comments
 (0)