Skip to content

Commit 9d98c6b

Browse files
committed
add initial set of lowerings for MPI dialect
1 parent ac40463 commit 9d98c6b

File tree

8 files changed

+354
-1
lines changed

8 files changed

+354
-1
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//
2+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// See https://llvm.org/LICENSE.txt for license information.
4+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#ifndef MLIR_CONVERSION_MPITOLLVM_H
9+
#define MLIR_CONVERSION_MPITOLLVM_H
10+
11+
#include <memory>
12+
13+
namespace mlir {
14+
15+
class LLVMTypeConverter;
16+
class RewritePatternSet;
17+
class Pass;
18+
19+
#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
20+
#include "mlir/Conversion/Passes.h.inc"
21+
22+
namespace mpi {
23+
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
24+
RewritePatternSet &patterns);
25+
} // namespace mpi
26+
} // namespace mlir
27+
28+
#endif // MLIR_CONVERSION_MPITOLLVM_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
4343
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
4444
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
45+
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
4546
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4647
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4748
#include "mlir/Conversion/MathToLibm/MathToLibm.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,24 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
833833
];
834834
}
835835

836+
//===----------------------------------------------------------------------===//
837+
// MPItoLLVM
838+
//===----------------------------------------------------------------------===//
839+
840+
def MPIToLLVMConversionPass : Pass<"convert-mpi-to-llvm"> {
841+
let summary = "Convert MPI dialect operations to LLVM dialect function calls";
842+
let description = [{
843+
This pass converts MPI dialect operatoins to functions calls in the LLVM
844+
dialect targeting the MPI stable ABI.
845+
}];
846+
let dependentDialects = ["LLVM::LLVMDialect"];
847+
let options = [
848+
Option<"indexBitwidth", "index-bitwidth", "unsigned",
849+
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
850+
"Bitwidth of the index type, 0 to use size of machine word">,
851+
];
852+
}
853+
836854
//===----------------------------------------------------------------------===//
837855
// NVVMToLLVM
838856
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MPI/IR/MPITypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
3030
//===----------------------------------------------------------------------===//
3131

3232
def MPI_Retval : MPI_Type<"Retval", "retval"> {
33-
let summary = "MPI function call return value";
33+
let summary = "MPI function call return value (!mpi.retval)";
3434
let description = [{
3535
This type represents a return value from an MPI function call.
3636
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV)
3939
add_subdirectory(MemRefToEmitC)
4040
add_subdirectory(MemRefToLLVM)
4141
add_subdirectory(MemRefToSPIRV)
42+
add_subdirectory(MPIToLLVM)
4243
add_subdirectory(NVGPUToNVVM)
4344
add_subdirectory(NVVMToLLVM)
4445
add_subdirectory(OpenACCToSCF)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRMPIToLLVM
2+
MPIToLLVM.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRLLVMCommonConversion
15+
MLIRLLVMDialect
16+
MLIRMPIDialect
17+
)
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
10+
11+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
#include "mlir/Dialect/MPI/IR/MPI.h"
15+
#include "mlir/Pass/Pass.h"
16+
17+
namespace mlir {
18+
#define GEN_PASS_DEF_MPITOLLVMCONVERSIONPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
} // namespace mlir
21+
22+
using namespace mlir;
23+
24+
namespace {
25+
26+
struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
27+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
28+
29+
LogicalResult
30+
matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
31+
ConversionPatternRewriter &rewriter) const override;
32+
};
33+
34+
struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
35+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
36+
37+
LogicalResult
38+
matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
39+
ConversionPatternRewriter &rewriter) const override;
40+
};
41+
42+
struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
43+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
44+
45+
LogicalResult
46+
matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
47+
ConversionPatternRewriter &rewriter) const override;
48+
};
49+
50+
// TODO: this was copied from GPUOpsLowering.cpp:288
51+
// is this okay, or should this be moved to some common file?
52+
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
53+
ConversionPatternRewriter &rewriter,
54+
StringRef name,
55+
LLVM::LLVMFunctionType type) {
56+
LLVM::LLVMFuncOp ret;
57+
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
58+
ConversionPatternRewriter::InsertionGuard guard(rewriter);
59+
rewriter.setInsertionPointToStart(moduleOp.getBody());
60+
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
61+
LLVM::Linkage::External);
62+
}
63+
return ret;
64+
}
65+
66+
// TODO: this is pretty close to getOrDefineFunction, can probably be factored
67+
LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
68+
ConversionPatternRewriter &rewriter,
69+
StringRef name,
70+
LLVM::LLVMStructType type) {
71+
LLVM::GlobalOp ret;
72+
if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
73+
ConversionPatternRewriter::InsertionGuard guard(rewriter);
74+
rewriter.setInsertionPointToStart(moduleOp.getBody());
75+
ret = rewriter.create<LLVM::GlobalOp>(
76+
loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
77+
/*value=*/Attribute(), /*alignment=*/0, 0);
78+
}
79+
return ret;
80+
}
81+
82+
} // namespace
83+
84+
//===----------------------------------------------------------------------===//
85+
// InitOpLowering
86+
//===----------------------------------------------------------------------===//
87+
88+
LogicalResult
89+
InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
90+
ConversionPatternRewriter &rewriter) const {
91+
// get loc
92+
auto loc = op.getLoc();
93+
94+
// ptrType `!llvm.ptr`
95+
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
96+
97+
// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
98+
auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
99+
Value llvmnull = nullPtrOp.getRes();
100+
101+
// grab a reference to the global module op:
102+
auto moduleOp = op->getParentOfType<ModuleOp>();
103+
104+
// LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
105+
auto initFuncType =
106+
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
107+
// get or create function declaration:
108+
LLVM::LLVMFuncOp initDecl =
109+
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
110+
111+
// replace init with function call
112+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
113+
ValueRange{llvmnull, llvmnull});
114+
115+
return success();
116+
}
117+
118+
//===----------------------------------------------------------------------===//
119+
// FinalizeOpLowering
120+
//===----------------------------------------------------------------------===//
121+
122+
LogicalResult
123+
FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
124+
ConversionPatternRewriter &rewriter) const {
125+
// get loc
126+
auto loc = op.getLoc();
127+
128+
// grab a reference to the global module op:
129+
auto moduleOp = op->getParentOfType<ModuleOp>();
130+
131+
// LLVM Function type representing `i32 MPI_Finalize()`
132+
auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
133+
// get or create function declaration:
134+
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
135+
"MPI_Finalize", initFuncType);
136+
137+
// replace init with function call
138+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
139+
140+
return success();
141+
}
142+
143+
//===----------------------------------------------------------------------===//
144+
// CommRankLowering
145+
//===----------------------------------------------------------------------===//
146+
147+
LogicalResult
148+
CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
149+
ConversionPatternRewriter &rewriter) const {
150+
// get some helper vars
151+
auto loc = op.getLoc();
152+
auto context = rewriter.getContext();
153+
auto i32 = rewriter.getI32Type();
154+
155+
// ptrType `!llvm.ptr`
156+
Type ptrType = LLVM::LLVMPointerType::get(context);
157+
158+
// get external opaque struct pointer type
159+
auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
160+
161+
// grab a reference to the global module op:
162+
auto moduleOp = op->getParentOfType<ModuleOp>();
163+
164+
// make sure global op definition exists
165+
getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
166+
commStructT);
167+
168+
// get address of @MPI_COMM_WORLD
169+
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
170+
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
171+
auto commWorld = rewriter.create<LLVM::AddressOfOp>(
172+
loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
173+
174+
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
175+
auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
176+
// get or create function declaration:
177+
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
178+
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
179+
180+
// replace init with function call
181+
auto callOp = rewriter.create<LLVM::CallOp>(
182+
loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
183+
184+
// load the rank into a register
185+
auto loadedRank =
186+
rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
187+
188+
// if retval is checked, replace uses of retval with the results from the call
189+
// op
190+
SmallVector<Value> replacements;
191+
if (op.getRetval()) {
192+
replacements.push_back(callOp.getResult());
193+
}
194+
// replace all uses, then erase op
195+
replacements.push_back(loadedRank.getRes());
196+
rewriter.replaceOp(op, replacements);
197+
198+
return success();
199+
}
200+
201+
//===----------------------------------------------------------------------===//
202+
// Pass Definition
203+
//===----------------------------------------------------------------------===//
204+
205+
namespace {
206+
struct MPIToLLVMConversionPass
207+
: public impl::MPIToLLVMConversionPassBase<MPIToLLVMConversionPass> {
208+
using Base::Base;
209+
210+
void runOnOperation() override {
211+
LLVMConversionTarget target(getContext());
212+
RewritePatternSet patterns(&getContext());
213+
214+
LowerToLLVMOptions options(&getContext());
215+
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
216+
options.overrideIndexBitwidth(indexBitwidth);
217+
218+
target.addIllegalDialect<mpi::MPIDialect>();
219+
220+
// not yet implemented, will be added in future patches:
221+
target.addLegalOp<mpi::RecvOp>();
222+
target.addLegalOp<mpi::SendOp>();
223+
target.addLegalOp<mpi::ErrorClassOp>();
224+
target.addLegalOp<mpi::RetvalCheckOp>();
225+
226+
LLVMTypeConverter converter(&getContext(), options);
227+
228+
converter.addConversion(
229+
[&](mpi::RetvalType) { return IntegerType::get(&getContext(), 32); });
230+
231+
mpi::populateMPIToLLVMConversionPatterns(converter, patterns);
232+
233+
if (failed(applyPartialConversion(getOperation(), target,
234+
std::move(patterns))))
235+
signalPassFailure();
236+
}
237+
};
238+
} // namespace
239+
240+
//===----------------------------------------------------------------------===//
241+
// Pattern Population
242+
//===----------------------------------------------------------------------===//
243+
244+
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
245+
RewritePatternSet &patterns) {
246+
patterns.add<InitOpLowering>(converter);
247+
patterns.add<CommRankOpLowering>(converter);
248+
patterns.add<FinalizeOpLowering>(converter);
249+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: mlir-opt -convert-mpi-to-llvm %s | FileCheck %s
2+
3+
module {
4+
// CHECK: llvm.func @MPI_Finalize() -> i32
5+
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
6+
// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
7+
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
8+
9+
func.func @mpi_test(%arg0: memref<100xf32>) {
10+
%0 = mpi.init : !mpi.retval
11+
// CHECK: %0 = llvm.mlir.zero : !llvm.ptr
12+
// CHECK: %1 = llvm.call @MPI_Init(%0, %0) : (!llvm.ptr, !llvm.ptr) -> i32
13+
// CHECK: %2 = builtin.unrealized_conversion_cast %1 : i32 to !mpi.retval
14+
15+
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
16+
// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
17+
// CHECK: %4 = llvm.alloca %3 x i32 : (i32) -> !llvm.ptr
18+
// CHECK: %5 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
19+
// CHECK: %6 = llvm.call @MPI_Comm_rank(%5, %4) : (!llvm.ptr, !llvm.ptr) -> i32
20+
// CHECK: %7 = llvm.load %4 : !llvm.ptr -> i32
21+
// CHECK: %8 = builtin.unrealized_conversion_cast %6 : i32 to !mpi.retval
22+
23+
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
24+
25+
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
26+
27+
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
28+
29+
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
30+
31+
%3 = mpi.finalize : !mpi.retval
32+
// CHECK: %11 = llvm.call @MPI_Finalize() : () -> i32
33+
34+
%4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
35+
36+
%5 = mpi.error_class %0 : !mpi.retval
37+
return
38+
}
39+
}

0 commit comments

Comments
 (0)