Skip to content

[mlir][mpi] fixing in-place and 0d mpi.all_reduce #134225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025

Conversation

fschlimb
Copy link
Contributor

@fschlimb fschlimb commented Apr 3, 2025

  • inplace allreduce needs special MPI token MPI_IN_PLACE as send buffer
  • 0d tensors have no sizes/strides in LLVM memref struct

@llvmbot llvmbot added the mlir label Apr 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 3, 2025

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes
  • inplace allreduce needs special MPI token MPI_IN_PLACE as send buffer
  • 0d tensors have no sizes/strides in LLVM memref struct

Full diff: https://github.com/llvm/llvm-project/pull/134225.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+29-5)
  • (modified) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+6-2)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9df5e992e8ebd..5575b295ae20a 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <memory>
@@ -57,9 +59,14 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
       loc, rewriter.getI64Type(), memRef, 2);
   Value resPtr =
       rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
-  Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
-                                                     ArrayRef<int64_t>{3, 0});
-  size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+  Value size;
+  if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
+    size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+                                                 ArrayRef<int64_t>{3, 0});
+    size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+  } else {
+    size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
+  }
   return {resPtr, size};
 }
 
@@ -97,6 +104,9 @@ class MPIImplTraits {
   /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
   virtual intptr_t getStatusIgnore() = 0;
 
+  /// Get the MPI_IN_PLACE value (void *).
+  virtual void *getInPlace() = 0;
+
   /// Gets or creates an MPI datatype as a value which corresponds to the given
   /// type.
   virtual Value getDataType(const Location loc,
@@ -158,6 +168,8 @@ class MPICHImplTraits : public MPIImplTraits {
 
   intptr_t getStatusIgnore() override { return 1; }
 
+  void *getInPlace() override { return reinterpret_cast<void *>(-1); }
+
   Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
                     Type type) override {
     int32_t mtype = 0;
@@ -283,6 +295,8 @@ class OMPIImplTraits : public MPIImplTraits {
 
   intptr_t getStatusIgnore() override { return 0; }
 
+  void *getInPlace() override { return reinterpret_cast<void *>(1); }
+
   Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
                     Type type) override {
     StringRef mtype;
@@ -516,7 +530,8 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
                    outPtr.getRes()});
 
     // load the communicator into a register
-    auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+    Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+    res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
 
     // if retval is checked, replace uses of retval with the results from the
     // call op
@@ -525,7 +540,7 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
       replacements.push_back(callOp.getResult());
 
     // replace op
-    replacements.push_back(res.getRes());
+    replacements.push_back(res);
     rewriter.replaceOp(op, replacements);
 
     return success();
@@ -709,6 +724,7 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
     Location loc = op.getLoc();
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
     Type elemType = op.getSendbuf().getType().getElementType();
 
     // ptrType `!llvm.ptr`
@@ -719,6 +735,14 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
         getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
     auto [recvPtr, recvSize] =
         getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+
+    // If input and output are the same, request in-place operation.
+    if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
+      sendPtr = rewriter.create<LLVM::ConstantOp>(
+          loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+      sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
+    }
+
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
     Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 174f7c79b9d50..35fc0f5d2e754 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -98,10 +98,12 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
+    // CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
+    // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
     // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
     // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
-    // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+    // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: llvm.call @MPI_Finalize() : () -> i32
@@ -202,10 +204,12 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
     // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
     // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
     // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
-    // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+    // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32

@fschlimb fschlimb requested review from Dinistro and AntonLydike April 3, 2025 10:17
@fschlimb
Copy link
Contributor Author

fschlimb commented Apr 3, 2025

CC @tkarna @mofeing

Copy link
Contributor

@AntonLydike AntonLydike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on those edge cases! LGTM from my side!

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fschlimb fschlimb merged commit 586c5e3 into llvm:main Apr 3, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants