@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
434
434
SmallVector<OpFoldResult> foldResults;
435
435
if (succeeded (op->fold (foldResults))) {
436
436
LLVM_DEBUG (logResultWithLine (" success" , " operation was folded" ));
437
- changed = true ;
438
437
if (foldResults.empty ()) {
439
438
// Op was modified in-place.
440
439
notifyOperationModified (op);
440
+ changed = true ;
441
441
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
442
442
if (config.scope && failed (verify (config.scope ->getParentOp ())))
443
443
llvm::report_fatal_error (" IR failed to verify after folding" );
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
451
451
OpBuilder::InsertionGuard g (*this );
452
452
setInsertionPoint (op);
453
453
SmallVector<Value> replacements;
454
+ bool materializationSucceeded = true ;
454
455
for (auto [ofr, resultType] :
455
456
llvm::zip_equal (foldResults, op->getResultTypes ())) {
456
457
if (auto value = ofr.dyn_cast <Value>()) {
@@ -462,18 +463,41 @@ bool GreedyPatternRewriteDriver::processWorklist() {
462
463
// Materialize Attributes as SSA values.
463
464
Operation *constOp = op->getDialect ()->materializeConstant (
464
465
*this , ofr.get <Attribute>(), resultType, op->getLoc ());
466
+
467
+ if (!constOp) {
468
+ // If materialization fails, cleanup any operations generated for
469
+ // the previous results.
470
+ llvm::SmallDenseSet<Operation *> replacementOps;
471
+ for (Value replacement : replacements) {
472
+ assert (replacement.use_empty () &&
473
+ " folder reused existing op for one result but constant "
474
+ " materialization failed for another result" );
475
+ replacementOps.insert (replacement.getDefiningOp ());
476
+ }
477
+ for (Operation *op : replacementOps) {
478
+ eraseOp (op);
479
+ }
480
+
481
+ materializationSucceeded = false ;
482
+ break ;
483
+ }
484
+
465
485
assert (constOp->hasTrait <OpTrait::ConstantLike>() &&
466
486
" materializeConstant produced op that is not a ConstantLike" );
467
487
assert (constOp->getResultTypes ()[0 ] == resultType &&
468
488
" materializeConstant produced incorrect result type" );
469
489
replacements.push_back (constOp->getResult (0 ));
470
490
}
471
- replaceOp (op, replacements);
491
+
492
+ if (materializationSucceeded) {
493
+ replaceOp (op, replacements);
494
+ changed = true ;
472
495
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
473
- if (config.scope && failed (verify (config.scope ->getParentOp ())))
474
- llvm::report_fatal_error (" IR failed to verify after folding" );
496
+ if (config.scope && failed (verify (config.scope ->getParentOp ())))
497
+ llvm::report_fatal_error (" IR failed to verify after folding" );
475
498
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
476
- continue ;
499
+ continue ;
500
+ }
477
501
}
478
502
}
479
503
0 commit comments