Skip to content

[MLIR][MPI] Add LLVM lowering patterns for some MPI operations #95524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AntonLydike
Copy link
Contributor

The first set of patterns to convert the MPI dialect to LLVM.

Further conversion pattern will be added in future PRs.

@llvmbot llvmbot added the mlir label Jun 14, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 14, 2024

@llvm/pr-subscribers-mlir

Author: Anton Lydike (AntonLydike)

Changes

The first set of patterns to convert the MPI dialect to LLVM.

Further conversion pattern will be added in future PRs.


Full diff: https://github.com/llvm/llvm-project/pull/95524.diff

8 Files Affected:

  • (added) mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h (+28)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+18)
  • (modified) mlir/include/mlir/Dialect/MPI/IR/MPITypes.td (+1-1)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt (+17)
  • (added) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+249)
  • (added) mlir/test/Conversion/MPIToLLVM/ops.mlir (+39)
diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
new file mode 100644
index 0000000000000..181e3c3e72b3f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -0,0 +1,28 @@
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MPITOLLVM_H
+#define MLIR_CONVERSION_MPITOLLVM_H
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mpi {
+void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         RewritePatternSet &patterns);
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MPITOLLVM_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 7700299b3a4f3..4b4e40d4f7463 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -42,6 +42,7 @@
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eb58f4adc31d3..e947c9fc49d8c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -833,6 +833,24 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MPItoLLVM
+//===----------------------------------------------------------------------===//
+
+def MPIToLLVMConversionPass : Pass<"convert-mpi-to-llvm"> {
+  let summary = "Convert MPI dialect operations to LLVM dialect function calls";
+  let description = [{
+    This pass converts MPI dialect operatoins to functions calls in the LLVM
+    dialect targeting the MPI stable ABI.
+  }];
+  let dependentDialects = ["LLVM::LLVMDialect"];
+  let options = [
+    Option<"indexBitwidth", "index-bitwidth", "unsigned",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           "Bitwidth of the index type, 0 to use size of machine word">,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index 87eefa719d45c..57ac512642829 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
 //===----------------------------------------------------------------------===//
 
 def MPI_Retval : MPI_Type<"Retval", "retval"> {
-  let summary = "MPI function call return value";
+  let summary = "MPI function call return value (!mpi.retval)";
   let description = [{
     This type represents a return value from an MPI function call.
     This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0a03a2e133db1..46e3768801560 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MPIToLLVM)
 add_subdirectory(NVGPUToNVVM)
 add_subdirectory(NVVMToLLVM)
 add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..f81fb25e56840
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRMPIToLLVM
+  MPIToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRMPIDialect
+  )
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
new file mode 100644
index 0000000000000..c4581dfbf3656
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -0,0 +1,249 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+// TODO: this was copied from GPUOpsLowering.cpp:288
+// is this okay, or should this be moved to some common file?
+LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
+                                     ConversionPatternRewriter &rewriter,
+                                     StringRef name,
+                                     LLVM::LLVMFunctionType type) {
+  LLVM::LLVMFuncOp ret;
+  if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
+                                            LLVM::Linkage::External);
+  }
+  return ret;
+}
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
+                                         ConversionPatternRewriter &rewriter,
+                                         StringRef name,
+                                         LLVM::LLVMStructType type) {
+  LLVM::GlobalOp ret;
+  if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.create<LLVM::GlobalOp>(
+        loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
+        /*value=*/Attribute(), /*alignment=*/0, 0);
+  }
+  return ret;
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// InitOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const {
+  // get loc
+  auto loc = op.getLoc();
+
+  // ptrType `!llvm.ptr`
+  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+  // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+  auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+  Value llvmnull = nullPtrOp.getRes();
+
+  // grab a reference to the global module op:
+  auto moduleOp = op->getParentOfType<ModuleOp>();
+
+  // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+  auto initFuncType =
+      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+  // get or create function declaration:
+  LLVM::LLVMFuncOp initDecl =
+      getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+
+  // replace init with function call
+  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+                                            ValueRange{llvmnull, llvmnull});
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FinalizeOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  // get loc
+  auto loc = op.getLoc();
+
+  // grab a reference to the global module op:
+  auto moduleOp = op->getParentOfType<ModuleOp>();
+
+  // LLVM Function type representing `i32 MPI_Finalize()`
+  auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+  // get or create function declaration:
+  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
+                                                  "MPI_Finalize", initFuncType);
+
+  // replace init with function call
+  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CommRankLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  // get some helper vars
+  auto loc = op.getLoc();
+  auto context = rewriter.getContext();
+  auto i32 = rewriter.getI32Type();
+
+  // ptrType `!llvm.ptr`
+  Type ptrType = LLVM::LLVMPointerType::get(context);
+
+  // get external opaque struct pointer type
+  auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+
+  // grab a reference to the global module op:
+  auto moduleOp = op->getParentOfType<ModuleOp>();
+
+  // make sure global op definition exists
+  getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
+                            commStructT);
+
+  // get address of @MPI_COMM_WORLD
+  auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+  auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+  auto commWorld = rewriter.create<LLVM::AddressOfOp>(
+      loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+
+  // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+  auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
+  // get or create function declaration:
+  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+      moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
+
+  // replace init with function call
+  auto callOp = rewriter.create<LLVM::CallOp>(
+      loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
+
+  // load the rank into a register
+  auto loadedRank =
+      rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+
+  // if retval is checked, replace uses of retval with the results from the call
+  // op
+  SmallVector<Value> replacements;
+  if (op.getRetval()) {
+    replacements.push_back(callOp.getResult());
+  }
+  // replace all uses, then erase op
+  replacements.push_back(loadedRank.getRes());
+  rewriter.replaceOp(op, replacements);
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct MPIToLLVMConversionPass
+    : public impl::MPIToLLVMConversionPassBase<MPIToLLVMConversionPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    LLVMConversionTarget target(getContext());
+    RewritePatternSet patterns(&getContext());
+
+    LowerToLLVMOptions options(&getContext());
+    if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+      options.overrideIndexBitwidth(indexBitwidth);
+
+    target.addIllegalDialect<mpi::MPIDialect>();
+
+    // not yet implemented, will be added in future patches:
+    target.addLegalOp<mpi::RecvOp>();
+    target.addLegalOp<mpi::SendOp>();
+    target.addLegalOp<mpi::ErrorClassOp>();
+    target.addLegalOp<mpi::RetvalCheckOp>();
+
+    LLVMTypeConverter converter(&getContext(), options);
+
+    converter.addConversion(
+        [&](mpi::RetvalType) { return IntegerType::get(&getContext(), 32); });
+
+    mpi::populateMPIToLLVMConversionPatterns(converter, patterns);
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                              RewritePatternSet &patterns) {
+  patterns.add<InitOpLowering>(converter);
+  patterns.add<CommRankOpLowering>(converter);
+  patterns.add<FinalizeOpLowering>(converter);
+}
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
new file mode 100644
index 0000000000000..71bd7ba464e67
--- /dev/null
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -convert-mpi-to-llvm %s | FileCheck %s
+
+module {
+// CHECK:  llvm.func @MPI_Finalize() -> i32
+// CHECK:  llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK:  llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+// CHECK:  llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+
+  func.func @mpi_test(%arg0: memref<100xf32>) {
+    %0 = mpi.init : !mpi.retval
+// CHECK: %0 = llvm.mlir.zero : !llvm.ptr
+// CHECK: %1 = llvm.call @MPI_Init(%0, %0) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: %2 = builtin.unrealized_conversion_cast %1 : i32 to !mpi.retval
+
+    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %4 = llvm.alloca %3 x i32 : (i32) -> !llvm.ptr
+// CHECK: %5 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
+// CHECK: %6 = llvm.call @MPI_Comm_rank(%5, %4) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: %7 = llvm.load %4 : !llvm.ptr -> i32
+// CHECK: %8 = builtin.unrealized_conversion_cast %6 : i32 to !mpi.retval
+
+    mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+    %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+    %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    %3 = mpi.finalize : !mpi.retval
+// CHECK: %11 = llvm.call @MPI_Finalize() : () -> i32
+
+    %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+
+    %5 = mpi.error_class %0 : !mpi.retval
+    return
+  }
+}

@AntonLydike
Copy link
Contributor Author

Tagging @Groverkss @sjw36 and @joker-eph for review requests, as you helped with the previous MPI patch. Also tagging @zero9178 for review.

@tschuett tschuett requested a review from joker-eph June 15, 2024 15:40
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this simply hook into the convert-to-llvm pass?

@AntonLydike
Copy link
Contributor Author

Can this simply hook into the convert-to-llvm pass?

Yes! I thought that that was what I was already doing. I got some advice on how to properly hook into that, will refactor this PR!

@AntonLydike AntonLydike changed the title [MLIR][MPI] Add first part of an convert-mpi-to-llvm lowering [MLIR][MPI] Add LLVM lowering patterns for some MPI operations Jun 28, 2024
@AntonLydike
Copy link
Contributor Author

@joker-eph Updated to be part of convert-to-llvm infrastructure, is this better?

@AntonLydike AntonLydike requested a review from zero9178 July 30, 2024 16:52
@@ -14,6 +14,7 @@
#ifndef MLIR_INITALLEXTENSIONS_H_
#define MLIR_INITALLEXTENSIONS_H_

#include "Conversion/MPIToLLVM/MPIToLLVM.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include "Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"

#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Pass/Pass.h"

#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"


// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, Location loc,

Comment on lines +52 to +53
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
auto ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (ret)
return ret;
...

ulta nit: Looks nicer as an early return in my opinion, Ditto below

InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// get loc
auto loc = op.getLoc();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto loc = op.getLoc();
Location loc = op.getLoc();

Ditto in other places where the type of the variable does not appear in the right-hand side expression

Comment on lines +187 to +189
if (op.getRetval()) {
replacements.push_back(callOp.getResult());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (op.getRetval()) {
replacements.push_back(callOp.getResult());
}
if (op.getRetval())
replacements.push_back(callOp.getResult());

Comment on lines +11 to +22
// CHECK: %7 = llvm.mlir.zero : !llvm.ptr
// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval


%retval, %rank = mpi.comm_rank : !mpi.retval, i32
// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests should be using filecheck variables for the SSA values. It is also convention that the // of the comments have the same indentation as the body.

You can also try (but don't see this as blocking) to slim down the tests a bit to not test the syntax of the LLVM operations so much, but rather just the lowering.
Just

%[[NULL_PTR:.*]] = llvm.mlir.zero
%[[INIT_RAW:.*]] = llvm.call @MPI_INIT(%[[NULL_PTR, %[[NULL_PTR]])
%[[INIT:.*]] = builtin.unrealized_conversion_cast %[[INIT_RAW]] : i32 to !mpi.retval

is fine (maybe splitting some of the type signatures with // CHECK-SAME: on the next lines if you feel they are very relevant.


func.func @mpi_test(%arg0: memref<100xf32>) {
%0 = mpi.init : !mpi.retval
// CHECK: %7 = llvm.mlir.zero : !llvm.ptr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it has always been convention that the check lines for a particular ops lowering are before the op being lowered.

namespace {
/// Implement the interface to convert Func to LLVM.
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;

Comment on lines +45 to +48

// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
} // namespace
// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,

LLVM convention is to mark functions as static rather than having them in anonymous namespaces. Ditto the other function

@fschlimb
Copy link
Contributor

Will this get finalized/merged in the foreseeable future?

@fschlimb
Copy link
Contributor

@AntonLydike I extended it to also lower send/recv here: fschlimb@0300e63

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants