@@ -624,10 +624,9 @@ class ModifyOperationRewrite : public OperationRewrite {
624
624
class ReplaceOperationRewrite : public OperationRewrite {
625
625
public:
626
626
ReplaceOperationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
627
- Operation *op, const TypeConverter *converter,
628
- bool changedResults)
627
+ Operation *op, const TypeConverter *converter)
629
628
: OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
630
- converter (converter), changedResults(changedResults) {}
629
+ converter (converter) {}
631
630
632
631
static bool classof (const IRRewrite *rewrite) {
633
632
return rewrite->getKind () == Kind::ReplaceOperation;
@@ -641,15 +640,10 @@ class ReplaceOperationRewrite : public OperationRewrite {
641
640
642
641
const TypeConverter *getConverter () const { return converter; }
643
642
644
- bool hasChangedResults () const { return changedResults; }
645
-
646
643
private:
647
644
// / An optional type converter that can be used to materialize conversions
648
645
// / between the new and old values if necessary.
649
646
const TypeConverter *converter;
650
-
651
- // / A boolean flag that indicates whether result types have changed or not.
652
- bool changedResults;
653
647
};
654
648
655
649
class CreateOperationRewrite : public OperationRewrite {
@@ -941,6 +935,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
941
935
// / to modify/access them is invalid rewriter API usage.
942
936
SetVector<Operation *> replacedOps;
943
937
938
+ // / A set of all unresolved materializations.
939
+ DenseSet<Operation *> unresolvedMaterializations;
940
+
944
941
// / The current type converter, or nullptr if no type converter is currently
945
942
// / active.
946
943
const TypeConverter *currentTypeConverter = nullptr ;
@@ -1066,6 +1063,7 @@ void UnresolvedMaterializationRewrite::rollback() {
1066
1063
for (Value input : op->getOperands ())
1067
1064
rewriterImpl.mapping .erase (input);
1068
1065
}
1066
+ rewriterImpl.unresolvedMaterializations .erase (op);
1069
1067
op->erase ();
1070
1068
}
1071
1069
@@ -1347,6 +1345,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1347
1345
builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
1348
1346
auto convertOp =
1349
1347
builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1348
+ unresolvedMaterializations.insert (convertOp);
1350
1349
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1351
1350
return convertOp.getResult (0 );
1352
1351
}
@@ -1379,22 +1378,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1379
1378
assert (newValues.size () == op->getNumResults ());
1380
1379
assert (!ignoredOps.contains (op) && " operation was already replaced" );
1381
1380
1382
- // Track if any of the results changed, e.g. erased and replaced with null.
1383
- bool resultChanged = false ;
1384
-
1385
1381
// Create mappings for each of the new result values.
1386
1382
for (auto [newValue, result] : llvm::zip (newValues, op->getResults ())) {
1387
1383
if (!newValue) {
1388
- resultChanged = true ;
1389
- continue ;
1384
+ // This result was dropped and no replacement value was provided.
1385
+ if (unresolvedMaterializations.contains (op)) {
1386
+ // Do not create another materializations if we are erasing a
1387
+ // materialization.
1388
+ continue ;
1389
+ }
1390
+
1391
+ // Materialize a replacement value "out of thin air".
1392
+ newValue = buildUnresolvedMaterialization (
1393
+ MaterializationKind::Source, computeInsertPoint (result),
1394
+ result.getLoc (), /* inputs=*/ ValueRange (),
1395
+ /* outputType=*/ result.getType (), currentTypeConverter);
1390
1396
}
1397
+
1391
1398
// Remap, and check for any result type changes.
1392
1399
mapping.map (result, newValue);
1393
- resultChanged |= (newValue.getType () != result.getType ());
1394
1400
}
1395
1401
1396
- appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
1397
- resultChanged);
1402
+ appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1398
1403
1399
1404
// Mark this operation and all nested ops as replaced.
1400
1405
op->walk ([&](Operation *op) { replacedOps.insert (op); });
@@ -2359,11 +2364,6 @@ struct OperationConverter {
2359
2364
ConversionPatternRewriterImpl &rewriterImpl,
2360
2365
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2361
2366
2362
- // / Legalize an operation result that was marked as "erased".
2363
- LogicalResult
2364
- legalizeErasedResult (Operation *op, OpResult result,
2365
- ConversionPatternRewriterImpl &rewriterImpl);
2366
-
2367
2367
// / Dialect conversion configuration.
2368
2368
ConversionConfig config;
2369
2369
@@ -2455,77 +2455,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2455
2455
return failure ();
2456
2456
}
2457
2457
2458
- // / Erase all dead unrealized_conversion_cast ops. An op is dead if its results
2459
- // / are not used (transitively) by any op that is not in the given list of
2460
- // / cast ops.
2461
- // /
2462
- // / In particular, this function erases cyclic casts that may be inserted
2463
- // / during the dialect conversion process. E.g.:
2464
- // / %0 = unrealized_conversion_cast(%1)
2465
- // / %1 = unrealized_conversion_cast(%0)
2466
- // Note: This step will become unnecessary when
2467
- // https://github.com/llvm/llvm-project/pull/106760 has been merged.
2468
- static void eraseDeadUnrealizedCasts (
2469
- ArrayRef<UnrealizedConversionCastOp> castOps,
2470
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2471
- // Ops that have already been visited or are currently being visited.
2472
- DenseSet<Operation *> visited;
2473
- // Set of all cast ops for faster lookups.
2474
- DenseSet<Operation *> castOpSet;
2475
- // Set of all cast ops that have been determined to be alive.
2476
- DenseSet<Operation *> live;
2477
-
2478
- for (UnrealizedConversionCastOp op : castOps)
2479
- castOpSet.insert (op);
2480
-
2481
- // Visit a cast operation. Return "true" if the operation is live.
2482
- std::function<bool (Operation *)> visit = [&](Operation *op) -> bool {
2483
- // No need to traverse any IR if the op was already marked as live.
2484
- if (live.contains (op))
2485
- return true ;
2486
-
2487
- // Do not visit ops multiple times. If we find a circle, no live user was
2488
- // found on the current path.
2489
- if (!visited.insert (op).second )
2490
- return false ;
2491
-
2492
- // Visit all users.
2493
- for (Operation *user : op->getUsers ()) {
2494
- // If the user is not an unrealized_conversion_cast op, then the given op
2495
- // is live.
2496
- if (!castOpSet.contains (user)) {
2497
- live.insert (op);
2498
- return true ;
2499
- }
2500
- // Otherwise, it is live if a live op can be reached from one of its
2501
- // users (which must all be unrealized_conversion_cast ops).
2502
- if (visit (user)) {
2503
- live.insert (op);
2504
- return true ;
2505
- }
2506
- }
2507
-
2508
- return false ;
2509
- };
2510
-
2511
- // Visit all cast ops.
2512
- for (UnrealizedConversionCastOp op : castOps) {
2513
- visit (op);
2514
- visited.clear ();
2515
- }
2516
-
2517
- // Erase all cast ops that are dead.
2518
- for (UnrealizedConversionCastOp op : castOps) {
2519
- if (live.contains (op)) {
2520
- if (remainingCastOps)
2521
- remainingCastOps->push_back (op);
2522
- continue ;
2523
- }
2524
- op->dropAllUses ();
2525
- op->erase ();
2526
- }
2527
- }
2528
-
2529
2458
LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
2530
2459
if (ops.empty ())
2531
2460
return success ();
@@ -2584,14 +2513,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2584
2513
// Reconcile all UnrealizedConversionCastOps that were inserted by the
2585
2514
// dialect conversion frameworks. (Not the one that were inserted by
2586
2515
// patterns.)
2587
- SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
2588
- eraseDeadUnrealizedCasts (allCastOps, &remainingCastOps1);
2589
- reconcileUnrealizedCasts (remainingCastOps1, &remainingCastOps2);
2516
+ SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2517
+ reconcileUnrealizedCasts (allCastOps, &remainingCastOps);
2590
2518
2591
2519
// Try to legalize all unresolved materializations.
2592
2520
if (config.buildMaterializations ) {
2593
2521
IRRewriter rewriter (rewriterImpl.context , config.listener );
2594
- for (UnrealizedConversionCastOp castOp : remainingCastOps2 ) {
2522
+ for (UnrealizedConversionCastOp castOp : remainingCastOps ) {
2595
2523
auto it = rewriteMap.find (castOp.getOperation ());
2596
2524
assert (it != rewriteMap.end () && " inconsistent state" );
2597
2525
if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
@@ -2646,30 +2574,22 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2646
2574
for (unsigned i = 0 ; i < rewriterImpl.rewrites .size (); ++i) {
2647
2575
auto *opReplacement =
2648
2576
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites [i].get ());
2649
- if (!opReplacement || !opReplacement-> hasChangedResults () )
2577
+ if (!opReplacement)
2650
2578
continue ;
2651
2579
Operation *op = opReplacement->getOperation ();
2652
2580
for (OpResult result : op->getResults ()) {
2653
- Value newValue = rewriterImpl.mapping .lookupOrNull (result);
2654
-
2655
- // If the operation result was replaced with null, all of the uses of this
2656
- // value should be replaced.
2657
- if (!newValue) {
2658
- if (failed (legalizeErasedResult (op, result, rewriterImpl)))
2659
- return failure ();
2581
+ // If the type of this op result changed and the result is still live,
2582
+ // we need to materialize a conversion.
2583
+ if (rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
2660
2584
continue ;
2661
- }
2662
-
2663
- // Otherwise, check to see if the type of the result changed.
2664
- if (result.getType () == newValue.getType ())
2665
- continue ;
2666
-
2667
2585
Operation *liveUser =
2668
2586
findLiveUserOfReplaced (result, rewriterImpl, inverseMapping);
2669
2587
if (!liveUser)
2670
2588
continue ;
2671
2589
2672
2590
// Legalize this result.
2591
+ Value newValue = rewriterImpl.mapping .lookupOrNull (result);
2592
+ assert (newValue && " replacement value not found" );
2673
2593
Value castValue = rewriterImpl.buildUnresolvedMaterialization (
2674
2594
MaterializationKind::Source, computeInsertPoint (result), op->getLoc (),
2675
2595
/* inputs=*/ newValue, /* outputType=*/ result.getType (),
@@ -2727,25 +2647,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2727
2647
return success ();
2728
2648
}
2729
2649
2730
- LogicalResult OperationConverter::legalizeErasedResult (
2731
- Operation *op, OpResult result,
2732
- ConversionPatternRewriterImpl &rewriterImpl) {
2733
- // If the operation result was replaced with null, all of the uses of this
2734
- // value should be replaced.
2735
- auto liveUserIt = llvm::find_if_not (result.getUsers (), [&](Operation *user) {
2736
- return rewriterImpl.isOpIgnored (user);
2737
- });
2738
- if (liveUserIt != result.user_end ()) {
2739
- InFlightDiagnostic diag = op->emitError (" failed to legalize operation '" )
2740
- << op->getName () << " ' marked as erased" ;
2741
- diag.attachNote (liveUserIt->getLoc ())
2742
- << " found live user of result #" << result.getResultNumber () << " : "
2743
- << *liveUserIt;
2744
- return failure ();
2745
- }
2746
- return success ();
2747
- }
2748
-
2749
2650
// ===----------------------------------------------------------------------===//
2750
2651
// Reconcile Unrealized Casts
2751
2652
// ===----------------------------------------------------------------------===//
0 commit comments