Skip to content

Commit 80179ce

Browse files
committed
lowering MPI_Allreduce (MPICH)
1 parent 4cabee3 commit 80179ce

File tree

5 files changed

+174
-33
lines changed

5 files changed

+174
-33
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPI.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,12 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
246246
MPI_OpMaxloc,
247247
MPI_OpReplace
248248
]> {
249-
let genSpecializedAttr = 0;
249+
// let genSpecializedAttr = 0;
250250
let cppNamespace = "::mlir::mpi";
251251
}
252252

253-
def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
254-
let assemblyFormat = "`<` $value `>`";
255-
}
253+
// def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
254+
// let assemblyFormat = "`<` $value `>`";
255+
// }
256256

257257
#endif // MLIR_DIALECT_MPI_IR_MPI_TD

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
244244
let arguments = (
245245
ins AnyMemRef : $sendbuf,
246246
AnyMemRef : $recvbuf,
247-
MPI_OpClassAttr : $op
247+
MPI_OpClassEnum : $op
248248
);
249249

250250
let results = (outs Optional<MPI_Retval>:$retval);

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 146 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
4747
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
4848
}
4949

50+
std::pair<Value, Value> getRawPtrAndSize(const Location loc,
51+
ConversionPatternRewriter &rewriter,
52+
Value memRef, Type elType) {
53+
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
54+
Value dataPtr =
55+
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
56+
Value offset = rewriter.create<LLVM::ExtractValueOp>(
57+
loc, rewriter.getI64Type(), memRef, 2);
58+
Value resPtr =
59+
rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
60+
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
61+
ArrayRef<int64_t>{3, 0});
62+
size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
63+
return {resPtr, size};
64+
}
65+
5066
/// When lowering the mpi dialect to functions calls certain details
5167
/// differ between various MPI implementations. This class will provide
5268
/// these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
7793
/// type.
7894
virtual Value getDataType(const Location loc,
7995
ConversionPatternRewriter &rewriter, Type type) = 0;
96+
97+
/// Gets or creates an MPI_Op value which corresponds to the given
98+
/// enum value.
99+
virtual Value getMPIOp(const Location loc,
100+
ConversionPatternRewriter &rewriter,
101+
mpi::MPI_OpClassEnum opAttr) = 0;
80102
};
81103

82104
//===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
94116
static constexpr int MPI_UINT16_T = 0x4c00023c;
95117
static constexpr int MPI_UINT32_T = 0x4c00043d;
96118
static constexpr int MPI_UINT64_T = 0x4c00083e;
119+
static constexpr int MPI_MAX = 0x58000001;
120+
static constexpr int MPI_MIN = 0x58000002;
121+
static constexpr int MPI_SUM = 0x58000003;
122+
static constexpr int MPI_PROD = 0x58000004;
123+
static constexpr int MPI_LAND = 0x58000005;
124+
static constexpr int MPI_BAND = 0x58000006;
125+
static constexpr int MPI_LOR = 0x58000007;
126+
static constexpr int MPI_BOR = 0x58000008;
127+
static constexpr int MPI_LXOR = 0x58000009;
128+
static constexpr int MPI_BXOR = 0x5800000a;
129+
static constexpr int MPI_MINLOC = 0x5800000b;
130+
static constexpr int MPI_MAXLOC = 0x5800000c;
131+
static constexpr int MPI_REPLACE = 0x5800000d;
132+
static constexpr int MPI_NO_OP = 0x5800000e;
97133

98134
public:
99135
using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
136172
assert(false && "unsupported type");
137173
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
138174
}
175+
176+
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
177+
mpi::MPI_OpClassEnum opAttr) override {
178+
int32_t op = MPI_NO_OP;
179+
switch (opAttr) {
180+
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
181+
op = MPI_NO_OP;
182+
break;
183+
case mpi::MPI_OpClassEnum::MPI_MAX:
184+
op = MPI_MAX;
185+
break;
186+
case mpi::MPI_OpClassEnum::MPI_MIN:
187+
op = MPI_MIN;
188+
break;
189+
case mpi::MPI_OpClassEnum::MPI_SUM:
190+
op = MPI_SUM;
191+
break;
192+
case mpi::MPI_OpClassEnum::MPI_PROD:
193+
op = MPI_PROD;
194+
break;
195+
case mpi::MPI_OpClassEnum::MPI_LAND:
196+
op = MPI_LAND;
197+
break;
198+
case mpi::MPI_OpClassEnum::MPI_BAND:
199+
op = MPI_BAND;
200+
break;
201+
case mpi::MPI_OpClassEnum::MPI_LOR:
202+
op = MPI_LOR;
203+
break;
204+
case mpi::MPI_OpClassEnum::MPI_BOR:
205+
op = MPI_BOR;
206+
break;
207+
case mpi::MPI_OpClassEnum::MPI_LXOR:
208+
op = MPI_LXOR;
209+
break;
210+
case mpi::MPI_OpClassEnum::MPI_BXOR:
211+
op = MPI_BXOR;
212+
break;
213+
case mpi::MPI_OpClassEnum::MPI_MINLOC:
214+
op = MPI_MINLOC;
215+
break;
216+
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
217+
op = MPI_MAXLOC;
218+
break;
219+
case mpi::MPI_OpClassEnum::MPI_REPLACE:
220+
op = MPI_REPLACE;
221+
break;
222+
}
223+
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
224+
}
139225
};
140226

141227
//===----------------------------------------------------------------------===//
@@ -214,6 +300,12 @@ class OMPIImplTraits : public MPIImplTraits {
214300
loc, LLVM::LLVMPointerType::get(context),
215301
SymbolRefAttr::get(context, mtype));
216302
}
303+
304+
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
305+
mpi::MPI_OpClassEnum opAttr) override {
306+
llvm_unreachable("getMPIOp not implemented for OpenMPI");
307+
return Value();
308+
}
217309
};
218310

219311
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
@@ -365,8 +457,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
365457
Location loc = op.getLoc();
366458
MLIRContext *context = rewriter.getContext();
367459
Type i32 = rewriter.getI32Type();
368-
Type i64 = rewriter.getI64Type();
369-
Value memRef = adaptor.getRef();
370460
Type elemType = op.getRef().getType().getElementType();
371461

372462
// ptrType `!llvm.ptr`
@@ -376,14 +466,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
376466
auto moduleOp = op->getParentOfType<ModuleOp>();
377467

378468
// get MPI_COMM_WORLD, dataType and pointer
379-
Value dataPtr =
380-
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
381-
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
382-
dataPtr =
383-
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
384-
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
385-
ArrayRef<int64_t>{3, 0});
386-
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
469+
auto [dataPtr, size] =
470+
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
387471
auto mpiTraits = MPIImplTraits::get(moduleOp);
388472
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
389473
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -425,7 +509,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
425509
MLIRContext *context = rewriter.getContext();
426510
Type i32 = rewriter.getI32Type();
427511
Type i64 = rewriter.getI64Type();
428-
Value memRef = adaptor.getRef();
429512
Type elemType = op.getRef().getType().getElementType();
430513

431514
// ptrType `!llvm.ptr`
@@ -435,14 +518,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
435518
auto moduleOp = op->getParentOfType<ModuleOp>();
436519

437520
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
438-
Value dataPtr =
439-
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
440-
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
441-
dataPtr =
442-
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
443-
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
444-
ArrayRef<int64_t>{3, 0});
445-
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
521+
auto [dataPtr, size] =
522+
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
446523
auto mpiTraits = MPIImplTraits::get(moduleOp);
447524
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
448525
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -474,6 +551,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
474551
}
475552
};
476553

554+
//===----------------------------------------------------------------------===//
555+
// AllReduceOpLowering
556+
//===----------------------------------------------------------------------===//
557+
558+
struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
559+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
560+
561+
LogicalResult
562+
matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
563+
ConversionPatternRewriter &rewriter) const override {
564+
Location loc = op.getLoc();
565+
MLIRContext *context = rewriter.getContext();
566+
Type i32 = rewriter.getI32Type();
567+
Type elemType = op.getSendbuf().getType().getElementType();
568+
569+
// ptrType `!llvm.ptr`
570+
Type ptrType = LLVM::LLVMPointerType::get(context);
571+
auto moduleOp = op->getParentOfType<ModuleOp>();
572+
auto mpiTraits = MPIImplTraits::get(moduleOp);
573+
auto [sendPtr, sendSize] =
574+
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
575+
auto [recvPtr, recvSize] =
576+
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
577+
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
578+
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
579+
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
580+
// 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
581+
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
582+
auto funcType = LLVM::LLVMFunctionType::get(
583+
i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
584+
commWorld.getType()});
585+
// get or create function declaration:
586+
LLVM::LLVMFuncOp funcDecl =
587+
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
588+
589+
// replace op with function call
590+
auto funcCall = rewriter.create<LLVM::CallOp>(
591+
loc, funcDecl,
592+
ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
593+
594+
if (op.getRetval())
595+
rewriter.replaceOp(op, funcCall.getResult());
596+
else
597+
rewriter.eraseOp(op);
598+
599+
return success();
600+
}
601+
};
602+
477603
//===----------------------------------------------------------------------===//
478604
// ConvertToLLVMPatternInterface implementation
479605
//===----------------------------------------------------------------------===//
@@ -498,7 +624,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498624
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
499625
RewritePatternSet &patterns) {
500626
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
501-
SendOpLowering, RecvOpLowering>(converter);
627+
SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
502628
}
503629

504630
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

mlir/test/Conversion/MPIToLLVM/ops.mlir

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
22

33
// COM: Test MPICH ABI
4-
// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
4+
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
55
// CHECK: llvm.func @MPI_Finalize() -> i32
66
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
77
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
88
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
99
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
10-
module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
10+
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
1111

1212
// CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
1313
func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
@@ -72,8 +72,23 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
7272
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
7373
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
7474
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
75-
76-
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
75+
76+
// CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
77+
// CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
78+
// CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
79+
// CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
80+
// CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
81+
// CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
82+
// CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
83+
// CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
84+
// CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
85+
// CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
86+
// CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
87+
// CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
88+
// CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
89+
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
90+
91+
// CHECK: llvm.call @MPI_Finalize() : () -> i32
7792
%3 = mpi.finalize : !mpi.retval
7893

7994
return
@@ -83,15 +98,15 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
8398
// -----
8499

85100
// COM: Test OpenMPI ABI
86-
// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
101+
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
87102
// CHECK: llvm.func @MPI_Finalize() -> i32
88103
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
89104
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
90105
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
91106
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
92107
// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
93108
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
94-
module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
109+
module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
95110

96111
// CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
97112
func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {

mlir/test/Dialect/MPI/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
4949
%err7 = mpi.barrier : !mpi.retval
5050

5151
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
52-
mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
52+
mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
5353

5454
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
55-
%err8 = mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
55+
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
5656

5757
// CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
5858
%rval = mpi.finalize : !mpi.retval

0 commit comments

Comments
 (0)