@@ -211,25 +211,14 @@ static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
211
211
return decl;
212
212
}
213
213
214
- // / Returns an LLVM pointer type with the given element type, or an opaque
215
- // / pointer if 'useOpaquePointers' is true.
216
- static LLVM::LLVMPointerType getPointerType (Type elementType,
217
- bool useOpaquePointers) {
218
- if (useOpaquePointers)
219
- return LLVM::LLVMPointerType::get (elementType.getContext ());
220
- return LLVM::LLVMPointerType::get (elementType);
221
- }
222
-
223
214
// / Adds an atomic reduction combiner to the given OpenMP reduction declaration
224
215
// / using llvm.atomicrmw of the given kind.
225
216
static omp::ReductionDeclareOp addAtomicRMW (OpBuilder &builder,
226
217
LLVM::AtomicBinOp atomicKind,
227
218
omp::ReductionDeclareOp decl,
228
- scf::ReduceOp reduce,
229
- bool useOpaquePointers) {
219
+ scf::ReduceOp reduce) {
230
220
OpBuilder::InsertionGuard guard (builder);
231
- Type type = reduce.getOperand ().getType ();
232
- Type ptrType = getPointerType (type, useOpaquePointers);
221
+ auto ptrType = LLVM::LLVMPointerType::get (builder.getContext ());
233
222
Location reduceOperandLoc = reduce.getOperand ().getLoc ();
234
223
builder.createBlock (&decl.getAtomicReductionRegion (),
235
224
decl.getAtomicReductionRegion ().end (), {ptrType, ptrType},
@@ -250,8 +239,7 @@ static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
250
239
// / the neutral value, necessary for the OpenMP declaration. If the reduction
251
240
// / cannot be recognized, returns null.
252
241
static omp::ReductionDeclareOp declareReduction (PatternRewriter &builder,
253
- scf::ReduceOp reduce,
254
- bool useOpaquePointers) {
242
+ scf::ReduceOp reduce) {
255
243
Operation *container = SymbolTable::getNearestSymbolTable (reduce);
256
244
SymbolTable symbolTable (container);
257
245
@@ -272,34 +260,29 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
272
260
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
273
261
omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
274
262
builder.getFloatAttr (type, 0.0 ));
275
- return addAtomicRMW (builder, LLVM::AtomicBinOp::fadd, decl, reduce,
276
- useOpaquePointers);
263
+ return addAtomicRMW (builder, LLVM::AtomicBinOp::fadd, decl, reduce);
277
264
}
278
265
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
279
266
omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
280
267
builder.getIntegerAttr (type, 0 ));
281
- return addAtomicRMW (builder, LLVM::AtomicBinOp::add, decl, reduce,
282
- useOpaquePointers);
268
+ return addAtomicRMW (builder, LLVM::AtomicBinOp::add, decl, reduce);
283
269
}
284
270
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
285
271
omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
286
272
builder.getIntegerAttr (type, 0 ));
287
- return addAtomicRMW (builder, LLVM::AtomicBinOp::_or, decl, reduce,
288
- useOpaquePointers);
273
+ return addAtomicRMW (builder, LLVM::AtomicBinOp::_or, decl, reduce);
289
274
}
290
275
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
291
276
omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
292
277
builder.getIntegerAttr (type, 0 ));
293
- return addAtomicRMW (builder, LLVM::AtomicBinOp::_xor, decl, reduce,
294
- useOpaquePointers);
278
+ return addAtomicRMW (builder, LLVM::AtomicBinOp::_xor, decl, reduce);
295
279
}
296
280
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
297
281
omp::ReductionDeclareOp decl = createDecl (
298
282
builder, symbolTable, reduce,
299
283
builder.getIntegerAttr (
300
284
type, llvm::APInt::getAllOnes (type.getIntOrFloatBitWidth ())));
301
- return addAtomicRMW (builder, LLVM::AtomicBinOp::_and, decl, reduce,
302
- useOpaquePointers);
285
+ return addAtomicRMW (builder, LLVM::AtomicBinOp::_and, decl, reduce);
303
286
}
304
287
305
288
// Match simple binary reductions that cannot be expressed with atomicrmw.
@@ -335,7 +318,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
335
318
builder, symbolTable, reduce, minMaxValueForSignedInt (type, !isMin));
336
319
return addAtomicRMW (builder,
337
320
isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
338
- decl, reduce, useOpaquePointers );
321
+ decl, reduce);
339
322
}
340
323
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
341
324
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -347,7 +330,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
347
330
builder, symbolTable, reduce, minMaxValueForUnsignedInt (type, !isMin));
348
331
return addAtomicRMW (
349
332
builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
350
- decl, reduce, useOpaquePointers );
333
+ decl, reduce);
351
334
}
352
335
353
336
return nullptr ;
@@ -357,11 +340,8 @@ namespace {
357
340
358
341
struct ParallelOpLowering : public OpRewritePattern <scf::ParallelOp> {
359
342
360
- bool useOpaquePointers;
361
-
362
- ParallelOpLowering (MLIRContext *context, bool useOpaquePointers)
363
- : OpRewritePattern<scf::ParallelOp>(context),
364
- useOpaquePointers (useOpaquePointers) {}
343
+ ParallelOpLowering (MLIRContext *context)
344
+ : OpRewritePattern<scf::ParallelOp>(context) {}
365
345
366
346
LogicalResult matchAndRewrite (scf::ParallelOp parallelOp,
367
347
PatternRewriter &rewriter) const override {
@@ -370,8 +350,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
370
350
// declaration and use it instead of redeclaring.
371
351
SmallVector<Attribute> reductionDeclSymbols;
372
352
for (auto reduce : parallelOp.getOps <scf::ReduceOp>()) {
373
- omp::ReductionDeclareOp decl =
374
- declareReduction (rewriter, reduce, useOpaquePointers);
353
+ omp::ReductionDeclareOp decl = declareReduction (rewriter, reduce);
375
354
if (!decl)
376
355
return failure ();
377
356
reductionDeclSymbols.push_back (
@@ -385,14 +364,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
385
364
loc, rewriter.getIntegerType (64 ), rewriter.getI64IntegerAttr (1 ));
386
365
SmallVector<Value> reductionVariables;
387
366
reductionVariables.reserve (parallelOp.getNumReductions ());
367
+ auto ptrType = LLVM::LLVMPointerType::get (parallelOp.getContext ());
388
368
for (Value init : parallelOp.getInitVals ()) {
389
369
assert ((LLVM::isCompatibleType (init.getType ()) ||
390
370
isa<LLVM::PointerElementTypeInterface>(init.getType ())) &&
391
371
" cannot create a reduction variable if the type is not an LLVM "
392
372
" pointer element" );
393
- Value storage = rewriter.create <LLVM::AllocaOp>(
394
- loc, getPointerType (init.getType (), useOpaquePointers),
395
- init.getType (), one, 0 );
373
+ Value storage =
374
+ rewriter.create <LLVM::AllocaOp>(loc, ptrType, init.getType (), one, 0 );
396
375
rewriter.create <LLVM::StoreOp>(loc, init, storage);
397
376
reductionVariables.push_back (storage);
398
377
}
@@ -464,14 +443,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
464
443
};
465
444
466
445
// / Applies the conversion patterns in the given function.
467
- static LogicalResult applyPatterns (ModuleOp module , bool useOpaquePointers ) {
446
+ static LogicalResult applyPatterns (ModuleOp module ) {
468
447
ConversionTarget target (*module .getContext ());
469
448
target.addIllegalOp <scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
470
449
target.addLegalDialect <omp::OpenMPDialect, LLVM::LLVMDialect,
471
450
memref::MemRefDialect>();
472
451
473
452
RewritePatternSet patterns (module .getContext ());
474
- patterns.add <ParallelOpLowering>(module .getContext (), useOpaquePointers );
453
+ patterns.add <ParallelOpLowering>(module .getContext ());
475
454
FrozenRewritePatternSet frozen (std::move (patterns));
476
455
return applyPartialConversion (module , target, frozen);
477
456
}
@@ -484,7 +463,7 @@ struct SCFToOpenMPPass
484
463
485
464
// / Pass entry point.
486
465
void runOnOperation () override {
487
- if (failed (applyPatterns (getOperation (), useOpaquePointers )))
466
+ if (failed (applyPatterns (getOperation ())))
488
467
signalPassFailure ();
489
468
}
490
469
};
0 commit comments