Skip to content

Commit 3a2fd9f

Browse files
committed
remove dependency on mpi.h; TODO: runtime dispatch
1 parent 255de64 commit 3a2fd9f

File tree

3 files changed

+157
-164
lines changed

3 files changed

+157
-164
lines changed
Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
find_path(MPI_C_HEADER_DIR mpi.h
2-
PATHS $ENV{I_MPI_ROOT}/include
3-
$ENV{MPI_HOME}/include
4-
$ENV{MPI_ROOT}/include)
5-
61
add_mlir_conversion_library(MLIRMPIToLLVM
72
MPIToLLVM.cpp
83

@@ -20,11 +15,3 @@ add_mlir_conversion_library(MLIRMPIToLLVM
2015
MLIRLLVMDialect
2116
MLIRMPIDialect
2217
)
23-
24-
if(MPI_C_HEADER_DIR)
25-
message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
26-
target_include_directories(obj.MLIRMPIToLLVM PRIVATE ${MPI_C_HEADER_DIR})
27-
target_compile_definitions(obj.MLIRMPIToLLVM PUBLIC FOUND_MPI_C_HEADER=1)
28-
else()
29-
message(WARNING "MPI not found, falling back to definitions from MPICH")
30-
endif()

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 154 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,11 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#define MPICH_SKIP_MPICXX 1
10-
#define OMPI_SKIP_MPICXX 1
11-
#ifdef FOUND_MPI_C_HEADER
12-
// This must go first (MPI gets confused otherwise)
13-
#include <mpi.h>
14-
#else // not FOUND_MPI_C_HEADER
159
//
1610
// Copyright (C) by Argonne National Laboratory
1711
// See COPYRIGHT in top-level directory
1812
// of MPICH source repository.
1913
//
20-
typedef int MPI_Comm;
21-
#define MPI_COMM_WORLD ((MPI_Comm)0x44000000)
22-
23-
typedef int MPI_Datatype;
24-
#define MPI_FLOAT ((MPI_Datatype)0x4c00040a)
25-
#define MPI_DOUBLE ((MPI_Datatype)0x4c00080b)
26-
#define MPI_INT8_T ((MPI_Datatype)0x4c000137)
27-
#define MPI_INT16_T ((MPI_Datatype)0x4c000238)
28-
#define MPI_INT32_T ((MPI_Datatype)0x4c000439)
29-
#define MPI_INT64_T ((MPI_Datatype)0x4c00083a)
30-
#define MPI_UINT8_T ((MPI_Datatype)0x4c00013b)
31-
#define MPI_UINT16_T ((MPI_Datatype)0x4c00023c)
32-
#define MPI_UINT32_T ((MPI_Datatype)0x4c00043d)
33-
#define MPI_UINT64_T ((MPI_Datatype)0x4c00083e)
34-
35-
typedef struct MPI_Status;
36-
#define MPI_STATUS_IGNORE (MPI_Status *)1
37-
38-
#define _MPI_FALLBACK_DEFS 1
39-
#endif // FOUND_MPI_C_HEADER
4014

4115
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
4216
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -71,143 +45,172 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
7145
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
7246
}
7347

74-
// ****************************************************************************
48+
//===----------------------------------------------------------------------===//
49+
// Implementation details for MPICH ABI compatible MPI implementations
50+
//===----------------------------------------------------------------------===//
51+
struct MPICHImplTraits {
52+
static const int MPI_FLOAT = 0x4c00040a;
53+
static const int MPI_DOUBLE = 0x4c00080b;
54+
static const int MPI_INT8_T = 0x4c000137;
55+
static const int MPI_INT16_T = 0x4c000238;
56+
static const int MPI_INT32_T = 0x4c000439;
57+
static const int MPI_INT64_T = 0x4c00083a;
58+
static const int MPI_UINT8_T = 0x4c00013b;
59+
static const int MPI_UINT16_T = 0x4c00023c;
60+
static const int MPI_UINT32_T = 0x4c00043d;
61+
static const int MPI_UINT64_T = 0x4c00083e;
62+
63+
static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
64+
const mlir::Location loc,
65+
mlir::ConversionPatternRewriter &rewriter) {
66+
static const int MPI_COMM_WORLD = 0x44000000;
67+
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
68+
MPI_COMM_WORLD);
69+
}
70+
71+
static intptr_t getStatusIgnore() { return 1; }
72+
73+
static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
74+
const mlir::Location loc,
75+
mlir::ConversionPatternRewriter &rewriter,
76+
mlir::Type type) {
77+
int32_t mtype = 0;
78+
if (type.isF32())
79+
mtype = MPI_FLOAT;
80+
else if (type.isF64())
81+
mtype = MPI_DOUBLE;
82+
else if (type.isInteger(64) && !type.isUnsignedInteger())
83+
mtype = MPI_INT64_T;
84+
else if (type.isInteger(64))
85+
mtype = MPI_UINT64_T;
86+
else if (type.isInteger(32) && !type.isUnsignedInteger())
87+
mtype = MPI_INT32_T;
88+
else if (type.isInteger(32))
89+
mtype = MPI_UINT32_T;
90+
else if (type.isInteger(16) && !type.isUnsignedInteger())
91+
mtype = MPI_INT16_T;
92+
else if (type.isInteger(16))
93+
mtype = MPI_UINT16_T;
94+
else if (type.isInteger(8) && !type.isUnsignedInteger())
95+
mtype = MPI_INT8_T;
96+
else if (type.isInteger(8))
97+
mtype = MPI_UINT8_T;
98+
else
99+
assert(false && "unsupported type");
100+
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
101+
mtype);
102+
}
103+
};
104+
105+
//===----------------------------------------------------------------------===//
106+
// Implementation details for OpenMPI
107+
//===----------------------------------------------------------------------===//
108+
struct OMPIImplTraits {
109+
110+
static mlir::LLVM::GlobalOp
111+
getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
112+
mlir::ConversionPatternRewriter &rewriter,
113+
mlir::StringRef name,
114+
mlir::LLVM::LLVMStructType type) {
115+
116+
return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
117+
moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
118+
mlir::LLVM::Linkage::External, name,
119+
/*value=*/mlir::Attribute(), /*alignment=*/0, 0);
120+
}
121+
122+
static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
123+
const mlir::Location loc,
124+
mlir::ConversionPatternRewriter &rewriter) {
125+
auto context = rewriter.getContext();
126+
// get external opaque struct pointer type
127+
auto commStructT =
128+
mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
129+
mlir::StringRef name = "ompi_mpi_comm_world";
130+
131+
// make sure global op definition exists
132+
(void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
133+
134+
// get address of symbol
135+
return rewriter.create<mlir::LLVM::AddressOfOp>(
136+
loc, mlir::LLVM::LLVMPointerType::get(context),
137+
mlir::SymbolRefAttr::get(context, name));
138+
}
139+
140+
static intptr_t getStatusIgnore() { return 0; }
141+
142+
static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
143+
const mlir::Location loc,
144+
mlir::ConversionPatternRewriter &rewriter,
145+
mlir::Type type) {
146+
mlir::StringRef mtype;
147+
if (type.isF32())
148+
mtype = "ompi_mpi_float";
149+
else if (type.isF64())
150+
mtype = "ompi_mpi_double";
151+
else if (type.isInteger(64) && !type.isUnsignedInteger())
152+
mtype = "ompi_mpi_int64_t";
153+
else if (type.isInteger(64))
154+
mtype = "ompi_mpi_uint64_t";
155+
else if (type.isInteger(32) && !type.isUnsignedInteger())
156+
mtype = "ompi_mpi_int32_t";
157+
else if (type.isInteger(32))
158+
mtype = "ompi_mpi_uint32_t";
159+
else if (type.isInteger(16) && !type.isUnsignedInteger())
160+
mtype = "ompi_mpi_int16_t";
161+
else if (type.isInteger(16))
162+
mtype = "ompi_mpi_uint16_t";
163+
else if (type.isInteger(8) && !type.isUnsignedInteger())
164+
mtype = "ompi_mpi_int8_t";
165+
else if (type.isInteger(8))
166+
mtype = "ompi_mpi_uint8_t";
167+
else
168+
assert(false && "unsupported type");
169+
170+
auto context = rewriter.getContext();
171+
// get external opaque struct pointer type
172+
auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
173+
"ompi_predefined_datatype_t", context);
174+
// make sure global op definition exists
175+
(void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype,
176+
commStructT);
177+
// get address of symbol
178+
return rewriter.create<mlir::LLVM::AddressOfOp>(
179+
loc, mlir::LLVM::LLVMPointerType::get(context),
180+
mlir::SymbolRefAttr::get(context, mtype));
181+
}
182+
};
183+
184+
//===----------------------------------------------------------------------===//
75185
// When lowering the mpi dialect to functions calls certain details
76186
// differ between various MPI implementations. This class will provide
77-
// these depending on the MPI implementation that got included.
187+
// these in a gnereic way, depending on the MPI implementation that got
188+
// included.
189+
//===----------------------------------------------------------------------===//
78190
struct MPIImplTraits {
79191
// get/create MPI_COMM_WORLD as a mlir::Value
80192
static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
81193
const mlir::Location loc,
82-
mlir::ConversionPatternRewriter &rewriter);
194+
mlir::ConversionPatternRewriter &rewriter) {
195+
// TODO: dispatch based on the MPI implementation
196+
return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
197+
}
198+
// Get the MPI_STATUS_IGNORE value (typically a pointer type).
199+
static intptr_t getStatusIgnore() {
200+
// TODO: dispatch based on the MPI implementation
201+
return MPICHImplTraits::getStatusIgnore();
202+
}
83203
// get/create MPI datatype as a mlir::Value which corresponds to the given
84204
// mlir::Type
85205
static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
86206
const mlir::Location loc,
87207
mlir::ConversionPatternRewriter &rewriter,
88-
mlir::Type type);
208+
mlir::Type type) {
209+
// TODO: dispatch based on the MPI implementation
210+
return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
211+
}
89212
};
90213

91-
// ****************************************************************************
92-
// Intel MPI/MPICH
93-
#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
94-
95-
mlir::Value
96-
MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
97-
mlir::ConversionPatternRewriter &rewriter) {
98-
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
99-
MPI_COMM_WORLD);
100-
}
101-
102-
mlir::Value
103-
MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
104-
mlir::ConversionPatternRewriter &rewriter,
105-
mlir::Type type) {
106-
int32_t mtype = 0;
107-
if (type.isF32())
108-
mtype = MPI_FLOAT;
109-
else if (type.isF64())
110-
mtype = MPI_DOUBLE;
111-
else if (type.isInteger(64) && !type.isUnsignedInteger())
112-
mtype = MPI_INT64_T;
113-
else if (type.isInteger(64))
114-
mtype = MPI_UINT64_T;
115-
else if (type.isInteger(32) && !type.isUnsignedInteger())
116-
mtype = MPI_INT32_T;
117-
else if (type.isInteger(32))
118-
mtype = MPI_UINT32_T;
119-
else if (type.isInteger(16) && !type.isUnsignedInteger())
120-
mtype = MPI_INT16_T;
121-
else if (type.isInteger(16))
122-
mtype = MPI_UINT16_T;
123-
else if (type.isInteger(8) && !type.isUnsignedInteger())
124-
mtype = MPI_INT8_T;
125-
else if (type.isInteger(8))
126-
mtype = MPI_UINT8_T;
127-
else
128-
assert(false && "unsupported type");
129-
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
130-
mtype);
131-
}
132-
133-
// ****************************************************************************
134-
// OpenMPI
135-
#elif defined(OPEN_MPI) && OPEN_MPI == 1
136-
137-
static mlir::LLVM::GlobalOp
138-
getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
139-
mlir::ConversionPatternRewriter &rewriter,
140-
mlir::StringRef name,
141-
mlir::LLVM::LLVMStructType type) {
142-
143-
return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
144-
moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
145-
mlir::LLVM::Linkage::External, name,
146-
/*value=*/mlir::Attribute(), /*alignment=*/0, 0);
147-
}
148-
149-
mlir::Value
150-
MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
151-
mlir::ConversionPatternRewriter &rewriter) {
152-
auto context = rewriter.getContext();
153-
// get external opaque struct pointer type
154-
auto commStructT =
155-
mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
156-
mlir::StringRef name = "ompi_mpi_comm_world";
157-
158-
// make sure global op definition exists
159-
(void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
160-
161-
// get address of symbol
162-
return rewriter.create<mlir::LLVM::AddressOfOp>(
163-
loc, mlir::LLVM::LLVMPointerType::get(context),
164-
mlir::SymbolRefAttr::get(context, name));
165-
}
166-
167-
mlir::Value
168-
MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
169-
mlir::ConversionPatternRewriter &rewriter,
170-
mlir::Type type) {
171-
mlir::StringRef mtype = nullptr;
172-
if (type.isF32())
173-
mtype = "ompi_mpi_float";
174-
else if (type.isF64())
175-
mtype = "ompi_mpi_double";
176-
else if (type.isInteger(64) && !type.isUnsignedInteger())
177-
mtype = "ompi_mpi_int64_t";
178-
else if (type.isInteger(64))
179-
mtype = "ompi_mpi_uint64_t";
180-
else if (type.isInteger(32) && !type.isUnsignedInteger())
181-
mtype = "ompi_mpi_int32_t";
182-
else if (type.isInteger(32))
183-
mtype = "ompi_mpi_uint32_t";
184-
else if (type.isInteger(16) && !type.isUnsignedInteger())
185-
mtype = "ompi_mpi_int16_t";
186-
else if (type.isInteger(16))
187-
mtype = "ompi_mpi_uint16_t";
188-
else if (type.isInteger(8) && !type.isUnsignedInteger())
189-
mtype = "ompi_mpi_int8_t";
190-
else if (type.isInteger(8))
191-
mtype = "ompi_mpi_uint8_t";
192-
else
193-
assert(false && "unsupported type");
194-
195-
auto context = rewriter.getContext();
196-
// get external opaque struct pointer type
197-
auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
198-
"ompi_predefined_datatype_t", context);
199-
// make sure global op definition exists
200-
(void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, commStructT);
201-
// get address of symbol
202-
return rewriter.create<mlir::LLVM::AddressOfOp>(
203-
loc, mlir::LLVM::LLVMPointerType::get(context),
204-
mlir::SymbolRefAttr::get(context, mtype));
205-
}
206-
207-
#else
208-
#error "Unsupported MPI implementation"
209-
#endif
210-
211214
//===----------------------------------------------------------------------===//
212215
// InitOpLowering
213216
//===----------------------------------------------------------------------===//
@@ -427,7 +430,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
427430
MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
428431
Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
429432
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
430-
loc, i64, reinterpret_cast<int64_t>(MPI_STATUS_IGNORE));
433+
loc, i64, MPIImplTraits::getStatusIgnore());
431434
statusIgnore =
432435
rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
433436

mlir/test/Conversion/MPIToLLVM/ops.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
22

3+
module attributes { mpi.dlti = #dlti.map<"MPI:Implemention" = "Intel"> } {
4+
35
// CHECK: llvm.func @MPI_Finalize() -> i32
46
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}, !llvm.ptr) -> i32
57
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}) -> i32
@@ -75,3 +77,4 @@ func.func @mpi_test(%arg0: memref<100xf32>) {
7577

7678
return
7779
}
80+
}

0 commit comments

Comments
 (0)