Skip to content

Commit eb42868

Browse files
authored
[MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriver (#77258)
Make GreedyPatternRewriteDriver handle failures of `materializeConstant` gracefully. Previously it was not checking whether the returned op was null and crashing. This PR handles it similarly to how OperationFolder does it.
1 parent c1023c5 commit eb42868

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
434434
SmallVector<OpFoldResult> foldResults;
435435
if (succeeded(op->fold(foldResults))) {
436436
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
437-
changed = true;
438437
if (foldResults.empty()) {
439438
// Op was modified in-place.
440439
notifyOperationModified(op);
440+
changed = true;
441441
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
442442
if (config.scope && failed(verify(config.scope->getParentOp())))
443443
llvm::report_fatal_error("IR failed to verify after folding");
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
451451
OpBuilder::InsertionGuard g(*this);
452452
setInsertionPoint(op);
453453
SmallVector<Value> replacements;
454+
bool materializationSucceeded = true;
454455
for (auto [ofr, resultType] :
455456
llvm::zip_equal(foldResults, op->getResultTypes())) {
456457
if (auto value = ofr.dyn_cast<Value>()) {
@@ -462,18 +463,41 @@ bool GreedyPatternRewriteDriver::processWorklist() {
462463
// Materialize Attributes as SSA values.
463464
Operation *constOp = op->getDialect()->materializeConstant(
464465
*this, ofr.get<Attribute>(), resultType, op->getLoc());
466+
467+
if (!constOp) {
468+
// If materialization fails, cleanup any operations generated for
469+
// the previous results.
470+
llvm::SmallDenseSet<Operation *> replacementOps;
471+
for (Value replacement : replacements) {
472+
assert(replacement.use_empty() &&
473+
"folder reused existing op for one result but constant "
474+
"materialization failed for another result");
475+
replacementOps.insert(replacement.getDefiningOp());
476+
}
477+
for (Operation *op : replacementOps) {
478+
eraseOp(op);
479+
}
480+
481+
materializationSucceeded = false;
482+
break;
483+
}
484+
465485
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
466486
"materializeConstant produced op that is not a ConstantLike");
467487
assert(constOp->getResultTypes()[0] == resultType &&
468488
"materializeConstant produced incorrect result type");
469489
replacements.push_back(constOp->getResult(0));
470490
}
471-
replaceOp(op, replacements);
491+
492+
if (materializationSucceeded) {
493+
replaceOp(op, replacements);
494+
changed = true;
472495
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
473-
if (config.scope && failed(verify(config.scope->getParentOp())))
474-
llvm::report_fatal_error("IR failed to verify after folding");
496+
if (config.scope && failed(verify(config.scope->getParentOp())))
497+
llvm::report_fatal_error("IR failed to verify after folding");
475498
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
476-
continue;
499+
continue;
500+
}
477501
}
478502
}
479503

mlir/test/Transforms/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
12241224
// CHECK-NEXT: scf.yield %[[ALLOC3_2]]
12251225
// CHECK: memref.dealloc %[[ALLOC1]]
12261226
// CHECK-NEXT: return %[[ALLOC2]]
1227+
1228+
// -----
1229+
1230+
// CHECK-LABEL: func @test_materialize_failure
1231+
func.func @test_materialize_failure() -> i64 {
1232+
%const = index.constant 1234
1233+
// Cannot materialize this castu's output constant.
1234+
// CHECK: index.castu
1235+
%u = index.castu %const : index to i64
1236+
return %u: i64
1237+
}

0 commit comments

Comments
 (0)