Skip to content

[mlir][mpi] Lowering MPI_Allreduce #133133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions mlir/include/mlir/Dialect/MPI/IR/MPI.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,7 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpMaxloc,
MPI_OpReplace
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mpi";
}

def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // MLIR_DIALECT_MPI_IR_MPI_TD
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
MPI_OpClassAttr : $op
MPI_OpClassEnum : $op
);

let results = (outs Optional<MPI_Retval>:$retval);
Expand Down
223 changes: 201 additions & 22 deletions mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
}

std::pair<Value, Value> getRawPtrAndSize(const Location loc,
ConversionPatternRewriter &rewriter,
Value memRef, Type elType) {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value dataPtr =
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
Value offset = rewriter.create<LLVM::ExtractValueOp>(
loc, rewriter.getI64Type(), memRef, 2);
Value resPtr =
rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
ArrayRef<int64_t>{3, 0});
size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
return {resPtr, size};
}

/// When lowering the mpi dialect to functions calls certain details
/// differ between various MPI implementations. This class will provide
/// these in a generic way, depending on the MPI implementation that got
Expand Down Expand Up @@ -77,6 +93,12 @@ class MPIImplTraits {
/// type.
virtual Value getDataType(const Location loc,
ConversionPatternRewriter &rewriter, Type type) = 0;

/// Gets or creates an MPI_Op value which corresponds to the given
/// enum value.
virtual Value getMPIOp(const Location loc,
ConversionPatternRewriter &rewriter,
mpi::MPI_OpClassEnum opAttr) = 0;
};

//===----------------------------------------------------------------------===//
Expand All @@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
static constexpr int MPI_UINT16_T = 0x4c00023c;
static constexpr int MPI_UINT32_T = 0x4c00043d;
static constexpr int MPI_UINT64_T = 0x4c00083e;
static constexpr int MPI_MAX = 0x58000001;
static constexpr int MPI_MIN = 0x58000002;
static constexpr int MPI_SUM = 0x58000003;
static constexpr int MPI_PROD = 0x58000004;
static constexpr int MPI_LAND = 0x58000005;
static constexpr int MPI_BAND = 0x58000006;
static constexpr int MPI_LOR = 0x58000007;
static constexpr int MPI_BOR = 0x58000008;
static constexpr int MPI_LXOR = 0x58000009;
static constexpr int MPI_BXOR = 0x5800000a;
static constexpr int MPI_MINLOC = 0x5800000b;
static constexpr int MPI_MAXLOC = 0x5800000c;
static constexpr int MPI_REPLACE = 0x5800000d;
static constexpr int MPI_NO_OP = 0x5800000e;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about reusing the same values as in the incoming MPI 5.0 ABI?

https://github.com/mpi-forum/mpi-abi-stubs/blob/e89a80017a3fe9a05d903ced2564c6342d678165/mpi.h#L47-L62

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are current MPI implementations supporting this yet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, each implementation uses their own values. you can check some of them in here https://github.com/JuliaParallel/MPI.jl/tree/master/src/api

i guess that in the future they will change to the ones in MPI 5, that's why i suggested but it's not mandatory for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is agreement that we'll go for explicit specialization for now until the ABI is broadly available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mofeing As I said already, MPICH supports the MPI-5 ABI already (you must enable it with configure, as it is not the default). I will sync Mukautuva to the final MPI-5 ABI in a few days.
#133280 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's great to know!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great.
Once we there are at least 2 popular implementations supporting it we should switch to the ABI.


public:
using MPIImplTraits::MPIImplTraits;
Expand Down Expand Up @@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
assert(false && "unsupported type");
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
}

Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
mpi::MPI_OpClassEnum opAttr) override {
int32_t op = MPI_NO_OP;
switch (opAttr) {
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
op = MPI_NO_OP;
break;
case mpi::MPI_OpClassEnum::MPI_MAX:
op = MPI_MAX;
break;
case mpi::MPI_OpClassEnum::MPI_MIN:
op = MPI_MIN;
break;
case mpi::MPI_OpClassEnum::MPI_SUM:
op = MPI_SUM;
break;
case mpi::MPI_OpClassEnum::MPI_PROD:
op = MPI_PROD;
break;
case mpi::MPI_OpClassEnum::MPI_LAND:
op = MPI_LAND;
break;
case mpi::MPI_OpClassEnum::MPI_BAND:
op = MPI_BAND;
break;
case mpi::MPI_OpClassEnum::MPI_LOR:
op = MPI_LOR;
break;
case mpi::MPI_OpClassEnum::MPI_BOR:
op = MPI_BOR;
break;
case mpi::MPI_OpClassEnum::MPI_LXOR:
op = MPI_LXOR;
break;
case mpi::MPI_OpClassEnum::MPI_BXOR:
op = MPI_BXOR;
break;
case mpi::MPI_OpClassEnum::MPI_MINLOC:
op = MPI_MINLOC;
break;
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
op = MPI_MAXLOC;
break;
case mpi::MPI_OpClassEnum::MPI_REPLACE:
op = MPI_REPLACE;
break;
}
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
}
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {

auto context = rewriter.getContext();
// get external opaque struct pointer type
auto commStructT =
auto typeStructT =
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
// get address of symbol
return rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context),
SymbolRefAttr::get(context, mtype));
}

Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
mpi::MPI_OpClassEnum opAttr) override {
StringRef op;
switch (opAttr) {
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
op = "ompi_mpi_no_op";
break;
case mpi::MPI_OpClassEnum::MPI_MAX:
op = "ompi_mpi_max";
break;
case mpi::MPI_OpClassEnum::MPI_MIN:
op = "ompi_mpi_min";
break;
case mpi::MPI_OpClassEnum::MPI_SUM:
op = "ompi_mpi_sum";
break;
case mpi::MPI_OpClassEnum::MPI_PROD:
op = "ompi_mpi_prod";
break;
case mpi::MPI_OpClassEnum::MPI_LAND:
op = "ompi_mpi_land";
break;
case mpi::MPI_OpClassEnum::MPI_BAND:
op = "ompi_mpi_band";
break;
case mpi::MPI_OpClassEnum::MPI_LOR:
op = "ompi_mpi_lor";
break;
case mpi::MPI_OpClassEnum::MPI_BOR:
op = "ompi_mpi_bor";
break;
case mpi::MPI_OpClassEnum::MPI_LXOR:
op = "ompi_mpi_lxor";
break;
case mpi::MPI_OpClassEnum::MPI_BXOR:
op = "ompi_mpi_bxor";
break;
case mpi::MPI_OpClassEnum::MPI_MINLOC:
op = "ompi_mpi_minloc";
break;
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
op = "ompi_mpi_maxloc";
break;
case mpi::MPI_OpClassEnum::MPI_REPLACE:
op = "ompi_mpi_replace";
break;
}
auto context = rewriter.getContext();
// get external opaque struct pointer type
auto opStructT =
LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, op, opStructT);
// get address of symbol
return rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context),
SymbolRefAttr::get(context, op));
}
};

std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
Expand Down Expand Up @@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();

// ptrType `!llvm.ptr`
Expand All @@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();

// get MPI_COMM_WORLD, dataType and pointer
Value dataPtr =
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
dataPtr =
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
ArrayRef<int64_t>{3, 0});
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
auto [dataPtr, size] =
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
Expand Down Expand Up @@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();

// ptrType `!llvm.ptr`
Expand All @@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();

// get MPI_COMM_WORLD, dataType, status_ignore and pointer
Value dataPtr =
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
dataPtr =
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
ArrayRef<int64_t>{3, 0});
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
auto [dataPtr, size] =
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
Expand Down Expand Up @@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};

//===----------------------------------------------------------------------===//
// AllReduceOpLowering
//===----------------------------------------------------------------------===//

struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type elemType = op.getSendbuf().getType().getElementType();

// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
auto [sendPtr, sendSize] =
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
auto [recvPtr, recvSize] =
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
// 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
auto funcType = LLVM::LLVMFunctionType::get(
i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
commWorld.getType()});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);

// replace op with function call
auto funcCall = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});

if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
rewriter.eraseOp(op);

return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
Expand All @@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
SendOpLowering, RecvOpLowering>(converter);
SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
}

void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
Expand Down
Loading
Loading