Skip to content

Commit 6093c26

Browse files
[mlir][Transforms] Dialect conversion: Align handling of dropped values (#106760)
Handle dropped block arguments and dropped op results in the same way: build a source materialization (that may fold away if unused). This simplifies the code base a bit and makes it possible to merge `legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes` in a future commit. These two functions are almost doing the same thing now. As a side effect, this commit also changes the dialect conversion such that temporary circular cast ops are no longer generated. (There was a workaround in #107109 that can now be removed again.) Example: ``` %0 = "builtin.unrealized_conversion_cast"(%1) : (!a) -> !b %1 = "builtin.unrealized_conversion_cast"(%0) : (!b) -> !a // No further uses of %0, %1. ``` This happened when: 1. An op was erased. (No replacement values provided.) 2. A conversion pattern for another op builds a replacement value for the erased op's results (first cast op) during `remapValues`, but that SSA value is not used during the pattern application. 3. During the finalization phase, `legalizeConvertedOpResultTypes` thinks that the erased op is alive because of the cast op that was built in Step 2. It builds a cast from that replacement value to the original type. 4. During the commit phase, all uses of the original op are replaced with the casted value produced in Step 3. We have generated circular IR. This problem can be avoided by making sure that source materializations are generated for all dropped results. This ensures that we always have some replacement SSA value in the mapping. Previously, we sometimes had a value mapped and sometimes not. (No more special casing is needed anymore to distinguish between "value dropped" or "value replaced with SSA value".)
1 parent 229f391 commit 6093c26

File tree

2 files changed

+32
-131
lines changed

2 files changed

+32
-131
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 30 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -624,10 +624,9 @@ class ModifyOperationRewrite : public OperationRewrite {
624624
class ReplaceOperationRewrite : public OperationRewrite {
625625
public:
626626
ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
627-
Operation *op, const TypeConverter *converter,
628-
bool changedResults)
627+
Operation *op, const TypeConverter *converter)
629628
: OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
630-
converter(converter), changedResults(changedResults) {}
629+
converter(converter) {}
631630

632631
static bool classof(const IRRewrite *rewrite) {
633632
return rewrite->getKind() == Kind::ReplaceOperation;
@@ -641,15 +640,10 @@ class ReplaceOperationRewrite : public OperationRewrite {
641640

642641
const TypeConverter *getConverter() const { return converter; }
643642

644-
bool hasChangedResults() const { return changedResults; }
645-
646643
private:
647644
/// An optional type converter that can be used to materialize conversions
648645
/// between the new and old values if necessary.
649646
const TypeConverter *converter;
650-
651-
/// A boolean flag that indicates whether result types have changed or not.
652-
bool changedResults;
653647
};
654648

655649
class CreateOperationRewrite : public OperationRewrite {
@@ -941,6 +935,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
941935
/// to modify/access them is invalid rewriter API usage.
942936
SetVector<Operation *> replacedOps;
943937

938+
/// A set of all unresolved materializations.
939+
DenseSet<Operation *> unresolvedMaterializations;
940+
944941
/// The current type converter, or nullptr if no type converter is currently
945942
/// active.
946943
const TypeConverter *currentTypeConverter = nullptr;
@@ -1066,6 +1063,7 @@ void UnresolvedMaterializationRewrite::rollback() {
10661063
for (Value input : op->getOperands())
10671064
rewriterImpl.mapping.erase(input);
10681065
}
1066+
rewriterImpl.unresolvedMaterializations.erase(op);
10691067
op->erase();
10701068
}
10711069

@@ -1347,6 +1345,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13471345
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13481346
auto convertOp =
13491347
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1348+
unresolvedMaterializations.insert(convertOp);
13501349
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13511350
return convertOp.getResult(0);
13521351
}
@@ -1379,22 +1378,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13791378
assert(newValues.size() == op->getNumResults());
13801379
assert(!ignoredOps.contains(op) && "operation was already replaced");
13811380

1382-
// Track if any of the results changed, e.g. erased and replaced with null.
1383-
bool resultChanged = false;
1384-
13851381
// Create mappings for each of the new result values.
13861382
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13871383
if (!newValue) {
1388-
resultChanged = true;
1389-
continue;
1384+
// This result was dropped and no replacement value was provided.
1385+
if (unresolvedMaterializations.contains(op)) {
1386+
// Do not create another materializations if we are erasing a
1387+
// materialization.
1388+
continue;
1389+
}
1390+
1391+
// Materialize a replacement value "out of thin air".
1392+
newValue = buildUnresolvedMaterialization(
1393+
MaterializationKind::Source, computeInsertPoint(result),
1394+
result.getLoc(), /*inputs=*/ValueRange(),
1395+
/*outputType=*/result.getType(), currentTypeConverter);
13901396
}
1397+
13911398
// Remap, and check for any result type changes.
13921399
mapping.map(result, newValue);
1393-
resultChanged |= (newValue.getType() != result.getType());
13941400
}
13951401

1396-
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
1397-
resultChanged);
1402+
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
13981403

13991404
// Mark this operation and all nested ops as replaced.
14001405
op->walk([&](Operation *op) { replacedOps.insert(op); });
@@ -2359,11 +2364,6 @@ struct OperationConverter {
23592364
ConversionPatternRewriterImpl &rewriterImpl,
23602365
DenseMap<Value, SmallVector<Value>> &inverseMapping);
23612366

2362-
/// Legalize an operation result that was marked as "erased".
2363-
LogicalResult
2364-
legalizeErasedResult(Operation *op, OpResult result,
2365-
ConversionPatternRewriterImpl &rewriterImpl);
2366-
23672367
/// Dialect conversion configuration.
23682368
ConversionConfig config;
23692369

@@ -2455,77 +2455,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
24552455
return failure();
24562456
}
24572457

2458-
/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results
2459-
/// are not used (transitively) by any op that is not in the given list of
2460-
/// cast ops.
2461-
///
2462-
/// In particular, this function erases cyclic casts that may be inserted
2463-
/// during the dialect conversion process. E.g.:
2464-
/// %0 = unrealized_conversion_cast(%1)
2465-
/// %1 = unrealized_conversion_cast(%0)
2466-
// Note: This step will become unnecessary when
2467-
// https://github.com/llvm/llvm-project/pull/106760 has been merged.
2468-
static void eraseDeadUnrealizedCasts(
2469-
ArrayRef<UnrealizedConversionCastOp> castOps,
2470-
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2471-
// Ops that have already been visited or are currently being visited.
2472-
DenseSet<Operation *> visited;
2473-
// Set of all cast ops for faster lookups.
2474-
DenseSet<Operation *> castOpSet;
2475-
// Set of all cast ops that have been determined to be alive.
2476-
DenseSet<Operation *> live;
2477-
2478-
for (UnrealizedConversionCastOp op : castOps)
2479-
castOpSet.insert(op);
2480-
2481-
// Visit a cast operation. Return "true" if the operation is live.
2482-
std::function<bool(Operation *)> visit = [&](Operation *op) -> bool {
2483-
// No need to traverse any IR if the op was already marked as live.
2484-
if (live.contains(op))
2485-
return true;
2486-
2487-
// Do not visit ops multiple times. If we find a circle, no live user was
2488-
// found on the current path.
2489-
if (!visited.insert(op).second)
2490-
return false;
2491-
2492-
// Visit all users.
2493-
for (Operation *user : op->getUsers()) {
2494-
// If the user is not an unrealized_conversion_cast op, then the given op
2495-
// is live.
2496-
if (!castOpSet.contains(user)) {
2497-
live.insert(op);
2498-
return true;
2499-
}
2500-
// Otherwise, it is live if a live op can be reached from one of its
2501-
// users (which must all be unrealized_conversion_cast ops).
2502-
if (visit(user)) {
2503-
live.insert(op);
2504-
return true;
2505-
}
2506-
}
2507-
2508-
return false;
2509-
};
2510-
2511-
// Visit all cast ops.
2512-
for (UnrealizedConversionCastOp op : castOps) {
2513-
visit(op);
2514-
visited.clear();
2515-
}
2516-
2517-
// Erase all cast ops that are dead.
2518-
for (UnrealizedConversionCastOp op : castOps) {
2519-
if (live.contains(op)) {
2520-
if (remainingCastOps)
2521-
remainingCastOps->push_back(op);
2522-
continue;
2523-
}
2524-
op->dropAllUses();
2525-
op->erase();
2526-
}
2527-
}
2528-
25292458
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25302459
if (ops.empty())
25312460
return success();
@@ -2584,14 +2513,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25842513
// Reconcile all UnrealizedConversionCastOps that were inserted by the
25852514
// dialect conversion frameworks. (Not the one that were inserted by
25862515
// patterns.)
2587-
SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
2588-
eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1);
2589-
reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2);
2516+
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2517+
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
25902518

25912519
// Try to legalize all unresolved materializations.
25922520
if (config.buildMaterializations) {
25932521
IRRewriter rewriter(rewriterImpl.context, config.listener);
2594-
for (UnrealizedConversionCastOp castOp : remainingCastOps2) {
2522+
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
25952523
auto it = rewriteMap.find(castOp.getOperation());
25962524
assert(it != rewriteMap.end() && "inconsistent state");
25972525
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
@@ -2646,30 +2574,22 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
26462574
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
26472575
auto *opReplacement =
26482576
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2649-
if (!opReplacement || !opReplacement->hasChangedResults())
2577+
if (!opReplacement)
26502578
continue;
26512579
Operation *op = opReplacement->getOperation();
26522580
for (OpResult result : op->getResults()) {
2653-
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2654-
2655-
// If the operation result was replaced with null, all of the uses of this
2656-
// value should be replaced.
2657-
if (!newValue) {
2658-
if (failed(legalizeErasedResult(op, result, rewriterImpl)))
2659-
return failure();
2581+
// If the type of this op result changed and the result is still live,
2582+
// we need to materialize a conversion.
2583+
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
26602584
continue;
2661-
}
2662-
2663-
// Otherwise, check to see if the type of the result changed.
2664-
if (result.getType() == newValue.getType())
2665-
continue;
2666-
26672585
Operation *liveUser =
26682586
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
26692587
if (!liveUser)
26702588
continue;
26712589

26722590
// Legalize this result.
2591+
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2592+
assert(newValue && "replacement value not found");
26732593
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
26742594
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
26752595
/*inputs=*/newValue, /*outputType=*/result.getType(),
@@ -2727,25 +2647,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
27272647
return success();
27282648
}
27292649

2730-
LogicalResult OperationConverter::legalizeErasedResult(
2731-
Operation *op, OpResult result,
2732-
ConversionPatternRewriterImpl &rewriterImpl) {
2733-
// If the operation result was replaced with null, all of the uses of this
2734-
// value should be replaced.
2735-
auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2736-
return rewriterImpl.isOpIgnored(user);
2737-
});
2738-
if (liveUserIt != result.user_end()) {
2739-
InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2740-
<< op->getName() << "' marked as erased";
2741-
diag.attachNote(liveUserIt->getLoc())
2742-
<< "found live user of result #" << result.getResultNumber() << ": "
2743-
<< *liveUserIt;
2744-
return failure();
2745-
}
2746-
return success();
2747-
}
2748-
27492650
//===----------------------------------------------------------------------===//
27502651
// Reconcile Unrealized Casts
27512652
//===----------------------------------------------------------------------===//

mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
// Test that an error is emitted when an operation is marked as "erased", but
44
// has users that live across the conversion.
55
func.func @remove_all_ops(%arg0: i32) -> i32 {
6-
// expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}}
6+
// expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}}
77
%0 = "test.illegal_op_a"() : () -> i32
8-
// expected-note@below {{found live user of result #0: func.return %0 : i32}}
8+
// expected-note@below {{see existing live user here}}
99
return %0 : i32
1010
}

0 commit comments

Comments
 (0)