@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
47
47
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
48
48
}
49
49
50
+ std::pair<Value, Value> getRawPtrAndSize (const Location loc,
51
+ ConversionPatternRewriter &rewriter,
52
+ Value memRef, Type elType) {
53
+ Type ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
54
+ Value dataPtr =
55
+ rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
56
+ Value offset = rewriter.create <LLVM::ExtractValueOp>(
57
+ loc, rewriter.getI64Type (), memRef, 2 );
58
+ Value resPtr =
59
+ rewriter.create <LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
60
+ Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
61
+ ArrayRef<int64_t >{3 , 0 });
62
+ size = rewriter.create <LLVM::TruncOp>(loc, rewriter.getI32Type (), size);
63
+ return {resPtr, size};
64
+ }
65
+
50
66
// / When lowering the mpi dialect to functions calls certain details
51
67
// / differ between various MPI implementations. This class will provide
52
68
// / these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
77
93
// / type.
78
94
virtual Value getDataType (const Location loc,
79
95
ConversionPatternRewriter &rewriter, Type type) = 0;
96
+
97
+ // / Gets or creates an MPI_Op value which corresponds to the given
98
+ // / enum value.
99
+ virtual Value getMPIOp (const Location loc,
100
+ ConversionPatternRewriter &rewriter,
101
+ mpi::MPI_OpClassEnum opAttr) = 0;
80
102
};
81
103
82
104
// ===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
94
116
static constexpr int MPI_UINT16_T = 0x4c00023c ;
95
117
static constexpr int MPI_UINT32_T = 0x4c00043d ;
96
118
static constexpr int MPI_UINT64_T = 0x4c00083e ;
119
+ static constexpr int MPI_MAX = 0x58000001 ;
120
+ static constexpr int MPI_MIN = 0x58000002 ;
121
+ static constexpr int MPI_SUM = 0x58000003 ;
122
+ static constexpr int MPI_PROD = 0x58000004 ;
123
+ static constexpr int MPI_LAND = 0x58000005 ;
124
+ static constexpr int MPI_BAND = 0x58000006 ;
125
+ static constexpr int MPI_LOR = 0x58000007 ;
126
+ static constexpr int MPI_BOR = 0x58000008 ;
127
+ static constexpr int MPI_LXOR = 0x58000009 ;
128
+ static constexpr int MPI_BXOR = 0x5800000a ;
129
+ static constexpr int MPI_MINLOC = 0x5800000b ;
130
+ static constexpr int MPI_MAXLOC = 0x5800000c ;
131
+ static constexpr int MPI_REPLACE = 0x5800000d ;
132
+ static constexpr int MPI_NO_OP = 0x5800000e ;
97
133
98
134
public:
99
135
using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
136
172
assert (false && " unsupported type" );
137
173
return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (), mtype);
138
174
}
175
+
176
+ Value getMPIOp (const Location loc, ConversionPatternRewriter &rewriter,
177
+ mpi::MPI_OpClassEnum opAttr) override {
178
+ int32_t op = MPI_NO_OP;
179
+ switch (opAttr) {
180
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
181
+ op = MPI_NO_OP;
182
+ break ;
183
+ case mpi::MPI_OpClassEnum::MPI_MAX:
184
+ op = MPI_MAX;
185
+ break ;
186
+ case mpi::MPI_OpClassEnum::MPI_MIN:
187
+ op = MPI_MIN;
188
+ break ;
189
+ case mpi::MPI_OpClassEnum::MPI_SUM:
190
+ op = MPI_SUM;
191
+ break ;
192
+ case mpi::MPI_OpClassEnum::MPI_PROD:
193
+ op = MPI_PROD;
194
+ break ;
195
+ case mpi::MPI_OpClassEnum::MPI_LAND:
196
+ op = MPI_LAND;
197
+ break ;
198
+ case mpi::MPI_OpClassEnum::MPI_BAND:
199
+ op = MPI_BAND;
200
+ break ;
201
+ case mpi::MPI_OpClassEnum::MPI_LOR:
202
+ op = MPI_LOR;
203
+ break ;
204
+ case mpi::MPI_OpClassEnum::MPI_BOR:
205
+ op = MPI_BOR;
206
+ break ;
207
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
208
+ op = MPI_LXOR;
209
+ break ;
210
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
211
+ op = MPI_BXOR;
212
+ break ;
213
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
214
+ op = MPI_MINLOC;
215
+ break ;
216
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
217
+ op = MPI_MAXLOC;
218
+ break ;
219
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
220
+ op = MPI_REPLACE;
221
+ break ;
222
+ }
223
+ return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (), op);
224
+ }
139
225
};
140
226
141
227
// ===----------------------------------------------------------------------===//
@@ -214,6 +300,12 @@ class OMPIImplTraits : public MPIImplTraits {
214
300
loc, LLVM::LLVMPointerType::get (context),
215
301
SymbolRefAttr::get (context, mtype));
216
302
}
303
+
304
+ Value getMPIOp (const Location loc, ConversionPatternRewriter &rewriter,
305
+ mpi::MPI_OpClassEnum opAttr) override {
306
+ llvm_unreachable (" getMPIOp not implemented for OpenMPI" );
307
+ return Value ();
308
+ }
217
309
};
218
310
219
311
std::unique_ptr<MPIImplTraits> MPIImplTraits::get (ModuleOp &moduleOp) {
@@ -365,8 +457,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
365
457
Location loc = op.getLoc ();
366
458
MLIRContext *context = rewriter.getContext ();
367
459
Type i32 = rewriter.getI32Type ();
368
- Type i64 = rewriter.getI64Type ();
369
- Value memRef = adaptor.getRef ();
370
460
Type elemType = op.getRef ().getType ().getElementType ();
371
461
372
462
// ptrType `!llvm.ptr`
@@ -376,14 +466,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
376
466
auto moduleOp = op->getParentOfType <ModuleOp>();
377
467
378
468
// get MPI_COMM_WORLD, dataType and pointer
379
- Value dataPtr =
380
- rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
381
- Value offset = rewriter.create <LLVM::ExtractValueOp>(loc, i64 , memRef, 2 );
382
- dataPtr =
383
- rewriter.create <LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
384
- Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
385
- ArrayRef<int64_t >{3 , 0 });
386
- size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
469
+ auto [dataPtr, size] =
470
+ getRawPtrAndSize (loc, rewriter, adaptor.getRef (), elemType);
387
471
auto mpiTraits = MPIImplTraits::get (moduleOp);
388
472
Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
389
473
Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
@@ -425,7 +509,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
425
509
MLIRContext *context = rewriter.getContext ();
426
510
Type i32 = rewriter.getI32Type ();
427
511
Type i64 = rewriter.getI64Type ();
428
- Value memRef = adaptor.getRef ();
429
512
Type elemType = op.getRef ().getType ().getElementType ();
430
513
431
514
// ptrType `!llvm.ptr`
@@ -435,14 +518,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
435
518
auto moduleOp = op->getParentOfType <ModuleOp>();
436
519
437
520
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
438
- Value dataPtr =
439
- rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
440
- Value offset = rewriter.create <LLVM::ExtractValueOp>(loc, i64 , memRef, 2 );
441
- dataPtr =
442
- rewriter.create <LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
443
- Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
444
- ArrayRef<int64_t >{3 , 0 });
445
- size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
521
+ auto [dataPtr, size] =
522
+ getRawPtrAndSize (loc, rewriter, adaptor.getRef (), elemType);
446
523
auto mpiTraits = MPIImplTraits::get (moduleOp);
447
524
Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
448
525
Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
@@ -474,6 +551,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
474
551
}
475
552
};
476
553
554
+ // ===----------------------------------------------------------------------===//
555
+ // AllReduceOpLowering
556
+ // ===----------------------------------------------------------------------===//
557
+
558
+ struct AllReduceOpLowering : public ConvertOpToLLVMPattern <mpi::AllReduceOp> {
559
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
560
+
561
+ LogicalResult
562
+ matchAndRewrite (mpi::AllReduceOp op, OpAdaptor adaptor,
563
+ ConversionPatternRewriter &rewriter) const override {
564
+ Location loc = op.getLoc ();
565
+ MLIRContext *context = rewriter.getContext ();
566
+ Type i32 = rewriter.getI32Type ();
567
+ Type elemType = op.getSendbuf ().getType ().getElementType ();
568
+
569
+ // ptrType `!llvm.ptr`
570
+ Type ptrType = LLVM::LLVMPointerType::get (context);
571
+ auto moduleOp = op->getParentOfType <ModuleOp>();
572
+ auto mpiTraits = MPIImplTraits::get (moduleOp);
573
+ auto [sendPtr, sendSize] =
574
+ getRawPtrAndSize (loc, rewriter, adaptor.getSendbuf (), elemType);
575
+ auto [recvPtr, recvSize] =
576
+ getRawPtrAndSize (loc, rewriter, adaptor.getRecvbuf (), elemType);
577
+ Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
578
+ Value mpiOp = mpiTraits->getMPIOp (loc, rewriter, op.getOp ());
579
+ Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
580
+ // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
581
+ // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
582
+ auto funcType = LLVM::LLVMFunctionType::get (
583
+ i32 , {ptrType, ptrType, i32 , dataType.getType (), mpiOp.getType (),
584
+ commWorld.getType ()});
585
+ // get or create function declaration:
586
+ LLVM::LLVMFuncOp funcDecl =
587
+ getOrDefineFunction (moduleOp, loc, rewriter, " MPI_Allreduce" , funcType);
588
+
589
+ // replace op with function call
590
+ auto funcCall = rewriter.create <LLVM::CallOp>(
591
+ loc, funcDecl,
592
+ ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
593
+
594
+ if (op.getRetval ())
595
+ rewriter.replaceOp (op, funcCall.getResult ());
596
+ else
597
+ rewriter.eraseOp (op);
598
+
599
+ return success ();
600
+ }
601
+ };
602
+
477
603
// ===----------------------------------------------------------------------===//
478
604
// ConvertToLLVMPatternInterface implementation
479
605
// ===----------------------------------------------------------------------===//
@@ -498,7 +624,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498
624
void mpi::populateMPIToLLVMConversionPatterns (LLVMTypeConverter &converter,
499
625
RewritePatternSet &patterns) {
500
626
patterns.add <CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
501
- SendOpLowering, RecvOpLowering>(converter);
627
+ SendOpLowering, RecvOpLowering, AllReduceOpLowering >(converter);
502
628
}
503
629
504
630
void mpi::registerConvertMPIToLLVMInterface (DialectRegistry ®istry) {
0 commit comments