Skip to content

Commit 90f1a0e

Browse files
[mlir][Transforms] Dialect conversion: Align handling of dropped values
Handle dropped block arguments and dropped op results in the same way: build a source materialization (that may fold away if unused). This simplifies the code base a bit and makes it possible to merge `legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes` in a future commit. These two functions are almost doing the same thing now. This commit also fixes a bug where circular materializations were built, e.g.: ``` %0 = "builtin.unrealized_conversion_cast"(%1) : (!a) -> !b %1 = "builtin.unrealized_conversion_cast"(%0) : (!b) -> !a // No further uses of %0, %1. ``` This happened when: 1. An op was erased. (No replacement values provided.) 2. A conversion pattern for another op builds a replacement value (first cast op) during `remapValues`, but that SSA value is not used during the pattern application. 3. During the finalization phase, `legalizeConvertedOpResultTypes` thinks that the erased op is alive because of the cast op that was built in Step 2. It builds a cast from that replacement value to the original type. 4. During the commit phase, all uses of the original op are repalced with the casted value produced in Step 3. We have generated circular IR.
1 parent 63dab72 commit 90f1a0e

File tree

2 files changed

+28
-116
lines changed

2 files changed

+28
-116
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 26 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
941941
/// to modify/access them is invalid rewriter API usage.
942942
SetVector<Operation *> replacedOps;
943943

944+
DenseSet<Operation *> unresolvedMaterializations;
945+
944946
/// The current type converter, or nullptr if no type converter is currently
945947
/// active.
946948
const TypeConverter *currentTypeConverter = nullptr;
@@ -1066,6 +1068,7 @@ void UnresolvedMaterializationRewrite::rollback() {
10661068
for (Value input : op->getOperands())
10671069
rewriterImpl.mapping.erase(input);
10681070
}
1071+
rewriterImpl.unresolvedMaterializations.erase(op);
10691072
op->erase();
10701073
}
10711074

@@ -1347,6 +1350,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13471350
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13481351
auto convertOp =
13491352
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1353+
unresolvedMaterializations.insert(convertOp);
13501354
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13511355
return convertOp.getResult(0);
13521356
}
@@ -1385,9 +1389,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13851389
// Create mappings for each of the new result values.
13861390
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13871391
if (!newValue) {
1388-
resultChanged = true;
1389-
continue;
1392+
// This result was dropped and no replacement value was provided.
1393+
if (unresolvedMaterializations.contains(op)) {
1394+
// Do not create another materializations if we are erasing a
1395+
// materialization.
1396+
resultChanged = true;
1397+
continue;
1398+
}
1399+
1400+
// Materialize a replacement value "out of thin air".
1401+
newValue = buildUnresolvedMaterialization(
1402+
MaterializationKind::Source, computeInsertPoint(result),
1403+
result.getLoc(), /*inputs=*/ValueRange(),
1404+
/*outputType=*/result.getType(), currentTypeConverter);
13901405
}
1406+
13911407
// Remap, and check for any result type changes.
13921408
mapping.map(result, newValue);
13931409
resultChanged |= (newValue.getType() != result.getType());
@@ -2359,11 +2375,6 @@ struct OperationConverter {
23592375
ConversionPatternRewriterImpl &rewriterImpl,
23602376
DenseMap<Value, SmallVector<Value>> &inverseMapping);
23612377

2362-
/// Legalize an operation result that was marked as "erased".
2363-
LogicalResult
2364-
legalizeErasedResult(Operation *op, OpResult result,
2365-
ConversionPatternRewriterImpl &rewriterImpl);
2366-
23672378
/// Dialect conversion configuration.
23682379
ConversionConfig config;
23692380

@@ -2455,77 +2466,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
24552466
return failure();
24562467
}
24572468

2458-
/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results
2459-
/// are not used (transitively) by any op that is not in the given list of
2460-
/// cast ops.
2461-
///
2462-
/// In particular, this function erases cyclic casts that may be inserted
2463-
/// during the dialect conversion process. E.g.:
2464-
/// %0 = unrealized_conversion_cast(%1)
2465-
/// %1 = unrealized_conversion_cast(%0)
2466-
// Note: This step will become unnecessary when
2467-
// https://github.com/llvm/llvm-project/pull/106760 has been merged.
2468-
static void eraseDeadUnrealizedCasts(
2469-
ArrayRef<UnrealizedConversionCastOp> castOps,
2470-
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2471-
// Ops that have already been visited or are currently being visited.
2472-
DenseSet<Operation *> visited;
2473-
// Set of all cast ops for faster lookups.
2474-
DenseSet<Operation *> castOpSet;
2475-
// Set of all cast ops that have been determined to be alive.
2476-
DenseSet<Operation *> live;
2477-
2478-
for (UnrealizedConversionCastOp op : castOps)
2479-
castOpSet.insert(op);
2480-
2481-
// Visit a cast operation. Return "true" if the operation is live.
2482-
std::function<bool(Operation *)> visit = [&](Operation *op) -> bool {
2483-
// No need to traverse any IR if the op was already marked as live.
2484-
if (live.contains(op))
2485-
return true;
2486-
2487-
// Do not visit ops multiple times. If we find a circle, no live user was
2488-
// found on the current path.
2489-
if (!visited.insert(op).second)
2490-
return false;
2491-
2492-
// Visit all users.
2493-
for (Operation *user : op->getUsers()) {
2494-
// If the user is not an unrealized_conversion_cast op, then the given op
2495-
// is live.
2496-
if (!castOpSet.contains(user)) {
2497-
live.insert(op);
2498-
return true;
2499-
}
2500-
// Otherwise, it is live if a live op can be reached from one of its
2501-
// users (which must all be unrealized_conversion_cast ops).
2502-
if (visit(user)) {
2503-
live.insert(op);
2504-
return true;
2505-
}
2506-
}
2507-
2508-
return false;
2509-
};
2510-
2511-
// Visit all cast ops.
2512-
for (UnrealizedConversionCastOp op : castOps) {
2513-
visit(op);
2514-
visited.clear();
2515-
}
2516-
2517-
// Erase all cast ops that are dead.
2518-
for (UnrealizedConversionCastOp op : castOps) {
2519-
if (live.contains(op)) {
2520-
if (remainingCastOps)
2521-
remainingCastOps->push_back(op);
2522-
continue;
2523-
}
2524-
op->dropAllUses();
2525-
op->erase();
2526-
}
2527-
}
2528-
25292469
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25302470
if (ops.empty())
25312471
return success();
@@ -2584,14 +2524,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25842524
// Reconcile all UnrealizedConversionCastOps that were inserted by the
25852525
// dialect conversion frameworks. (Not the one that were inserted by
25862526
// patterns.)
2587-
SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
2588-
eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1);
2589-
reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2);
2527+
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2528+
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
25902529

25912530
// Try to legalize all unresolved materializations.
25922531
if (config.buildMaterializations) {
25932532
IRRewriter rewriter(rewriterImpl.context, config.listener);
2594-
for (UnrealizedConversionCastOp castOp : remainingCastOps2) {
2533+
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
25952534
auto it = rewriteMap.find(castOp.getOperation());
25962535
assert(it != rewriteMap.end() && "inconsistent state");
25972536
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
@@ -2650,26 +2589,18 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
26502589
continue;
26512590
Operation *op = opReplacement->getOperation();
26522591
for (OpResult result : op->getResults()) {
2653-
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2654-
2655-
// If the operation result was replaced with null, all of the uses of this
2656-
// value should be replaced.
2657-
if (!newValue) {
2658-
if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2659-
return failure();
2660-
continue;
2661-
}
2662-
2663-
// Otherwise, check to see if the type of the result changed.
2664-
if (result.getType() == newValue.getType())
2592+
// If the type of this op result changed and the result is still live,
2593+
// we need to materialize a conversion.
2594+
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
26652595
continue;
2666-
26672596
Operation *liveUser =
26682597
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
26692598
if (!liveUser)
26702599
continue;
26712600

26722601
// Legalize this result.
2602+
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2603+
assert(newValue && "replacement value not found");
26732604
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
26742605
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
26752606
/*inputs=*/newValue, /*outputType=*/result.getType(),
@@ -2727,25 +2658,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
27272658
return success();
27282659
}
27292660

2730-
LogicalResult OperationConverter::legalizeErasedResult(
2731-
Operation *op, OpResult result,
2732-
ConversionPatternRewriterImpl &rewriterImpl) {
2733-
// If the operation result was replaced with null, all of the uses of this
2734-
// value should be replaced.
2735-
auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2736-
return rewriterImpl.isOpIgnored(user);
2737-
});
2738-
if (liveUserIt != result.user_end()) {
2739-
InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2740-
<< op->getName() << "' marked as erased";
2741-
diag.attachNote(liveUserIt->getLoc())
2742-
<< "found live user of result #" << result.getResultNumber() << ": "
2743-
<< *liveUserIt;
2744-
return failure();
2745-
}
2746-
return success();
2747-
}
2748-
27492661
//===----------------------------------------------------------------------===//
27502662
// Reconcile Unrealized Casts
27512663
//===----------------------------------------------------------------------===//

mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
// Test that an error is emitted when an operation is marked as "erased", but
44
// has users that live across the conversion.
55
func.func @remove_all_ops(%arg0: i32) -> i32 {
6-
// expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}}
6+
// expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}}
77
%0 = "test.illegal_op_a"() : () -> i32
8-
// expected-note@below {{found live user of result #0: func.return %0 : i32}}
8+
// expected-note@below {{see existing live user here}}
99
return %0 : i32
1010
}

0 commit comments

Comments
 (0)