Skip to content

Commit 1dee125

Browse files
authored
[mlir][mpi] Lowering MPI_Allreduce (llvm#133133)
Lowering of mpi.all_reduce to LLVM function call
1 parent 0ec9498 commit 1dee125

File tree

5 files changed

+243
-37
lines changed

5 files changed

+243
-37
lines changed

mlir/include/mlir/Dialect/MPI/IR/MPI.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,7 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
246246
MPI_OpMaxloc,
247247
MPI_OpReplace
248248
]> {
249-
let genSpecializedAttr = 0;
250249
let cppNamespace = "::mlir::mpi";
251250
}
252251

253-
def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
254-
let assemblyFormat = "`<` $value `>`";
255-
}
256-
257252
#endif // MLIR_DIALECT_MPI_IR_MPI_TD

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
244244
let arguments = (
245245
ins AnyMemRef : $sendbuf,
246246
AnyMemRef : $recvbuf,
247-
MPI_OpClassAttr : $op
247+
MPI_OpClassEnum : $op
248248
);
249249

250250
let results = (outs Optional<MPI_Retval>:$retval);

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 201 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
4747
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
4848
}
4949

50+
std::pair<Value, Value> getRawPtrAndSize(const Location loc,
51+
ConversionPatternRewriter &rewriter,
52+
Value memRef, Type elType) {
53+
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
54+
Value dataPtr =
55+
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
56+
Value offset = rewriter.create<LLVM::ExtractValueOp>(
57+
loc, rewriter.getI64Type(), memRef, 2);
58+
Value resPtr =
59+
rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
60+
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
61+
ArrayRef<int64_t>{3, 0});
62+
size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
63+
return {resPtr, size};
64+
}
65+
5066
/// When lowering the mpi dialect to functions calls certain details
5167
/// differ between various MPI implementations. This class will provide
5268
/// these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
7793
/// type.
7894
virtual Value getDataType(const Location loc,
7995
ConversionPatternRewriter &rewriter, Type type) = 0;
96+
97+
/// Gets or creates an MPI_Op value which corresponds to the given
98+
/// enum value.
99+
virtual Value getMPIOp(const Location loc,
100+
ConversionPatternRewriter &rewriter,
101+
mpi::MPI_OpClassEnum opAttr) = 0;
80102
};
81103

82104
//===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
94116
static constexpr int MPI_UINT16_T = 0x4c00023c;
95117
static constexpr int MPI_UINT32_T = 0x4c00043d;
96118
static constexpr int MPI_UINT64_T = 0x4c00083e;
119+
static constexpr int MPI_MAX = 0x58000001;
120+
static constexpr int MPI_MIN = 0x58000002;
121+
static constexpr int MPI_SUM = 0x58000003;
122+
static constexpr int MPI_PROD = 0x58000004;
123+
static constexpr int MPI_LAND = 0x58000005;
124+
static constexpr int MPI_BAND = 0x58000006;
125+
static constexpr int MPI_LOR = 0x58000007;
126+
static constexpr int MPI_BOR = 0x58000008;
127+
static constexpr int MPI_LXOR = 0x58000009;
128+
static constexpr int MPI_BXOR = 0x5800000a;
129+
static constexpr int MPI_MINLOC = 0x5800000b;
130+
static constexpr int MPI_MAXLOC = 0x5800000c;
131+
static constexpr int MPI_REPLACE = 0x5800000d;
132+
static constexpr int MPI_NO_OP = 0x5800000e;
97133

98134
public:
99135
using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
136172
assert(false && "unsupported type");
137173
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
138174
}
175+
176+
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
177+
mpi::MPI_OpClassEnum opAttr) override {
178+
int32_t op = MPI_NO_OP;
179+
switch (opAttr) {
180+
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
181+
op = MPI_NO_OP;
182+
break;
183+
case mpi::MPI_OpClassEnum::MPI_MAX:
184+
op = MPI_MAX;
185+
break;
186+
case mpi::MPI_OpClassEnum::MPI_MIN:
187+
op = MPI_MIN;
188+
break;
189+
case mpi::MPI_OpClassEnum::MPI_SUM:
190+
op = MPI_SUM;
191+
break;
192+
case mpi::MPI_OpClassEnum::MPI_PROD:
193+
op = MPI_PROD;
194+
break;
195+
case mpi::MPI_OpClassEnum::MPI_LAND:
196+
op = MPI_LAND;
197+
break;
198+
case mpi::MPI_OpClassEnum::MPI_BAND:
199+
op = MPI_BAND;
200+
break;
201+
case mpi::MPI_OpClassEnum::MPI_LOR:
202+
op = MPI_LOR;
203+
break;
204+
case mpi::MPI_OpClassEnum::MPI_BOR:
205+
op = MPI_BOR;
206+
break;
207+
case mpi::MPI_OpClassEnum::MPI_LXOR:
208+
op = MPI_LXOR;
209+
break;
210+
case mpi::MPI_OpClassEnum::MPI_BXOR:
211+
op = MPI_BXOR;
212+
break;
213+
case mpi::MPI_OpClassEnum::MPI_MINLOC:
214+
op = MPI_MINLOC;
215+
break;
216+
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
217+
op = MPI_MAXLOC;
218+
break;
219+
case mpi::MPI_OpClassEnum::MPI_REPLACE:
220+
op = MPI_REPLACE;
221+
break;
222+
}
223+
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
224+
}
139225
};
140226

141227
//===----------------------------------------------------------------------===//
@@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {
205291

206292
auto context = rewriter.getContext();
207293
// get external opaque struct pointer type
208-
auto commStructT =
294+
auto typeStructT =
209295
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
210296
// make sure global op definition exists
211-
getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
297+
getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
212298
// get address of symbol
213299
return rewriter.create<LLVM::AddressOfOp>(
214300
loc, LLVM::LLVMPointerType::get(context),
215301
SymbolRefAttr::get(context, mtype));
216302
}
303+
304+
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
305+
mpi::MPI_OpClassEnum opAttr) override {
306+
StringRef op;
307+
switch (opAttr) {
308+
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
309+
op = "ompi_mpi_no_op";
310+
break;
311+
case mpi::MPI_OpClassEnum::MPI_MAX:
312+
op = "ompi_mpi_max";
313+
break;
314+
case mpi::MPI_OpClassEnum::MPI_MIN:
315+
op = "ompi_mpi_min";
316+
break;
317+
case mpi::MPI_OpClassEnum::MPI_SUM:
318+
op = "ompi_mpi_sum";
319+
break;
320+
case mpi::MPI_OpClassEnum::MPI_PROD:
321+
op = "ompi_mpi_prod";
322+
break;
323+
case mpi::MPI_OpClassEnum::MPI_LAND:
324+
op = "ompi_mpi_land";
325+
break;
326+
case mpi::MPI_OpClassEnum::MPI_BAND:
327+
op = "ompi_mpi_band";
328+
break;
329+
case mpi::MPI_OpClassEnum::MPI_LOR:
330+
op = "ompi_mpi_lor";
331+
break;
332+
case mpi::MPI_OpClassEnum::MPI_BOR:
333+
op = "ompi_mpi_bor";
334+
break;
335+
case mpi::MPI_OpClassEnum::MPI_LXOR:
336+
op = "ompi_mpi_lxor";
337+
break;
338+
case mpi::MPI_OpClassEnum::MPI_BXOR:
339+
op = "ompi_mpi_bxor";
340+
break;
341+
case mpi::MPI_OpClassEnum::MPI_MINLOC:
342+
op = "ompi_mpi_minloc";
343+
break;
344+
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
345+
op = "ompi_mpi_maxloc";
346+
break;
347+
case mpi::MPI_OpClassEnum::MPI_REPLACE:
348+
op = "ompi_mpi_replace";
349+
break;
350+
}
351+
auto context = rewriter.getContext();
352+
// get external opaque struct pointer type
353+
auto opStructT =
354+
LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
355+
// make sure global op definition exists
356+
getOrDefineExternalStruct(loc, rewriter, op, opStructT);
357+
// get address of symbol
358+
return rewriter.create<LLVM::AddressOfOp>(
359+
loc, LLVM::LLVMPointerType::get(context),
360+
SymbolRefAttr::get(context, op));
361+
}
217362
};
218363

219364
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
@@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
365510
Location loc = op.getLoc();
366511
MLIRContext *context = rewriter.getContext();
367512
Type i32 = rewriter.getI32Type();
368-
Type i64 = rewriter.getI64Type();
369-
Value memRef = adaptor.getRef();
370513
Type elemType = op.getRef().getType().getElementType();
371514

372515
// ptrType `!llvm.ptr`
@@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
376519
auto moduleOp = op->getParentOfType<ModuleOp>();
377520

378521
// get MPI_COMM_WORLD, dataType and pointer
379-
Value dataPtr =
380-
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
381-
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
382-
dataPtr =
383-
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
384-
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
385-
ArrayRef<int64_t>{3, 0});
386-
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
522+
auto [dataPtr, size] =
523+
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
387524
auto mpiTraits = MPIImplTraits::get(moduleOp);
388525
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
389526
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
425562
MLIRContext *context = rewriter.getContext();
426563
Type i32 = rewriter.getI32Type();
427564
Type i64 = rewriter.getI64Type();
428-
Value memRef = adaptor.getRef();
429565
Type elemType = op.getRef().getType().getElementType();
430566

431567
// ptrType `!llvm.ptr`
@@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
435571
auto moduleOp = op->getParentOfType<ModuleOp>();
436572

437573
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
438-
Value dataPtr =
439-
rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
440-
Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
441-
dataPtr =
442-
rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
443-
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
444-
ArrayRef<int64_t>{3, 0});
445-
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
574+
auto [dataPtr, size] =
575+
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
446576
auto mpiTraits = MPIImplTraits::get(moduleOp);
447577
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
448578
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
474604
}
475605
};
476606

607+
//===----------------------------------------------------------------------===//
608+
// AllReduceOpLowering
609+
//===----------------------------------------------------------------------===//
610+
611+
struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
612+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
613+
614+
LogicalResult
615+
matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
616+
ConversionPatternRewriter &rewriter) const override {
617+
Location loc = op.getLoc();
618+
MLIRContext *context = rewriter.getContext();
619+
Type i32 = rewriter.getI32Type();
620+
Type elemType = op.getSendbuf().getType().getElementType();
621+
622+
// ptrType `!llvm.ptr`
623+
Type ptrType = LLVM::LLVMPointerType::get(context);
624+
auto moduleOp = op->getParentOfType<ModuleOp>();
625+
auto mpiTraits = MPIImplTraits::get(moduleOp);
626+
auto [sendPtr, sendSize] =
627+
getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
628+
auto [recvPtr, recvSize] =
629+
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
630+
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
631+
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
632+
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
633+
// 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
634+
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
635+
auto funcType = LLVM::LLVMFunctionType::get(
636+
i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
637+
commWorld.getType()});
638+
// get or create function declaration:
639+
LLVM::LLVMFuncOp funcDecl =
640+
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
641+
642+
// replace op with function call
643+
auto funcCall = rewriter.create<LLVM::CallOp>(
644+
loc, funcDecl,
645+
ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
646+
647+
if (op.getRetval())
648+
rewriter.replaceOp(op, funcCall.getResult());
649+
else
650+
rewriter.eraseOp(op);
651+
652+
return success();
653+
}
654+
};
655+
477656
//===----------------------------------------------------------------------===//
478657
// ConvertToLLVMPatternInterface implementation
479658
//===----------------------------------------------------------------------===//
@@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498677
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
499678
RewritePatternSet &patterns) {
500679
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
501-
SendOpLowering, RecvOpLowering>(converter);
680+
SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
502681
}
503682

504683
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

0 commit comments

Comments
 (0)