Skip to content

Commit 3430a36

Browse files
committed
address more review comments
1 parent cc95d75 commit 3430a36

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,9 +629,9 @@ FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
629629
llvm::append_range(inits, newInitOperands);
630630
scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
631631
getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
632-
inits, getMapping());
632+
inits, getMapping(),
633+
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
633634

634-
rewriter.eraseOp(newLoop.getTerminator());
635635
// Move the loop body to the new op.
636636
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
637637
newLoop.getBody()->getArguments().take_front(

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,6 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
11241124
},
11251125
[&](RewriterBase &b, LoopLikeOpInterface source,
11261126
LoopLikeOpInterface &target, IRMapping mapping) {
1127-
auto sourceFor = cast<scf::ForOp>(source);
11281127
auto targetFor = cast<scf::ForOp>(target);
11291128
auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
11301129
b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
@@ -1151,8 +1150,9 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
11511150

11521151
rewriter.setInsertionPoint(source);
11531152
auto fusedLoop = rewriter.create<scf::ParallelOp>(
1154-
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1155-
source.getStep(), newInitVars);
1153+
rewriter.getFusedLoc(target.getLoc(), source.getLoc()),
1154+
source.getLowerBound(), source.getUpperBound(), source.getStep(),
1155+
newInitVars);
11561156
Block *newBlock = fusedLoop.getBody();
11571157
rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
11581158
newBlock->getArguments());
@@ -1168,8 +1168,8 @@ scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
11681168
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
11691169
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
11701170

1171-
auto newReduceOp =
1172-
rewriter.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
1171+
auto newReduceOp = rewriter.create<scf::ReduceOp>(
1172+
rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs);
11731173

11741174
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
11751175
term1.getReductions(), term2.getReductions()))) {

0 commit comments

Comments
 (0)