@@ -400,19 +400,20 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
400
400
// Replace the reduction operations contained in this loop. Must be done
401
401
// here rather than in a separate pattern to have access to the list of
402
402
// reduction variables.
403
- unsigned int reductionIndex = 0 ;
404
- for (auto [x, y] :
405
- llvm::zip_equal (reductionVariables, reduce.getOperands ())) {
403
+ for (auto [x, y, rD] : llvm::zip_equal (
404
+ reductionVariables, reduce.getOperands (), ompReductionDecls)) {
406
405
OpBuilder::InsertionGuard guard (rewriter);
407
406
rewriter.setInsertionPoint (reduce);
408
- Region &redRegion =
409
- ompReductionDecls[reductionIndex].getReductionRegion ();
407
+ Region &redRegion = rD.getReductionRegion ();
408
+ // The SCF dialect by definition contains only structured operations
409
+ // and hence the SCF reduction region will contain a single block.
410
+ // The ompReductionDecls region is a copy of the SCF reduction region
411
+ // and hence has the same property.
410
412
assert (redRegion.hasOneBlock () &&
411
413
" expect reduction region to have one block" );
412
414
Value pvtRedVar = parallelOp.getRegion ().addArgument (x.getType (), loc);
413
- Value pvtRedVal = rewriter.create <LLVM::LoadOp>(
414
- reduce.getLoc (), ompReductionDecls[reductionIndex].getType (),
415
- pvtRedVar);
415
+ Value pvtRedVal = rewriter.create <LLVM::LoadOp>(reduce.getLoc (),
416
+ rD.getType (), pvtRedVar);
416
417
// Make a copy of the reduction combiner region in the body
417
418
mlir::OpBuilder builder (rewriter.getContext ());
418
419
builder.setInsertionPoint (reduce);
@@ -432,7 +433,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
432
433
break ;
433
434
}
434
435
}
435
- reductionIndex++;
436
436
}
437
437
rewriter.eraseOp (reduce);
438
438
0 commit comments