|
15 | 15 | #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
|
16 | 16 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
17 | 17 | #include "mlir/Conversion/LLVMCommon/Pattern.h"
|
| 18 | +#include "mlir/Dialect/DLTI/DLTI.h" |
18 | 19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
19 | 20 | #include "mlir/Dialect/MPI/IR/MPI.h"
|
20 | 21 | #include "mlir/Transforms/DialectConversion.h"
|
@@ -184,29 +185,53 @@ struct OMPIImplTraits {
|
184 | 185 | //===----------------------------------------------------------------------===//
|
185 | 186 | // When lowering the mpi dialect to functions calls certain details
|
186 | 187 | // differ between various MPI implementations. This class will provide
|
187 |
| -// these in a gnereic way, depending on the MPI implementation that got |
188 |
| -// included. |
| 188 | +// these in a generic way, depending on the MPI implementation that got |
| 189 | +// selected by the DLTI attribute on the module. |
189 | 190 | //===----------------------------------------------------------------------===//
|
190 | 191 | struct MPIImplTraits {
|
| 192 | + enum MPIImpl { MPICH, OMPI }; |
| 193 | + |
| 194 | + // Get the MPI implementation from a DLTI attribute on the module. |
| 195 | + // Default to MPICH (and ABI compatible). |
| 196 | + static MPIImpl getMPIImpl(mlir::ModuleOp &moduleOp) { |
| 197 | + auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true); |
| 198 | + if (failed(attr)) { |
| 199 | + return MPICH; |
| 200 | + } |
| 201 | + auto strAttr = dyn_cast<StringAttr>(attr.value()); |
| 202 | + if (strAttr && strAttr.getValue() == "OpenMPI") { |
| 203 | + return OMPI; |
| 204 | + } |
| 205 | + return MPICH; |
| 206 | + } |
| 207 | + |
191 | 208 | // get/create MPI_COMM_WORLD as a mlir::Value
|
192 | 209 | static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
|
193 | 210 | const mlir::Location loc,
|
194 | 211 | mlir::ConversionPatternRewriter &rewriter) {
|
195 |
| - // TODO: dispatch based on the MPI implementation |
| 212 | + if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) { |
| 213 | + return OMPIImplTraits::getCommWorld(moduleOp, loc, rewriter); |
| 214 | + } |
196 | 215 | return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
|
197 | 216 | }
|
| 217 | + |
198 | 218 | // Get the MPI_STATUS_IGNORE value (typically a pointer type).
|
199 |
| - static intptr_t getStatusIgnore() { |
200 |
| - // TODO: dispatch based on the MPI implementation |
| 219 | + static intptr_t getStatusIgnore(mlir::ModuleOp &moduleOp) { |
| 220 | + if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) { |
| 221 | + return OMPIImplTraits::getStatusIgnore(); |
| 222 | + } |
201 | 223 | return MPICHImplTraits::getStatusIgnore();
|
202 | 224 | }
|
| 225 | + |
203 | 226 | // get/create MPI datatype as a mlir::Value which corresponds to the given
|
204 | 227 | // mlir::Type
|
205 | 228 | static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
|
206 | 229 | const mlir::Location loc,
|
207 | 230 | mlir::ConversionPatternRewriter &rewriter,
|
208 | 231 | mlir::Type type) {
|
209 |
| - // TODO: dispatch based on the MPI implementation |
| 232 | + if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) { |
| 233 | + return OMPIImplTraits::getDataType(moduleOp, loc, rewriter, type); |
| 234 | + } |
210 | 235 | return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
|
211 | 236 | }
|
212 | 237 | };
|
@@ -430,7 +455,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
|
430 | 455 | MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
|
431 | 456 | Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
|
432 | 457 | Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
|
433 |
| - loc, i64, MPIImplTraits::getStatusIgnore()); |
| 458 | + loc, i64, MPIImplTraits::getStatusIgnore(moduleOp)); |
434 | 459 | statusIgnore =
|
435 | 460 | rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
|
436 | 461 |
|
|
0 commit comments