Skip to content

Commit 9a33993

Browse files
[MLIR][OpenMP] Address review comments
1 parent 110679e commit 9a33993

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -525,13 +525,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
525525
accumulator variables in `reduction_vars` and symbols referring to reduction
526526
declarations in the `reductions` attribute. Each reduction is identified
527527
by the accumulator it uses and accumulators must not be repeated in the same
528-
reduction. The `omp.reduction` operation accepts the accumulator and a
529-
partial value which is considered to be produced by the current loop
530-
iteration for the given reduction. If multiple values are produced for the
531-
same accumulator, i.e. there are multiple `omp.reduction`s, the last value
532-
is taken. The reduction declaration specifies how to combine the values from
533-
each iteration into the final value, which is available in the accumulator
534-
after the loop completes.
528+
reduction. A private variable corresponding to the accumulator is used in
529+
place of the accumulator inside the body of the worksharing-loop. The
530+
reduction declaration specifies how to combine the values from each
531+
iteration into the final value, which is available in the accumulator after
532+
the loop completes.
535533

536534
The optional `schedule_val` attribute specifies the loop schedule for this
537535
loop, determining how the loop is distributed across the parallel threads.

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -400,19 +400,20 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
400400
// Replace the reduction operations contained in this loop. Must be done
401401
// here rather than in a separate pattern to have access to the list of
402402
// reduction variables.
403-
unsigned int reductionIndex = 0;
404-
for (auto [x, y] :
405-
llvm::zip_equal(reductionVariables, reduce.getOperands())) {
403+
for (auto [x, y, rD] : llvm::zip_equal(
404+
reductionVariables, reduce.getOperands(), ompReductionDecls)) {
406405
OpBuilder::InsertionGuard guard(rewriter);
407406
rewriter.setInsertionPoint(reduce);
408-
Region &redRegion =
409-
ompReductionDecls[reductionIndex].getReductionRegion();
407+
Region &redRegion = rD.getReductionRegion();
408+
// The SCF dialect by definition contains only structured operations
409+
// and hence the SCF reduction region will contain a single block.
410+
// The ompReductionDecls region is a copy of the SCF reduction region
411+
// and hence has the same property.
410412
assert(redRegion.hasOneBlock() &&
411413
"expect reduction region to have one block");
412414
Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
413-
Value pvtRedVal = rewriter.create<LLVM::LoadOp>(
414-
reduce.getLoc(), ompReductionDecls[reductionIndex].getType(),
415-
pvtRedVar);
415+
Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(),
416+
rD.getType(), pvtRedVar);
416417
// Make a copy of the reduction combiner region in the body
417418
mlir::OpBuilder builder(rewriter.getContext());
418419
builder.setInsertionPoint(reduce);
@@ -432,7 +433,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
432433
break;
433434
}
434435
}
435-
reductionIndex++;
436436
}
437437
rewriter.eraseOp(reduce);
438438

0 commit comments

Comments
 (0)