Skip to content

Commit 6464e20

Browse files
committed
MPI implementation selection at runtime
1 parent 7706fdf commit 6464e20

File tree

3 files changed

+192
-81
lines changed

3 files changed

+192
-81
lines changed

mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMPIToLLVM
1111
Core
1212

1313
LINK_LIBS PUBLIC
14+
MLIRDLTIDialect
1415
MLIRLLVMCommonConversion
1516
MLIRLLVMDialect
1617
MLIRMPIDialect

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
1616
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1717
#include "mlir/Conversion/LLVMCommon/Pattern.h"
18+
#include "mlir/Dialect/DLTI/DLTI.h"
1819
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1920
#include "mlir/Dialect/MPI/IR/MPI.h"
2021
#include "mlir/Transforms/DialectConversion.h"
@@ -184,29 +185,53 @@ struct OMPIImplTraits {
184185
//===----------------------------------------------------------------------===//
185186
// When lowering the mpi dialect to functions calls certain details
186187
// differ between various MPI implementations. This class will provide
187-
// these in a gnereic way, depending on the MPI implementation that got
188-
// included.
188+
// these in a generic way, depending on the MPI implementation that got
189+
// selected by the DLTI attribute on the module.
189190
//===----------------------------------------------------------------------===//
190191
struct MPIImplTraits {
192+
enum MPIImpl { MPICH, OMPI };
193+
194+
// Get the MPI implementation from a DLTI attribute on the module.
195+
// Default to MPICH (and ABI compatible).
196+
static MPIImpl getMPIImpl(mlir::ModuleOp &moduleOp) {
197+
auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
198+
if (failed(attr)) {
199+
return MPICH;
200+
}
201+
auto strAttr = dyn_cast<StringAttr>(attr.value());
202+
if (strAttr && strAttr.getValue() == "OpenMPI") {
203+
return OMPI;
204+
}
205+
return MPICH;
206+
}
207+
191208
// get/create MPI_COMM_WORLD as a mlir::Value
192209
static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
193210
const mlir::Location loc,
194211
mlir::ConversionPatternRewriter &rewriter) {
195-
// TODO: dispatch based on the MPI implementation
212+
if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
213+
return OMPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
214+
}
196215
return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
197216
}
217+
198218
// Get the MPI_STATUS_IGNORE value (typically a pointer type).
199-
static intptr_t getStatusIgnore() {
200-
// TODO: dispatch based on the MPI implementation
219+
static intptr_t getStatusIgnore(mlir::ModuleOp &moduleOp) {
220+
if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
221+
return OMPIImplTraits::getStatusIgnore();
222+
}
201223
return MPICHImplTraits::getStatusIgnore();
202224
}
225+
203226
// get/create MPI datatype as a mlir::Value which corresponds to the given
204227
// mlir::Type
205228
static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
206229
const mlir::Location loc,
207230
mlir::ConversionPatternRewriter &rewriter,
208231
mlir::Type type) {
209-
// TODO: dispatch based on the MPI implementation
232+
if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
233+
return OMPIImplTraits::getDataType(moduleOp, loc, rewriter, type);
234+
}
210235
return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
211236
}
212237
};
@@ -430,7 +455,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
430455
MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
431456
Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
432457
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
433-
loc, i64, MPIImplTraits::getStatusIgnore());
458+
loc, i64, MPIImplTraits::getStatusIgnore(moduleOp));
434459
statusIgnore =
435460
rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
436461

0 commit comments

Comments
 (0)