Skip to content

Commit f202d32

Browse files
[mlir][SCF] Add canonicalization pattern for scf::For to eliminate yields that just forward.
For instance: ``` func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) { %a = call @make_i32() : () -> (i32) %b = call @make_i32() : () -> (i32) %r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) { %c = call @make_i32() : () -> (i32) scf.yield %0, %c, %2 : i32, i32, i32 } return %r#0, %r#1, %r#2 : i32, i32, i32 } ``` Canonicalizes as: ``` func @for_yields_3(%arg0: index, %arg1: index, %arg2: index) -> (i32, i32, i32) { %0 = call @make_i32() : () -> i32 %1 = call @make_i32() : () -> i32 %2 = scf.for %arg3 = %arg0 to %arg1 step %arg2 iter_args(%arg4 = %0) -> (i32) { %3 = call @make_i32() : () -> i32 scf.yield %3 : i32 } return %0, %2, %1 : i32, i32, i32 } ``` Differential Revision: https://reviews.llvm.org/D90745
1 parent 825e517 commit f202d32

File tree

3 files changed

+151
-0
lines changed

3 files changed

+151
-0
lines changed

mlir/include/mlir/Dialect/SCF/SCFOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def ForOp : SCF_Op<"for",
197197
/// value for `index`.
198198
OperandRange getSuccessorEntryOperands(unsigned index);
199199
}];
200+
201+
let hasCanonicalizer = 1;
200202
}
201203

202204
def IfOp : SCF_Op<"if",

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,120 @@ ValueVector mlir::scf::buildLoopNest(
370370
});
371371
}
372372

373+
namespace {
374+
// Fold away ForOp iter arguments that are also yielded by the op.
375+
// These arguments must be defined outside of the ForOp region and can just be
376+
// forwarded after simplifying the op inits, yields and returns.
377+
//
378+
// The implementation uses `mergeBlockBefore` to steal the content of the
379+
// original ForOp and avoid cloning.
380+
struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
381+
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
382+
383+
LogicalResult matchAndRewrite(scf::ForOp forOp,
384+
PatternRewriter &rewriter) const final {
385+
bool canonicalize = false;
386+
Block &block = forOp.region().front();
387+
auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
388+
389+
// An internal flat vector of block transfer
390+
// arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
391+
// transformed block argument mappings. This plays the role of a
392+
// BlockAndValueMapping for the particular use case of calling into
393+
// `mergeBlockBefore`.
394+
SmallVector<bool, 4> keepMask;
395+
keepMask.reserve(yieldOp.getNumOperands());
396+
SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
397+
newResultValues;
398+
newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
399+
newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
400+
newIterArgs.reserve(forOp.getNumIterOperands());
401+
newYieldValues.reserve(yieldOp.getNumOperands());
402+
newResultValues.reserve(forOp.getNumResults());
403+
for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
404+
forOp.getRegionIterArgs(), // iter inside region
405+
yieldOp.getOperands() // iter yield
406+
)) {
407+
// Forwarded is `true` when the region `iter` argument is yielded.
408+
bool forwarded = (std::get<1>(it) == std::get<2>(it));
409+
keepMask.push_back(!forwarded);
410+
canonicalize |= forwarded;
411+
if (forwarded) {
412+
newBlockTransferArgs.push_back(std::get<0>(it));
413+
newResultValues.push_back(std::get<0>(it));
414+
continue;
415+
}
416+
newIterArgs.push_back(std::get<0>(it));
417+
newYieldValues.push_back(std::get<2>(it));
418+
newBlockTransferArgs.push_back(Value()); // placeholder with null value
419+
newResultValues.push_back(Value()); // placeholder with null value
420+
}
421+
422+
if (!canonicalize)
423+
return failure();
424+
425+
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
426+
forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
427+
newIterArgs);
428+
Block &newBlock = newForOp.region().front();
429+
430+
// Replace the null placeholders with newly constructed values.
431+
newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
432+
for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
433+
idx != e; ++idx) {
434+
Value &blockTransferArg = newBlockTransferArgs[1 + idx];
435+
Value &newResultVal = newResultValues[idx];
436+
assert((blockTransferArg && newResultVal) ||
437+
(!blockTransferArg && !newResultVal));
438+
if (!blockTransferArg) {
439+
blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
440+
newResultVal = newForOp.getResult(collapsedIdx++);
441+
}
442+
}
443+
444+
Block &oldBlock = forOp.region().front();
445+
assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
446+
"unexpected argument size mismatch");
447+
448+
// No results case: the scf::ForOp builder already created a zero
449+
// reult terminator. Merge before this terminator and just get rid of the
450+
// original terminator that has been merged in.
451+
if (newIterArgs.empty()) {
452+
auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
453+
rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
454+
rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
455+
rewriter.replaceOp(forOp, newResultValues);
456+
return success();
457+
}
458+
459+
// No terminator case: merge and rewrite the merged terminator.
460+
auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
461+
OpBuilder::InsertionGuard g(rewriter);
462+
rewriter.setInsertionPoint(mergedTerminator);
463+
SmallVector<Value, 4> filteredOperands;
464+
filteredOperands.reserve(newResultValues.size());
465+
for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
466+
if (keepMask[idx])
467+
filteredOperands.push_back(mergedTerminator.getOperand(idx));
468+
rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
469+
filteredOperands);
470+
};
471+
472+
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
473+
auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
474+
cloneFilteredTerminator(mergedYieldOp);
475+
rewriter.eraseOp(mergedYieldOp);
476+
rewriter.replaceOp(forOp, newResultValues);
477+
return success();
478+
}
479+
};
480+
} // namespace
481+
482+
void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
483+
MLIRContext *context) {
484+
results.insert<ForOpIterArgsFolder>(context);
485+
}
486+
373487
//===----------------------------------------------------------------------===//
374488
// IfOp
375489
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,38 @@ func @all_unused() {
137137
// CHECK: call @side_effect() : () -> ()
138138
// CHECK: }
139139
// CHECK: return
140+
141+
// -----
142+
143+
func @make_i32() -> i32
144+
145+
func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
146+
%a = call @make_i32() : () -> (i32)
147+
%b = scf.for %i = %lb to %ub step %step iter_args(%0 = %a) -> i32 {
148+
scf.yield %0 : i32
149+
}
150+
return %b : i32
151+
}
152+
153+
// CHECK-LABEL: func @for_yields_2
154+
// CHECK-NEXT: %[[R:.*]] = call @make_i32() : () -> i32
155+
// CHECK-NEXT: return %[[R]] : i32
156+
157+
func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
158+
%a = call @make_i32() : () -> (i32)
159+
%b = call @make_i32() : () -> (i32)
160+
%r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
161+
%c = call @make_i32() : () -> (i32)
162+
scf.yield %0, %c, %2 : i32, i32, i32
163+
}
164+
return %r#0, %r#1, %r#2 : i32, i32, i32
165+
}
166+
167+
// CHECK-LABEL: func @for_yields_3
168+
// CHECK-NEXT: %[[a:.*]] = call @make_i32() : () -> i32
169+
// CHECK-NEXT: %[[b:.*]] = call @make_i32() : () -> i32
170+
// CHECK-NEXT: %[[r1:.*]] = scf.for {{.*}} iter_args(%arg4 = %[[a]]) -> (i32) {
171+
// CHECK-NEXT: %[[c:.*]] = call @make_i32() : () -> i32
172+
// CHECK-NEXT: scf.yield %[[c]] : i32
173+
// CHECK-NEXT: }
174+
// CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32

0 commit comments

Comments
 (0)