Skip to content

Commit 6cbd47b

Browse files
[MLIR][OpenMP] Address review comments
1 parent 465cdc8 commit 6cbd47b

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
443443
accumulator variables in `reduction_vars` and symbols referring to reduction
444444
declarations in the `reductions` attribute. Each reduction is identified
445445
by the accumulator it uses and accumulators must not be repeated in the same
446-
reduction. The `omp.reduction` operation accepts the accumulator and a
447-
partial value which is considered to be produced by the current loop
448-
iteration for the given reduction. If multiple values are produced for the
449-
same accumulator, i.e. there are multiple `omp.reduction`s, the last value
450-
is taken. The reduction declaration specifies how to combine the values from
451-
each iteration into the final value, which is available in the accumulator
452-
after the loop completes.
446+
reduction. A private variable corresponding to the accumulator is used in
447+
place of the accumulator inside the body of the worksharing-loop. The
448+
reduction declaration specifies how to combine the values from each
449+
iteration into the final value, which is available in the accumulator after
450+
the loop completes.
453451

454452
The optional `schedule_val` attribute specifies the loop schedule for this
455453
loop, determining how the loop is distributed across the parallel threads.

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -400,19 +400,20 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
400400
// Replace the reduction operations contained in this loop. Must be done
401401
// here rather than in a separate pattern to have access to the list of
402402
// reduction variables.
403-
unsigned int reductionIndex = 0;
404-
for (auto [x, y] :
405-
llvm::zip_equal(reductionVariables, reduce.getOperands())) {
403+
for (auto [x, y, rD] : llvm::zip_equal(
404+
reductionVariables, reduce.getOperands(), ompReductionDecls)) {
406405
OpBuilder::InsertionGuard guard(rewriter);
407406
rewriter.setInsertionPoint(reduce);
408-
Region &redRegion =
409-
ompReductionDecls[reductionIndex].getReductionRegion();
407+
Region &redRegion = rD.getReductionRegion();
408+
// The SCF dialect by definition contains only structured operations
409+
// and hence the SCF reduction region will contain a single block.
410+
// The ompReductionDecls region is a copy of the SCF reduction region
411+
// and hence has the same property.
410412
assert(redRegion.hasOneBlock() &&
411413
"expect reduction region to have one block");
412414
Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
413-
Value pvtRedVal = rewriter.create<LLVM::LoadOp>(
414-
reduce.getLoc(), ompReductionDecls[reductionIndex].getType(),
415-
pvtRedVar);
415+
Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(),
416+
rD.getType(), pvtRedVar);
416417
// Make a copy of the reduction combiner region in the body
417418
mlir::OpBuilder builder(rewriter.getContext());
418419
builder.setInsertionPoint(reduce);
@@ -432,7 +433,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
432433
break;
433434
}
434435
}
435-
reductionIndex++;
436436
}
437437
rewriter.eraseOp(reduce);
438438

0 commit comments

Comments
 (0)