@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
47
47
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
48
48
}
49
49
50
+ std::pair<Value, Value> getRawPtrAndSize (const Location loc,
51
+ ConversionPatternRewriter &rewriter,
52
+ Value memRef, Type elType) {
53
+ Type ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
54
+ Value dataPtr =
55
+ rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
56
+ Value offset = rewriter.create <LLVM::ExtractValueOp>(
57
+ loc, rewriter.getI64Type (), memRef, 2 );
58
+ Value resPtr =
59
+ rewriter.create <LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
60
+ Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
61
+ ArrayRef<int64_t >{3 , 0 });
62
+ size = rewriter.create <LLVM::TruncOp>(loc, rewriter.getI32Type (), size);
63
+ return {resPtr, size};
64
+ }
65
+
50
66
// / When lowering the mpi dialect to functions calls certain details
51
67
// / differ between various MPI implementations. This class will provide
52
68
// / these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
77
93
// / type.
78
94
virtual Value getDataType (const Location loc,
79
95
ConversionPatternRewriter &rewriter, Type type) = 0;
96
+
97
+ // / Gets or creates an MPI_Op value which corresponds to the given
98
+ // / enum value.
99
+ virtual Value getMPIOp (const Location loc,
100
+ ConversionPatternRewriter &rewriter,
101
+ mpi::MPI_OpClassEnum opAttr) = 0;
80
102
};
81
103
82
104
// ===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
94
116
static constexpr int MPI_UINT16_T = 0x4c00023c ;
95
117
static constexpr int MPI_UINT32_T = 0x4c00043d ;
96
118
static constexpr int MPI_UINT64_T = 0x4c00083e ;
119
+ static constexpr int MPI_MAX = 0x58000001 ;
120
+ static constexpr int MPI_MIN = 0x58000002 ;
121
+ static constexpr int MPI_SUM = 0x58000003 ;
122
+ static constexpr int MPI_PROD = 0x58000004 ;
123
+ static constexpr int MPI_LAND = 0x58000005 ;
124
+ static constexpr int MPI_BAND = 0x58000006 ;
125
+ static constexpr int MPI_LOR = 0x58000007 ;
126
+ static constexpr int MPI_BOR = 0x58000008 ;
127
+ static constexpr int MPI_LXOR = 0x58000009 ;
128
+ static constexpr int MPI_BXOR = 0x5800000a ;
129
+ static constexpr int MPI_MINLOC = 0x5800000b ;
130
+ static constexpr int MPI_MAXLOC = 0x5800000c ;
131
+ static constexpr int MPI_REPLACE = 0x5800000d ;
132
+ static constexpr int MPI_NO_OP = 0x5800000e ;
97
133
98
134
public:
99
135
using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
136
172
assert (false && " unsupported type" );
137
173
return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (), mtype);
138
174
}
175
+
176
+ Value getMPIOp (const Location loc, ConversionPatternRewriter &rewriter,
177
+ mpi::MPI_OpClassEnum opAttr) override {
178
+ int32_t op = MPI_NO_OP;
179
+ switch (opAttr) {
180
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
181
+ op = MPI_NO_OP;
182
+ break ;
183
+ case mpi::MPI_OpClassEnum::MPI_MAX:
184
+ op = MPI_MAX;
185
+ break ;
186
+ case mpi::MPI_OpClassEnum::MPI_MIN:
187
+ op = MPI_MIN;
188
+ break ;
189
+ case mpi::MPI_OpClassEnum::MPI_SUM:
190
+ op = MPI_SUM;
191
+ break ;
192
+ case mpi::MPI_OpClassEnum::MPI_PROD:
193
+ op = MPI_PROD;
194
+ break ;
195
+ case mpi::MPI_OpClassEnum::MPI_LAND:
196
+ op = MPI_LAND;
197
+ break ;
198
+ case mpi::MPI_OpClassEnum::MPI_BAND:
199
+ op = MPI_BAND;
200
+ break ;
201
+ case mpi::MPI_OpClassEnum::MPI_LOR:
202
+ op = MPI_LOR;
203
+ break ;
204
+ case mpi::MPI_OpClassEnum::MPI_BOR:
205
+ op = MPI_BOR;
206
+ break ;
207
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
208
+ op = MPI_LXOR;
209
+ break ;
210
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
211
+ op = MPI_BXOR;
212
+ break ;
213
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
214
+ op = MPI_MINLOC;
215
+ break ;
216
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
217
+ op = MPI_MAXLOC;
218
+ break ;
219
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
220
+ op = MPI_REPLACE;
221
+ break ;
222
+ }
223
+ return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (), op);
224
+ }
139
225
};
140
226
141
227
// ===----------------------------------------------------------------------===//
@@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {
205
291
206
292
auto context = rewriter.getContext ();
207
293
// get external opaque struct pointer type
208
- auto commStructT =
294
+ auto typeStructT =
209
295
LLVM::LLVMStructType::getOpaque (" ompi_predefined_datatype_t" , context);
210
296
// make sure global op definition exists
211
- getOrDefineExternalStruct (loc, rewriter, mtype, commStructT );
297
+ getOrDefineExternalStruct (loc, rewriter, mtype, typeStructT );
212
298
// get address of symbol
213
299
return rewriter.create <LLVM::AddressOfOp>(
214
300
loc, LLVM::LLVMPointerType::get (context),
215
301
SymbolRefAttr::get (context, mtype));
216
302
}
303
+
304
+ Value getMPIOp (const Location loc, ConversionPatternRewriter &rewriter,
305
+ mpi::MPI_OpClassEnum opAttr) override {
306
+ StringRef op;
307
+ switch (opAttr) {
308
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
309
+ op = " ompi_mpi_no_op" ;
310
+ break ;
311
+ case mpi::MPI_OpClassEnum::MPI_MAX:
312
+ op = " ompi_mpi_max" ;
313
+ break ;
314
+ case mpi::MPI_OpClassEnum::MPI_MIN:
315
+ op = " ompi_mpi_min" ;
316
+ break ;
317
+ case mpi::MPI_OpClassEnum::MPI_SUM:
318
+ op = " ompi_mpi_sum" ;
319
+ break ;
320
+ case mpi::MPI_OpClassEnum::MPI_PROD:
321
+ op = " ompi_mpi_prod" ;
322
+ break ;
323
+ case mpi::MPI_OpClassEnum::MPI_LAND:
324
+ op = " ompi_mpi_land" ;
325
+ break ;
326
+ case mpi::MPI_OpClassEnum::MPI_BAND:
327
+ op = " ompi_mpi_band" ;
328
+ break ;
329
+ case mpi::MPI_OpClassEnum::MPI_LOR:
330
+ op = " ompi_mpi_lor" ;
331
+ break ;
332
+ case mpi::MPI_OpClassEnum::MPI_BOR:
333
+ op = " ompi_mpi_bor" ;
334
+ break ;
335
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
336
+ op = " ompi_mpi_lxor" ;
337
+ break ;
338
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
339
+ op = " ompi_mpi_bxor" ;
340
+ break ;
341
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
342
+ op = " ompi_mpi_minloc" ;
343
+ break ;
344
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
345
+ op = " ompi_mpi_maxloc" ;
346
+ break ;
347
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
348
+ op = " ompi_mpi_replace" ;
349
+ break ;
350
+ }
351
+ auto context = rewriter.getContext ();
352
+ // get external opaque struct pointer type
353
+ auto opStructT =
354
+ LLVM::LLVMStructType::getOpaque (" ompi_predefined_op_t" , context);
355
+ // make sure global op definition exists
356
+ getOrDefineExternalStruct (loc, rewriter, op, opStructT);
357
+ // get address of symbol
358
+ return rewriter.create <LLVM::AddressOfOp>(
359
+ loc, LLVM::LLVMPointerType::get (context),
360
+ SymbolRefAttr::get (context, op));
361
+ }
217
362
};
218
363
219
364
std::unique_ptr<MPIImplTraits> MPIImplTraits::get (ModuleOp &moduleOp) {
@@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
365
510
Location loc = op.getLoc ();
366
511
MLIRContext *context = rewriter.getContext ();
367
512
Type i32 = rewriter.getI32Type ();
368
- Type i64 = rewriter.getI64Type ();
369
- Value memRef = adaptor.getRef ();
370
513
Type elemType = op.getRef ().getType ().getElementType ();
371
514
372
515
// ptrType `!llvm.ptr`
@@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
376
519
auto moduleOp = op->getParentOfType <ModuleOp>();
377
520
378
521
// get MPI_COMM_WORLD, dataType and pointer
379
- Value dataPtr =
380
- rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
381
- Value offset = rewriter.create <LLVM::ExtractValueOp>(loc, i64 , memRef, 2 );
382
- dataPtr =
383
- rewriter.create <LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
384
- Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
385
- ArrayRef<int64_t >{3 , 0 });
386
- size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
522
+ auto [dataPtr, size] =
523
+ getRawPtrAndSize (loc, rewriter, adaptor.getRef (), elemType);
387
524
auto mpiTraits = MPIImplTraits::get (moduleOp);
388
525
Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
389
526
Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
@@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
425
562
MLIRContext *context = rewriter.getContext ();
426
563
Type i32 = rewriter.getI32Type ();
427
564
Type i64 = rewriter.getI64Type ();
428
- Value memRef = adaptor.getRef ();
429
565
Type elemType = op.getRef ().getType ().getElementType ();
430
566
431
567
// ptrType `!llvm.ptr`
@@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
435
571
auto moduleOp = op->getParentOfType <ModuleOp>();
436
572
437
573
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
438
- Value dataPtr =
439
- rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
440
- Value offset = rewriter.create <LLVM::ExtractValueOp>(loc, i64 , memRef, 2 );
441
- dataPtr =
442
- rewriter.create <LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
443
- Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
444
- ArrayRef<int64_t >{3 , 0 });
445
- size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
574
+ auto [dataPtr, size] =
575
+ getRawPtrAndSize (loc, rewriter, adaptor.getRef (), elemType);
446
576
auto mpiTraits = MPIImplTraits::get (moduleOp);
447
577
Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
448
578
Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
@@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
474
604
}
475
605
};
476
606
607
+ // ===----------------------------------------------------------------------===//
608
+ // AllReduceOpLowering
609
+ // ===----------------------------------------------------------------------===//
610
+
611
+ struct AllReduceOpLowering : public ConvertOpToLLVMPattern <mpi::AllReduceOp> {
612
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
613
+
614
+ LogicalResult
615
+ matchAndRewrite (mpi::AllReduceOp op, OpAdaptor adaptor,
616
+ ConversionPatternRewriter &rewriter) const override {
617
+ Location loc = op.getLoc ();
618
+ MLIRContext *context = rewriter.getContext ();
619
+ Type i32 = rewriter.getI32Type ();
620
+ Type elemType = op.getSendbuf ().getType ().getElementType ();
621
+
622
+ // ptrType `!llvm.ptr`
623
+ Type ptrType = LLVM::LLVMPointerType::get (context);
624
+ auto moduleOp = op->getParentOfType <ModuleOp>();
625
+ auto mpiTraits = MPIImplTraits::get (moduleOp);
626
+ auto [sendPtr, sendSize] =
627
+ getRawPtrAndSize (loc, rewriter, adaptor.getSendbuf (), elemType);
628
+ auto [recvPtr, recvSize] =
629
+ getRawPtrAndSize (loc, rewriter, adaptor.getRecvbuf (), elemType);
630
+ Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
631
+ Value mpiOp = mpiTraits->getMPIOp (loc, rewriter, op.getOp ());
632
+ Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
633
+ // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
634
+ // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
635
+ auto funcType = LLVM::LLVMFunctionType::get (
636
+ i32 , {ptrType, ptrType, i32 , dataType.getType (), mpiOp.getType (),
637
+ commWorld.getType ()});
638
+ // get or create function declaration:
639
+ LLVM::LLVMFuncOp funcDecl =
640
+ getOrDefineFunction (moduleOp, loc, rewriter, " MPI_Allreduce" , funcType);
641
+
642
+ // replace op with function call
643
+ auto funcCall = rewriter.create <LLVM::CallOp>(
644
+ loc, funcDecl,
645
+ ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
646
+
647
+ if (op.getRetval ())
648
+ rewriter.replaceOp (op, funcCall.getResult ());
649
+ else
650
+ rewriter.eraseOp (op);
651
+
652
+ return success ();
653
+ }
654
+ };
655
+
477
656
// ===----------------------------------------------------------------------===//
478
657
// ConvertToLLVMPatternInterface implementation
479
658
// ===----------------------------------------------------------------------===//
@@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498
677
void mpi::populateMPIToLLVMConversionPatterns (LLVMTypeConverter &converter,
499
678
RewritePatternSet &patterns) {
500
679
patterns.add <CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
501
- SendOpLowering, RecvOpLowering>(converter);
680
+ SendOpLowering, RecvOpLowering, AllReduceOpLowering >(converter);
502
681
}
503
682
504
683
void mpi::registerConvertMPIToLLVMInterface (DialectRegistry ®istry) {
0 commit comments