-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mpi] fixing in-place and 0d mpi.all_reduce #134225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fschlimb
commented
Apr 3, 2025
- inplace allreduce needs special MPI token MPI_IN_PLACE as send buffer
- 0d tensors have no sizes/strides in LLVM memref struct
@llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) Changes
Full diff: https://github.com/llvm/llvm-project/pull/134225.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9df5e992e8ebd..5575b295ae20a 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -15,8 +15,10 @@
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
@@ -57,9 +59,14 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
loc, rewriter.getI64Type(), memRef, 2);
Value resPtr =
rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
- Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ Value size;
+ if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
+ size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+ ArrayRef<int64_t>{3, 0});
+ size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ } else {
+ size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
+ }
return {resPtr, size};
}
@@ -97,6 +104,9 @@ class MPIImplTraits {
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
virtual intptr_t getStatusIgnore() = 0;
+ /// Get the MPI_IN_PLACE value (void *).
+ virtual void *getInPlace() = 0;
+
/// Gets or creates an MPI datatype as a value which corresponds to the given
/// type.
virtual Value getDataType(const Location loc,
@@ -158,6 +168,8 @@ class MPICHImplTraits : public MPIImplTraits {
intptr_t getStatusIgnore() override { return 1; }
+ void *getInPlace() override { return reinterpret_cast<void *>(-1); }
+
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
Type type) override {
int32_t mtype = 0;
@@ -283,6 +295,8 @@ class OMPIImplTraits : public MPIImplTraits {
intptr_t getStatusIgnore() override { return 0; }
+ void *getInPlace() override { return reinterpret_cast<void *>(1); }
+
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
Type type) override {
StringRef mtype;
@@ -516,7 +530,8 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
outPtr.getRes()});
// load the communicator into a register
- auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+ Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+ res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
// if retval is checked, replace uses of retval with the results from the
// call op
@@ -525,7 +540,7 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
replacements.push_back(callOp.getResult());
// replace op
- replacements.push_back(res.getRes());
+ replacements.push_back(res);
rewriter.replaceOp(op, replacements);
return success();
@@ -709,6 +724,7 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
+ Type i64 = rewriter.getI64Type();
Type elemType = op.getSendbuf().getType().getElementType();
// ptrType `!llvm.ptr`
@@ -719,6 +735,14 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
auto [recvPtr, recvSize] =
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+
+ // If input and output are the same, request in-place operation.
+ if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
+ sendPtr = rewriter.create<LLVM::ConstantOp>(
+ loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+ sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
+ }
+
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 174f7c79b9d50..35fc0f5d2e754 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -98,10 +98,12 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
+ // CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
+ // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
// CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+ // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: llvm.call @MPI_Finalize() : () -> i32
@@ -202,10 +204,12 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
// CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch on those edge cases! LGTM from my side!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM