@@ -4297,33 +4297,42 @@ void IndexSwitchOp::getRegionInvocationBounds(
4297
4297
bounds.emplace_back (/* lb=*/ 0 , /* ub=*/ i == liveIndex);
4298
4298
}
4299
4299
4300
- LogicalResult IndexSwitchOp::fold (FoldAdaptor adaptor,
4301
- SmallVectorImpl<OpFoldResult> &results) {
4302
- std::optional<int64_t > maybeCst = getConstantIntValue (getArg ());
4303
- if (!maybeCst.has_value ())
4304
- return failure ();
4305
- int64_t cst = *maybeCst;
4306
- int64_t caseIdx, e = getNumCases ();
4307
- for (caseIdx = 0 ; caseIdx < e; ++caseIdx) {
4308
- if (cst == getCases ()[caseIdx])
4309
- break ;
4310
- }
4300
+ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4301
+ using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
4311
4302
4312
- Region &r = (caseIdx < getNumCases ()) ? getCaseRegions ()[caseIdx]
4313
- : getDefaultRegion ();
4314
- Block &source = r.front ();
4315
- results.assign (source.getTerminator ()->getOperands ().begin (),
4316
- source.getTerminator ()->getOperands ().end ());
4303
+ LogicalResult matchAndRewrite (scf::IndexSwitchOp op,
4304
+ PatternRewriter &rewriter) const override {
4305
+ // If `op.getArg()` is a constant, select the region that matches with
4306
+ // the constant value. Use the default region if no matche is found.
4307
+ std::optional<int64_t > maybeCst = getConstantIntValue (op.getArg ());
4308
+ if (!maybeCst.has_value ())
4309
+ return failure ();
4310
+ int64_t cst = *maybeCst;
4311
+ int64_t caseIdx, e = op.getNumCases ();
4312
+ for (caseIdx = 0 ; caseIdx < e; ++caseIdx) {
4313
+ if (cst == op.getCases ()[caseIdx])
4314
+ break ;
4315
+ }
4317
4316
4318
- Block *pDestination = (*this )->getBlock ();
4319
- if (!pDestination)
4320
- return failure ();
4321
- Block::iterator insertionPoint = (*this )->getIterator ();
4322
- pDestination->getOperations ().splice (insertionPoint, source.getOperations (),
4323
- source.getOperations ().begin (),
4324
- std::prev (source.getOperations ().end ()));
4317
+ Region &r = (caseIdx < op.getNumCases ()) ? op.getCaseRegions ()[caseIdx]
4318
+ : op.getDefaultRegion ();
4319
+ Block &source = r.front ();
4320
+ Operation *terminator = source.getTerminator ();
4321
+ SmallVector<Value> results = terminator->getOperands ();
4325
4322
4326
- return success ();
4323
+ rewriter.inlineBlockBefore (&source, op);
4324
+ rewriter.eraseOp (terminator);
4325
+ // Repalce the operation with a potentially empty list of results.
4326
+ // Fold mechanism doesn't support the case where the result list is empty.
4327
+ rewriter.replaceOp (op, results);
4328
+
4329
+ return success ();
4330
+ }
4331
+ };
4332
+
4333
+ void IndexSwitchOp::getCanonicalizationPatterns (RewritePatternSet &results,
4334
+ MLIRContext *context) {
4335
+ results.add <FoldConstantCase>(context);
4327
4336
}
4328
4337
4329
4338
// ===----------------------------------------------------------------------===//
0 commit comments