Skip to content

Commit d0e88a5

Browse files
Kureeyuxuanchen1997
authored andcommitted
[MLIR][SCF] fix scf.index_switch fold convergence (#98535) (#98680)
If the `scf.index_switch` op has no result, the current fold logic results in an infinite loop (see #98535). The is because `fold` mechanism does not support *erasing* zero-result ops. This PR moves the fold logic to a canonicalizer and fix the issue.
1 parent 9c332c2 commit d0e88a5

File tree

3 files changed

+52
-25
lines changed

3 files changed

+52
-25
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
11591159
Block &getCaseBlock(unsigned idx);
11601160
}];
11611161

1162-
let hasFolder = 1;
1162+
let hasCanonicalizer = 1;
11631163
let hasVerifier = 1;
11641164
}
11651165

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4297,33 +4297,42 @@ void IndexSwitchOp::getRegionInvocationBounds(
42974297
bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
42984298
}
42994299

4300-
LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
4301-
SmallVectorImpl<OpFoldResult> &results) {
4302-
std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
4303-
if (!maybeCst.has_value())
4304-
return failure();
4305-
int64_t cst = *maybeCst;
4306-
int64_t caseIdx, e = getNumCases();
4307-
for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4308-
if (cst == getCases()[caseIdx])
4309-
break;
4310-
}
4300+
struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4301+
using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
43114302

4312-
Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
4313-
: getDefaultRegion();
4314-
Block &source = r.front();
4315-
results.assign(source.getTerminator()->getOperands().begin(),
4316-
source.getTerminator()->getOperands().end());
4303+
LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4304+
PatternRewriter &rewriter) const override {
4305+
// If `op.getArg()` is a constant, select the region that matches with
4306+
// the constant value. Use the default region if no matche is found.
4307+
std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4308+
if (!maybeCst.has_value())
4309+
return failure();
4310+
int64_t cst = *maybeCst;
4311+
int64_t caseIdx, e = op.getNumCases();
4312+
for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4313+
if (cst == op.getCases()[caseIdx])
4314+
break;
4315+
}
43174316

4318-
Block *pDestination = (*this)->getBlock();
4319-
if (!pDestination)
4320-
return failure();
4321-
Block::iterator insertionPoint = (*this)->getIterator();
4322-
pDestination->getOperations().splice(insertionPoint, source.getOperations(),
4323-
source.getOperations().begin(),
4324-
std::prev(source.getOperations().end()));
4317+
Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4318+
: op.getDefaultRegion();
4319+
Block &source = r.front();
4320+
Operation *terminator = source.getTerminator();
4321+
SmallVector<Value> results = terminator->getOperands();
43254322

4326-
return success();
4323+
rewriter.inlineBlockBefore(&source, op);
4324+
rewriter.eraseOp(terminator);
4325+
// Repalce the operation with a potentially empty list of results.
4326+
// Fold mechanism doesn't support the case where the result list is empty.
4327+
rewriter.replaceOp(op, results);
4328+
4329+
return success();
4330+
}
4331+
};
4332+
4333+
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4334+
MLIRContext *context) {
4335+
results.add<FoldConstantCase>(context);
43274336
}
43284337

43294338
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,3 +1846,21 @@ func.func @index_switch_fold() -> (f32, f32) {
18461846
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32
18471847
// CHECK-NEXT: %[[c42:.*]] = arith.constant 4.200000e+01 : f32
18481848
// CHECK-NEXT: return %[[c1]], %[[c42]] : f32, f32
1849+
1850+
// -----
1851+
1852+
func.func @index_switch_fold_no_res() {
1853+
%c1 = arith.constant 1 : index
1854+
scf.index_switch %c1
1855+
case 0 {
1856+
scf.yield
1857+
}
1858+
default {
1859+
"test.op"() : () -> ()
1860+
scf.yield
1861+
}
1862+
return
1863+
}
1864+
1865+
// CHECK-LABEL: func.func @index_switch_fold_no_res()
1866+
// CHECK-NEXT: "test.op"() : () -> ()

0 commit comments

Comments
 (0)