|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
9 |
| -#define MPICH_SKIP_MPICXX 1 |
10 |
| -#define OMPI_SKIP_MPICXX 1 |
11 |
| -#ifdef FOUND_MPI_C_HEADER |
12 |
| -// This must go first (MPI gets confused otherwise) |
13 |
| -#include <mpi.h> |
14 |
| -#else // not FOUND_MPI_C_HEADER |
15 | 9 | //
|
16 | 10 | // Copyright (C) by Argonne National Laboratory
|
17 | 11 | // See COPYRIGHT in top-level directory
|
18 | 12 | // of MPICH source repository.
|
19 | 13 | //
|
20 |
| -typedef int MPI_Comm; |
21 |
| -#define MPI_COMM_WORLD ((MPI_Comm)0x44000000) |
22 |
| - |
23 |
| -typedef int MPI_Datatype; |
24 |
| -#define MPI_FLOAT ((MPI_Datatype)0x4c00040a) |
25 |
| -#define MPI_DOUBLE ((MPI_Datatype)0x4c00080b) |
26 |
| -#define MPI_INT8_T ((MPI_Datatype)0x4c000137) |
27 |
| -#define MPI_INT16_T ((MPI_Datatype)0x4c000238) |
28 |
| -#define MPI_INT32_T ((MPI_Datatype)0x4c000439) |
29 |
| -#define MPI_INT64_T ((MPI_Datatype)0x4c00083a) |
30 |
| -#define MPI_UINT8_T ((MPI_Datatype)0x4c00013b) |
31 |
| -#define MPI_UINT16_T ((MPI_Datatype)0x4c00023c) |
32 |
| -#define MPI_UINT32_T ((MPI_Datatype)0x4c00043d) |
33 |
| -#define MPI_UINT64_T ((MPI_Datatype)0x4c00083e) |
34 |
| - |
35 |
| -typedef struct MPI_Status; |
36 |
| -#define MPI_STATUS_IGNORE (MPI_Status *)1 |
37 |
| - |
38 |
| -#define _MPI_FALLBACK_DEFS 1 |
39 |
| -#endif // FOUND_MPI_C_HEADER |
40 | 14 |
|
41 | 15 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
42 | 16 | #include "mlir/Conversion/LLVMCommon/Pattern.h"
|
@@ -71,143 +45,172 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
|
71 | 45 | moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
|
72 | 46 | }
|
73 | 47 |
|
74 |
| -// **************************************************************************** |
| 48 | +//===----------------------------------------------------------------------===// |
| 49 | +// Implementation details for MPICH ABI compatible MPI implementations |
| 50 | +//===----------------------------------------------------------------------===// |
| 51 | +struct MPICHImplTraits { |
| 52 | + static const int MPI_FLOAT = 0x4c00040a; |
| 53 | + static const int MPI_DOUBLE = 0x4c00080b; |
| 54 | + static const int MPI_INT8_T = 0x4c000137; |
| 55 | + static const int MPI_INT16_T = 0x4c000238; |
| 56 | + static const int MPI_INT32_T = 0x4c000439; |
| 57 | + static const int MPI_INT64_T = 0x4c00083a; |
| 58 | + static const int MPI_UINT8_T = 0x4c00013b; |
| 59 | + static const int MPI_UINT16_T = 0x4c00023c; |
| 60 | + static const int MPI_UINT32_T = 0x4c00043d; |
| 61 | + static const int MPI_UINT64_T = 0x4c00083e; |
| 62 | + |
| 63 | + static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp, |
| 64 | + const mlir::Location loc, |
| 65 | + mlir::ConversionPatternRewriter &rewriter) { |
| 66 | + static const int MPI_COMM_WORLD = 0x44000000; |
| 67 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
| 68 | + MPI_COMM_WORLD); |
| 69 | + } |
| 70 | + |
| 71 | + static intptr_t getStatusIgnore() { return 1; } |
| 72 | + |
| 73 | + static mlir::Value getDataType(mlir::ModuleOp &moduleOp, |
| 74 | + const mlir::Location loc, |
| 75 | + mlir::ConversionPatternRewriter &rewriter, |
| 76 | + mlir::Type type) { |
| 77 | + int32_t mtype = 0; |
| 78 | + if (type.isF32()) |
| 79 | + mtype = MPI_FLOAT; |
| 80 | + else if (type.isF64()) |
| 81 | + mtype = MPI_DOUBLE; |
| 82 | + else if (type.isInteger(64) && !type.isUnsignedInteger()) |
| 83 | + mtype = MPI_INT64_T; |
| 84 | + else if (type.isInteger(64)) |
| 85 | + mtype = MPI_UINT64_T; |
| 86 | + else if (type.isInteger(32) && !type.isUnsignedInteger()) |
| 87 | + mtype = MPI_INT32_T; |
| 88 | + else if (type.isInteger(32)) |
| 89 | + mtype = MPI_UINT32_T; |
| 90 | + else if (type.isInteger(16) && !type.isUnsignedInteger()) |
| 91 | + mtype = MPI_INT16_T; |
| 92 | + else if (type.isInteger(16)) |
| 93 | + mtype = MPI_UINT16_T; |
| 94 | + else if (type.isInteger(8) && !type.isUnsignedInteger()) |
| 95 | + mtype = MPI_INT8_T; |
| 96 | + else if (type.isInteger(8)) |
| 97 | + mtype = MPI_UINT8_T; |
| 98 | + else |
| 99 | + assert(false && "unsupported type"); |
| 100 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
| 101 | + mtype); |
| 102 | + } |
| 103 | +}; |
| 104 | + |
| 105 | +//===----------------------------------------------------------------------===// |
| 106 | +// Implementation details for OpenMPI |
| 107 | +//===----------------------------------------------------------------------===// |
| 108 | +struct OMPIImplTraits { |
| 109 | + |
| 110 | + static mlir::LLVM::GlobalOp |
| 111 | + getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc, |
| 112 | + mlir::ConversionPatternRewriter &rewriter, |
| 113 | + mlir::StringRef name, |
| 114 | + mlir::LLVM::LLVMStructType type) { |
| 115 | + |
| 116 | + return getOrDefineGlobal<mlir::LLVM::GlobalOp>( |
| 117 | + moduleOp, loc, rewriter, name, type, /*isConstant=*/false, |
| 118 | + mlir::LLVM::Linkage::External, name, |
| 119 | + /*value=*/mlir::Attribute(), /*alignment=*/0, 0); |
| 120 | + } |
| 121 | + |
| 122 | + static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp, |
| 123 | + const mlir::Location loc, |
| 124 | + mlir::ConversionPatternRewriter &rewriter) { |
| 125 | + auto context = rewriter.getContext(); |
| 126 | + // get external opaque struct pointer type |
| 127 | + auto commStructT = |
| 128 | + mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context); |
| 129 | + mlir::StringRef name = "ompi_mpi_comm_world"; |
| 130 | + |
| 131 | + // make sure global op definition exists |
| 132 | + (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT); |
| 133 | + |
| 134 | + // get address of symbol |
| 135 | + return rewriter.create<mlir::LLVM::AddressOfOp>( |
| 136 | + loc, mlir::LLVM::LLVMPointerType::get(context), |
| 137 | + mlir::SymbolRefAttr::get(context, name)); |
| 138 | + } |
| 139 | + |
| 140 | + static intptr_t getStatusIgnore() { return 0; } |
| 141 | + |
| 142 | + static mlir::Value getDataType(mlir::ModuleOp &moduleOp, |
| 143 | + const mlir::Location loc, |
| 144 | + mlir::ConversionPatternRewriter &rewriter, |
| 145 | + mlir::Type type) { |
| 146 | + mlir::StringRef mtype; |
| 147 | + if (type.isF32()) |
| 148 | + mtype = "ompi_mpi_float"; |
| 149 | + else if (type.isF64()) |
| 150 | + mtype = "ompi_mpi_double"; |
| 151 | + else if (type.isInteger(64) && !type.isUnsignedInteger()) |
| 152 | + mtype = "ompi_mpi_int64_t"; |
| 153 | + else if (type.isInteger(64)) |
| 154 | + mtype = "ompi_mpi_uint64_t"; |
| 155 | + else if (type.isInteger(32) && !type.isUnsignedInteger()) |
| 156 | + mtype = "ompi_mpi_int32_t"; |
| 157 | + else if (type.isInteger(32)) |
| 158 | + mtype = "ompi_mpi_uint32_t"; |
| 159 | + else if (type.isInteger(16) && !type.isUnsignedInteger()) |
| 160 | + mtype = "ompi_mpi_int16_t"; |
| 161 | + else if (type.isInteger(16)) |
| 162 | + mtype = "ompi_mpi_uint16_t"; |
| 163 | + else if (type.isInteger(8) && !type.isUnsignedInteger()) |
| 164 | + mtype = "ompi_mpi_int8_t"; |
| 165 | + else if (type.isInteger(8)) |
| 166 | + mtype = "ompi_mpi_uint8_t"; |
| 167 | + else |
| 168 | + assert(false && "unsupported type"); |
| 169 | + |
| 170 | + auto context = rewriter.getContext(); |
| 171 | + // get external opaque struct pointer type |
| 172 | + auto commStructT = mlir::LLVM::LLVMStructType::getOpaque( |
| 173 | + "ompi_predefined_datatype_t", context); |
| 174 | + // make sure global op definition exists |
| 175 | + (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, |
| 176 | + commStructT); |
| 177 | + // get address of symbol |
| 178 | + return rewriter.create<mlir::LLVM::AddressOfOp>( |
| 179 | + loc, mlir::LLVM::LLVMPointerType::get(context), |
| 180 | + mlir::SymbolRefAttr::get(context, mtype)); |
| 181 | + } |
| 182 | +}; |
| 183 | + |
| 184 | +//===----------------------------------------------------------------------===// |
75 | 185 | // When lowering the mpi dialect to functions calls certain details
|
76 | 186 | // differ between various MPI implementations. This class will provide
|
77 |
| -// these depending on the MPI implementation that got included. |
| 187 | +// these in a gnereic way, depending on the MPI implementation that got |
| 188 | +// included. |
| 189 | +//===----------------------------------------------------------------------===// |
78 | 190 | struct MPIImplTraits {
|
79 | 191 | // get/create MPI_COMM_WORLD as a mlir::Value
|
80 | 192 | static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
|
81 | 193 | const mlir::Location loc,
|
82 |
| - mlir::ConversionPatternRewriter &rewriter); |
| 194 | + mlir::ConversionPatternRewriter &rewriter) { |
| 195 | + // TODO: dispatch based on the MPI implementation |
| 196 | + return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter); |
| 197 | + } |
| 198 | + // Get the MPI_STATUS_IGNORE value (typically a pointer type). |
| 199 | + static intptr_t getStatusIgnore() { |
| 200 | + // TODO: dispatch based on the MPI implementation |
| 201 | + return MPICHImplTraits::getStatusIgnore(); |
| 202 | + } |
83 | 203 | // get/create MPI datatype as a mlir::Value which corresponds to the given
|
84 | 204 | // mlir::Type
|
85 | 205 | static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
|
86 | 206 | const mlir::Location loc,
|
87 | 207 | mlir::ConversionPatternRewriter &rewriter,
|
88 |
| - mlir::Type type); |
| 208 | + mlir::Type type) { |
| 209 | + // TODO: dispatch based on the MPI implementation |
| 210 | + return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type); |
| 211 | + } |
89 | 212 | };
|
90 | 213 |
|
91 |
| -// **************************************************************************** |
92 |
| -// Intel MPI/MPICH |
93 |
| -#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS) |
94 |
| - |
95 |
| -mlir::Value |
96 |
| -MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc, |
97 |
| - mlir::ConversionPatternRewriter &rewriter) { |
98 |
| - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
99 |
| - MPI_COMM_WORLD); |
100 |
| -} |
101 |
| - |
102 |
| -mlir::Value |
103 |
| -MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc, |
104 |
| - mlir::ConversionPatternRewriter &rewriter, |
105 |
| - mlir::Type type) { |
106 |
| - int32_t mtype = 0; |
107 |
| - if (type.isF32()) |
108 |
| - mtype = MPI_FLOAT; |
109 |
| - else if (type.isF64()) |
110 |
| - mtype = MPI_DOUBLE; |
111 |
| - else if (type.isInteger(64) && !type.isUnsignedInteger()) |
112 |
| - mtype = MPI_INT64_T; |
113 |
| - else if (type.isInteger(64)) |
114 |
| - mtype = MPI_UINT64_T; |
115 |
| - else if (type.isInteger(32) && !type.isUnsignedInteger()) |
116 |
| - mtype = MPI_INT32_T; |
117 |
| - else if (type.isInteger(32)) |
118 |
| - mtype = MPI_UINT32_T; |
119 |
| - else if (type.isInteger(16) && !type.isUnsignedInteger()) |
120 |
| - mtype = MPI_INT16_T; |
121 |
| - else if (type.isInteger(16)) |
122 |
| - mtype = MPI_UINT16_T; |
123 |
| - else if (type.isInteger(8) && !type.isUnsignedInteger()) |
124 |
| - mtype = MPI_INT8_T; |
125 |
| - else if (type.isInteger(8)) |
126 |
| - mtype = MPI_UINT8_T; |
127 |
| - else |
128 |
| - assert(false && "unsupported type"); |
129 |
| - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
130 |
| - mtype); |
131 |
| -} |
132 |
| - |
133 |
| -// **************************************************************************** |
134 |
| -// OpenMPI |
135 |
| -#elif defined(OPEN_MPI) && OPEN_MPI == 1 |
136 |
| - |
137 |
| -static mlir::LLVM::GlobalOp |
138 |
| -getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc, |
139 |
| - mlir::ConversionPatternRewriter &rewriter, |
140 |
| - mlir::StringRef name, |
141 |
| - mlir::LLVM::LLVMStructType type) { |
142 |
| - |
143 |
| - return getOrDefineGlobal<mlir::LLVM::GlobalOp>( |
144 |
| - moduleOp, loc, rewriter, name, type, /*isConstant=*/false, |
145 |
| - mlir::LLVM::Linkage::External, name, |
146 |
| - /*value=*/mlir::Attribute(), /*alignment=*/0, 0); |
147 |
| -} |
148 |
| - |
149 |
| -mlir::Value |
150 |
| -MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc, |
151 |
| - mlir::ConversionPatternRewriter &rewriter) { |
152 |
| - auto context = rewriter.getContext(); |
153 |
| - // get external opaque struct pointer type |
154 |
| - auto commStructT = |
155 |
| - mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context); |
156 |
| - mlir::StringRef name = "ompi_mpi_comm_world"; |
157 |
| - |
158 |
| - // make sure global op definition exists |
159 |
| - (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT); |
160 |
| - |
161 |
| - // get address of symbol |
162 |
| - return rewriter.create<mlir::LLVM::AddressOfOp>( |
163 |
| - loc, mlir::LLVM::LLVMPointerType::get(context), |
164 |
| - mlir::SymbolRefAttr::get(context, name)); |
165 |
| -} |
166 |
| - |
167 |
| -mlir::Value |
168 |
| -MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc, |
169 |
| - mlir::ConversionPatternRewriter &rewriter, |
170 |
| - mlir::Type type) { |
171 |
| - mlir::StringRef mtype = nullptr; |
172 |
| - if (type.isF32()) |
173 |
| - mtype = "ompi_mpi_float"; |
174 |
| - else if (type.isF64()) |
175 |
| - mtype = "ompi_mpi_double"; |
176 |
| - else if (type.isInteger(64) && !type.isUnsignedInteger()) |
177 |
| - mtype = "ompi_mpi_int64_t"; |
178 |
| - else if (type.isInteger(64)) |
179 |
| - mtype = "ompi_mpi_uint64_t"; |
180 |
| - else if (type.isInteger(32) && !type.isUnsignedInteger()) |
181 |
| - mtype = "ompi_mpi_int32_t"; |
182 |
| - else if (type.isInteger(32)) |
183 |
| - mtype = "ompi_mpi_uint32_t"; |
184 |
| - else if (type.isInteger(16) && !type.isUnsignedInteger()) |
185 |
| - mtype = "ompi_mpi_int16_t"; |
186 |
| - else if (type.isInteger(16)) |
187 |
| - mtype = "ompi_mpi_uint16_t"; |
188 |
| - else if (type.isInteger(8) && !type.isUnsignedInteger()) |
189 |
| - mtype = "ompi_mpi_int8_t"; |
190 |
| - else if (type.isInteger(8)) |
191 |
| - mtype = "ompi_mpi_uint8_t"; |
192 |
| - else |
193 |
| - assert(false && "unsupported type"); |
194 |
| - |
195 |
| - auto context = rewriter.getContext(); |
196 |
| - // get external opaque struct pointer type |
197 |
| - auto commStructT = mlir::LLVM::LLVMStructType::getOpaque( |
198 |
| - "ompi_predefined_datatype_t", context); |
199 |
| - // make sure global op definition exists |
200 |
| - (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, commStructT); |
201 |
| - // get address of symbol |
202 |
| - return rewriter.create<mlir::LLVM::AddressOfOp>( |
203 |
| - loc, mlir::LLVM::LLVMPointerType::get(context), |
204 |
| - mlir::SymbolRefAttr::get(context, mtype)); |
205 |
| -} |
206 |
| - |
207 |
| -#else |
208 |
| -#error "Unsupported MPI implementation" |
209 |
| -#endif |
210 |
| - |
211 | 214 | //===----------------------------------------------------------------------===//
|
212 | 215 | // InitOpLowering
|
213 | 216 | //===----------------------------------------------------------------------===//
|
@@ -427,7 +430,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
|
427 | 430 | MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
|
428 | 431 | Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
|
429 | 432 | Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
|
430 |
| - loc, i64, reinterpret_cast<int64_t>(MPI_STATUS_IGNORE)); |
| 433 | + loc, i64, MPIImplTraits::getStatusIgnore()); |
431 | 434 | statusIgnore =
|
432 | 435 | rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
|
433 | 436 |
|
|
0 commit comments