-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][MPI] Add LLVM lowering patterns for some MPI operations #95524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Anton Lydike (AntonLydike) ChangesThe first set of patterns to convert the MPI dialect to LLVM. Further conversion pattern will be added in future PRs. Full diff: https://github.com/llvm/llvm-project/pull/95524.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
new file mode 100644
index 0000000000000..181e3c3e72b3f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -0,0 +1,28 @@
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MPITOLLVM_H
+#define MLIR_CONVERSION_MPITOLLVM_H
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mpi {
+void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MPITOLLVM_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 7700299b3a4f3..4b4e40d4f7463 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -42,6 +42,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eb58f4adc31d3..e947c9fc49d8c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -833,6 +833,24 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// MPItoLLVM
+//===----------------------------------------------------------------------===//
+
+def MPIToLLVMConversionPass : Pass<"convert-mpi-to-llvm"> {
+ let summary = "Convert MPI dialect operations to LLVM dialect function calls";
+ let description = [{
+ This pass converts MPI dialect operatoins to functions calls in the LLVM
+ dialect targeting the MPI stable ABI.
+ }];
+ let dependentDialects = ["LLVM::LLVMDialect"];
+ let options = [
+ Option<"indexBitwidth", "index-bitwidth", "unsigned",
+ /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+ "Bitwidth of the index type, 0 to use size of machine word">,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// NVVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index 87eefa719d45c..57ac512642829 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
def MPI_Retval : MPI_Type<"Retval", "retval"> {
- let summary = "MPI function call return value";
+ let summary = "MPI function call return value (!mpi.retval)";
let description = [{
This type represents a return value from an MPI function call.
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0a03a2e133db1..46e3768801560 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..f81fb25e56840
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRMPIToLLVM
+ MPIToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMPIDialect
+ )
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
new file mode 100644
index 0000000000000..c4581dfbf3656
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -0,0 +1,249 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+// TODO: this was copied from GPUOpsLowering.cpp:288
+// is this okay, or should this be moved to some common file?
+LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ LLVM::LLVMFuncOp ret;
+ if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
+ LLVM::Linkage::External);
+ }
+ return ret;
+}
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMStructType type) {
+ LLVM::GlobalOp ret;
+ if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
+ /*value=*/Attribute(), /*alignment=*/0, 0);
+ }
+ return ret;
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// InitOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get loc
+ auto loc = op.getLoc();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+ auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ Value llvmnull = nullPtrOp.getRes();
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+ auto initFuncType =
+ LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+ ValueRange{llvmnull, llvmnull});
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FinalizeOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get loc
+ auto loc = op.getLoc();
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // LLVM Function type representing `i32 MPI_Finalize()`
+ auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
+ "MPI_Finalize", initFuncType);
+
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CommRankLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get some helper vars
+ auto loc = op.getLoc();
+ auto context = rewriter.getContext();
+ auto i32 = rewriter.getI32Type();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // get external opaque struct pointer type
+ auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // make sure global op definition exists
+ getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
+ commStructT);
+
+ // get address of @MPI_COMM_WORLD
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+ auto commWorld = rewriter.create<LLVM::AddressOfOp>(
+ loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+
+ // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+ auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
+
+ // replace init with function call
+ auto callOp = rewriter.create<LLVM::CallOp>(
+ loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
+
+ // load the rank into a register
+ auto loadedRank =
+ rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the call
+ // op
+ SmallVector<Value> replacements;
+ if (op.getRetval()) {
+ replacements.push_back(callOp.getResult());
+ }
+ // replace all uses, then erase op
+ replacements.push_back(loadedRank.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct MPIToLLVMConversionPass
+ : public impl::MPIToLLVMConversionPassBase<MPIToLLVMConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ LLVMConversionTarget target(getContext());
+ RewritePatternSet patterns(&getContext());
+
+ LowerToLLVMOptions options(&getContext());
+ if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+ options.overrideIndexBitwidth(indexBitwidth);
+
+ target.addIllegalDialect<mpi::MPIDialect>();
+
+ // not yet implemented, will be added in future patches:
+ target.addLegalOp<mpi::RecvOp>();
+ target.addLegalOp<mpi::SendOp>();
+ target.addLegalOp<mpi::ErrorClassOp>();
+ target.addLegalOp<mpi::RetvalCheckOp>();
+
+ LLVMTypeConverter converter(&getContext(), options);
+
+ converter.addConversion(
+ [&](mpi::RetvalType) { return IntegerType::get(&getContext(), 32); });
+
+ mpi::populateMPIToLLVMConversionPatterns(converter, patterns);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<InitOpLowering>(converter);
+ patterns.add<CommRankOpLowering>(converter);
+ patterns.add<FinalizeOpLowering>(converter);
+}
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
new file mode 100644
index 0000000000000..71bd7ba464e67
--- /dev/null
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -convert-mpi-to-llvm %s | FileCheck %s
+
+module {
+// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+
+ func.func @mpi_test(%arg0: memref<100xf32>) {
+ %0 = mpi.init : !mpi.retval
+// CHECK: %0 = llvm.mlir.zero : !llvm.ptr
+// CHECK: %1 = llvm.call @MPI_Init(%0, %0) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: %2 = builtin.unrealized_conversion_cast %1 : i32 to !mpi.retval
+
+ %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %4 = llvm.alloca %3 x i32 : (i32) -> !llvm.ptr
+// CHECK: %5 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
+// CHECK: %6 = llvm.call @MPI_Comm_rank(%5, %4) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: %7 = llvm.load %4 : !llvm.ptr -> i32
+// CHECK: %8 = builtin.unrealized_conversion_cast %6 : i32 to !mpi.retval
+
+ mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+ %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+ %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ %3 = mpi.finalize : !mpi.retval
+// CHECK: %11 = llvm.call @MPI_Finalize() : () -> i32
+
+ %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+
+ %5 = mpi.error_class %0 : !mpi.retval
+ return
+ }
+}
|
Tagging @Groverkss @sjw36 and @joker-eph for review requests, as you helped with the previous MPI patch. Also tagging @zero9178 for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this simply hook into the convert-to-llvm pass?
Yes! I thought that that was what I was already doing. I got some advice on how to properly hook into that, will refactor this PR! |
9d98c6b
to
e6d70ab
Compare
convert-mpi-to-llvm
loweringe6d70ab
to
7d74834
Compare
7d74834
to
1e3b942
Compare
@joker-eph Updated to be part of |
@@ -14,6 +14,7 @@ | |||
#ifndef MLIR_INITALLEXTENSIONS_H_ | |||
#define MLIR_INITALLEXTENSIONS_H_ | |||
|
|||
#include "Conversion/MPIToLLVM/MPIToLLVM.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include "Conversion/MPIToLLVM/MPIToLLVM.h" | |
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" |
#include "mlir/Dialect/MPI/IR/MPI.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h> | |
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
|
||
// TODO: this was copied from GPUOpsLowering.cpp:288 | ||
// is this okay, or should this be moved to some common file? | ||
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc, | |
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, Location loc, |
LLVM::LLVMFuncOp ret; | ||
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM::LLVMFuncOp ret; | |
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) { | |
auto ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name); | |
if (ret) | |
return ret; | |
... |
ulta nit: Looks nicer as an early return in my opinion, Ditto below
InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
// get loc | ||
auto loc = op.getLoc(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto loc = op.getLoc(); | |
Location loc = op.getLoc(); |
Ditto in other places where the type of the variable does not appear in the right-hand side expression
if (op.getRetval()) { | ||
replacements.push_back(callOp.getResult()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (op.getRetval()) { | |
replacements.push_back(callOp.getResult()); | |
} | |
if (op.getRetval()) | |
replacements.push_back(callOp.getResult()); | |
// CHECK: %7 = llvm.mlir.zero : !llvm.ptr | ||
// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32 | ||
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval | ||
|
||
|
||
%retval, %rank = mpi.comm_rank : !mpi.retval, i32 | ||
// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32 | ||
// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr | ||
// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr | ||
// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32 | ||
// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32 | ||
// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests should be using filecheck variables for the SSA values. It is also convention that the //
of the comments have the same indentation as the body.
You can also try (but don't see this as blocking) to slim down the tests a bit to not test the syntax of the LLVM operations so much, but rather just the lowering.
Just
%[[NULL_PTR:.*]] = llvm.mlir.zero
%[[INIT_RAW:.*]] = llvm.call @MPI_INIT(%[[NULL_PTR, %[[NULL_PTR]])
%[[INIT:.*]] = builtin.unrealized_conversion_cast %[[INIT_RAW]] : i32 to !mpi.retval
is fine (maybe splitting some of the type signatures with // CHECK-SAME:
on the next lines if you feel they are very relevant.
|
||
func.func @mpi_test(%arg0: memref<100xf32>) { | ||
%0 = mpi.init : !mpi.retval | ||
// CHECK: %7 = llvm.mlir.zero : !llvm.ptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it has always been convention that the check lines for a particular ops lowering are before the op being lowered.
namespace { | ||
/// Implement the interface to convert Func to LLVM. | ||
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { | ||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; | |
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; | |
|
||
// TODO: this was copied from GPUOpsLowering.cpp:288 | ||
// is this okay, or should this be moved to some common file? | ||
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// TODO: this was copied from GPUOpsLowering.cpp:288 | |
// is this okay, or should this be moved to some common file? | |
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc, | |
} // namespace | |
// TODO: this was copied from GPUOpsLowering.cpp:288 | |
// is this okay, or should this be moved to some common file? | |
static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc, |
LLVM convention is to mark functions as static
rather than having them in anonymous namespaces. Ditto the other function
Will this get finalized/merged in the foreseeable future? |
@AntonLydike I extended it to also lower send/recv here: fschlimb@0300e63 |
The first set of patterns to convert the MPI dialect to LLVM.
Further conversion pattern will be added in future PRs.