Skip to content

Commit 465cdc8

Browse files
[MLIR][SCF][OpenMP] Update reduction conversion
1 parent 78d3a7f commit 465cdc8

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
367367
// TODO: consider checking it here is already a compatible reduction
368368
// declaration and use it instead of redeclaring.
369369
SmallVector<Attribute> reductionDeclSymbols;
370+
SmallVector<omp::ReductionDeclareOp> ompReductionDecls;
370371
auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
371372
for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
372373
omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce, i);
374+
ompReductionDecls.push_back(decl);
373375
if (!decl)
374376
return failure();
375377
reductionDeclSymbols.push_back(
@@ -398,11 +400,39 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
398400
// Replace the reduction operations contained in this loop. Must be done
399401
// here rather than in a separate pattern to have access to the list of
400402
// reduction variables.
403+
unsigned int reductionIndex = 0;
401404
for (auto [x, y] :
402405
llvm::zip_equal(reductionVariables, reduce.getOperands())) {
403406
OpBuilder::InsertionGuard guard(rewriter);
404407
rewriter.setInsertionPoint(reduce);
405-
rewriter.create<omp::ReductionOp>(reduce.getLoc(), y, x);
408+
Region &redRegion =
409+
ompReductionDecls[reductionIndex].getReductionRegion();
410+
assert(redRegion.hasOneBlock() &&
411+
"expect reduction region to have one block");
412+
Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
413+
Value pvtRedVal = rewriter.create<LLVM::LoadOp>(
414+
reduce.getLoc(), ompReductionDecls[reductionIndex].getType(),
415+
pvtRedVar);
416+
// Make a copy of the reduction combiner region in the body
417+
mlir::OpBuilder builder(rewriter.getContext());
418+
builder.setInsertionPoint(reduce);
419+
mlir::IRMapping mapper;
420+
assert(redRegion.getNumArguments() == 2 &&
421+
"expect reduction region to have two arguments");
422+
mapper.map(redRegion.getArgument(0), pvtRedVal);
423+
mapper.map(redRegion.getArgument(1), y);
424+
for (auto &op : redRegion.getOps()) {
425+
Operation *cloneOp = builder.clone(op, mapper);
426+
if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
427+
assert(yieldOp && yieldOp.getResults().size() == 1 &&
428+
"expect YieldOp in reduction region to return one result");
429+
Value redVal = yieldOp.getResults()[0];
430+
rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
431+
rewriter.eraseOp(yieldOp);
432+
break;
433+
}
434+
}
435+
reductionIndex++;
406436
}
407437
rewriter.eraseOp(reduce);
408438

mlir/test/Conversion/SCFToOpenMP/reductions.mlir

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
2727
%zero = arith.constant 0.0 : f32
2828
// CHECK: omp.parallel
2929
// CHECK: omp.wsloop
30-
// CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]]
30+
// CHECK-SAME: reduction(@[[$REDF]] %[[BUF]] -> %[[PVT_BUF:[a-z0-9]+]]
3131
// CHECK: memref.alloca_scope
3232
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
3333
step (%arg4, %step) init (%zero) -> (f32) {
3434
// CHECK: %[[CST_INNER:.*]] = arith.constant 1.0
3535
%one = arith.constant 1.0 : f32
36-
// CHECK: omp.reduction %[[CST_INNER]], %[[BUF]]
36+
// CHECK: %[[PVT_VAL:.*]] = llvm.load %[[PVT_BUF]] : !llvm.ptr -> f32
37+
// CHECK: %[[ADD_RESULT:.*]] = arith.addf %[[PVT_VAL]], %[[CST_INNER]] : f32
38+
// CHECK: llvm.store %[[ADD_RESULT]], %[[PVT_BUF]] : f32, !llvm.ptr
3739
scf.reduce(%one : f32) {
3840
^bb0(%lhs : f32, %rhs: f32):
3941
%res = arith.addf %lhs, %rhs : f32
@@ -103,10 +105,15 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
103105
%arg3 : index, %arg4 : index) {
104106
%step = arith.constant 1 : index
105107
%one = arith.constant 1 : i32
108+
// CHECK: %[[RED_VAR:.*]] = llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
109+
// CHECK: omp.wsloop reduction(@[[$REDI]] %[[RED_VAR]] -> %[[RED_PVT_VAR:.*]] : !llvm.ptr)
106110
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
107111
step (%arg4, %step) init (%one) -> (i32) {
108-
// CHECK: omp.reduction
112+
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
109113
%pow2 = arith.constant 2 : i32
114+
// CHECK: %[[RED_PVT_VAL:.*]] = llvm.load %[[RED_PVT_VAR]] : !llvm.ptr -> i32
115+
// CHECK: %[[MUL_RESULT:.*]] = arith.muli %[[RED_PVT_VAL]], %[[C2]] : i32
116+
// CHECK: llvm.store %[[MUL_RESULT]], %[[RED_PVT_VAR]] : i32, !llvm.ptr
110117
scf.reduce(%pow2 : i32) {
111118
^bb0(%lhs : i32, %rhs: i32):
112119
%res = arith.muli %lhs, %rhs : i32
@@ -199,16 +206,23 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
199206

200207
// CHECK: omp.parallel
201208
// CHECK: omp.wsloop
202-
// CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
203-
// CHECK-SAME: @[[$REDF2]] -> %[[BUF2]]
209+
// CHECK-SAME: reduction(@[[$REDF1]] %[[BUF1]] -> %[[PVT_BUF1:[a-z0-9]+]]
210+
// CHECK-SAME: @[[$REDF2]] %[[BUF2]] -> %[[PVT_BUF2:[a-z0-9]+]]
204211
// CHECK: memref.alloca_scope
205212
%res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
206213
step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
214+
// CHECK: %[[CST_ONE:.*]] = arith.constant 1.0{{.*}} : f32
207215
%one = arith.constant 1.0 : f32
208-
// CHECK: arith.fptosi
216+
// CHECK: %[[CST_INT_ONE:.*]] = arith.fptosi
209217
%1 = arith.fptosi %one : f32 to i64
210-
// CHECK: omp.reduction %{{.*}}, %[[BUF1]]
211-
// CHECK: omp.reduction %{{.*}}, %[[BUF2]]
218+
// CHECK: %[[PVT_VAL1:.*]] = llvm.load %[[PVT_BUF1]] : !llvm.ptr -> f32
219+
// CHECK: %[[TEMP1:.*]] = arith.cmpf oge, %[[PVT_VAL1]], %[[CST_ONE]] : f32
220+
// CHECK: %[[CMP_VAL1:.*]] = arith.select %[[TEMP1]], %[[PVT_VAL1]], %[[CST_ONE]] : f32
221+
// CHECK: llvm.store %[[CMP_VAL1]], %[[PVT_BUF1]] : f32, !llvm.ptr
222+
// CHECK: %[[PVT_VAL2:.*]] = llvm.load %[[PVT_BUF2]] : !llvm.ptr -> i64
223+
// CHECK: %[[TEMP2:.*]] = arith.cmpi slt, %[[PVT_VAL2]], %[[CST_INT_ONE]] : i64
224+
// CHECK: %[[CMP_VAL2:.*]] = arith.select %[[TEMP2]], %[[CST_INT_ONE]], %[[PVT_VAL2]] : i64
225+
// CHECK: llvm.store %[[CMP_VAL2]], %[[PVT_BUF2]] : i64, !llvm.ptr
212226
scf.reduce(%one, %1 : f32, i64) {
213227
^bb0(%lhs : f32, %rhs: f32):
214228
%cmp = arith.cmpf oge, %lhs, %rhs : f32

0 commit comments

Comments
 (0)