Skip to content

[mlir][mpi] fixing in-place and 0d mpi.all_reduce #134225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
Expand Down Expand Up @@ -57,9 +59,14 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
loc, rewriter.getI64Type(), memRef, 2);
Value resPtr =
rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
ArrayRef<int64_t>{3, 0});
size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
Value size;
if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
ArrayRef<int64_t>{3, 0});
size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
} else {
size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
}
return {resPtr, size};
}

Expand Down Expand Up @@ -97,6 +104,9 @@ class MPIImplTraits {
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
virtual intptr_t getStatusIgnore() = 0;

/// Get the MPI_IN_PLACE value (void *).
virtual void *getInPlace() = 0;

/// Gets or creates an MPI datatype as a value which corresponds to the given
/// type.
virtual Value getDataType(const Location loc,
Expand Down Expand Up @@ -158,6 +168,8 @@ class MPICHImplTraits : public MPIImplTraits {

intptr_t getStatusIgnore() override { return 1; }

void *getInPlace() override { return reinterpret_cast<void *>(-1); }

Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
Type type) override {
int32_t mtype = 0;
Expand Down Expand Up @@ -283,6 +295,8 @@ class OMPIImplTraits : public MPIImplTraits {

intptr_t getStatusIgnore() override { return 0; }

void *getInPlace() override { return reinterpret_cast<void *>(1); }

Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
Type type) override {
StringRef mtype;
Expand Down Expand Up @@ -516,7 +530,8 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
outPtr.getRes()});

// load the communicator into a register
auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);

// if retval is checked, replace uses of retval with the results from the
// call op
Expand All @@ -525,7 +540,7 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
replacements.push_back(callOp.getResult());

// replace op
replacements.push_back(res.getRes());
replacements.push_back(res);
rewriter.replaceOp(op, replacements);

return success();
Expand Down Expand Up @@ -709,6 +724,7 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Type elemType = op.getSendbuf().getType().getElementType();

// ptrType `!llvm.ptr`
Expand All @@ -719,6 +735,14 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
auto [recvPtr, recvSize] =
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);

// If input and output are the same, request in-place operation.
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
sendPtr = rewriter.create<LLVM::ConstantOp>(
loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
}

Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
Expand Down
8 changes: 6 additions & 2 deletions mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
// CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
// CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
// CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>

// CHECK: llvm.call @MPI_Finalize() : () -> i32
Expand Down Expand Up @@ -202,10 +204,12 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
// CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
// CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>

// CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
Expand Down