@@ -941,6 +941,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
941
941
// / to modify/access them is invalid rewriter API usage.
942
942
SetVector<Operation *> replacedOps;
943
943
944
+ DenseSet<Operation *> unresolvedMaterializations;
945
+
944
946
// / The current type converter, or nullptr if no type converter is currently
945
947
// / active.
946
948
const TypeConverter *currentTypeConverter = nullptr ;
@@ -1066,6 +1068,7 @@ void UnresolvedMaterializationRewrite::rollback() {
1066
1068
for (Value input : op->getOperands ())
1067
1069
rewriterImpl.mapping .erase (input);
1068
1070
}
1071
+ rewriterImpl.unresolvedMaterializations .erase (op);
1069
1072
op->erase ();
1070
1073
}
1071
1074
@@ -1347,6 +1350,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1347
1350
builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
1348
1351
auto convertOp =
1349
1352
builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1353
+ unresolvedMaterializations.insert (convertOp);
1350
1354
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1351
1355
return convertOp.getResult (0 );
1352
1356
}
@@ -1385,9 +1389,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1385
1389
// Create mappings for each of the new result values.
1386
1390
for (auto [newValue, result] : llvm::zip (newValues, op->getResults ())) {
1387
1391
if (!newValue) {
1388
- resultChanged = true ;
1389
- continue ;
1392
+ // This result was dropped and no replacement value was provided.
1393
+ if (unresolvedMaterializations.contains (op)) {
1394
+ // Do not create another materializations if we are erasing a
1395
+ // materialization.
1396
+ resultChanged = true ;
1397
+ continue ;
1398
+ }
1399
+
1400
+ // Materialize a replacement value "out of thin air".
1401
+ newValue = buildUnresolvedMaterialization (
1402
+ MaterializationKind::Source, computeInsertPoint (result),
1403
+ result.getLoc (), /* inputs=*/ ValueRange (),
1404
+ /* outputType=*/ result.getType (), currentTypeConverter);
1390
1405
}
1406
+
1391
1407
// Remap, and check for any result type changes.
1392
1408
mapping.map (result, newValue);
1393
1409
resultChanged |= (newValue.getType () != result.getType ());
@@ -2359,11 +2375,6 @@ struct OperationConverter {
2359
2375
ConversionPatternRewriterImpl &rewriterImpl,
2360
2376
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2361
2377
2362
- // / Legalize an operation result that was marked as "erased".
2363
- LogicalResult
2364
- legalizeErasedResult (Operation *op, OpResult result,
2365
- ConversionPatternRewriterImpl &rewriterImpl);
2366
-
2367
2378
// / Dialect conversion configuration.
2368
2379
ConversionConfig config;
2369
2380
@@ -2455,77 +2466,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2455
2466
return failure ();
2456
2467
}
2457
2468
2458
- // / Erase all dead unrealized_conversion_cast ops. An op is dead if its results
2459
- // / are not used (transitively) by any op that is not in the given list of
2460
- // / cast ops.
2461
- // /
2462
- // / In particular, this function erases cyclic casts that may be inserted
2463
- // / during the dialect conversion process. E.g.:
2464
- // / %0 = unrealized_conversion_cast(%1)
2465
- // / %1 = unrealized_conversion_cast(%0)
2466
- // Note: This step will become unnecessary when
2467
- // https://github.com/llvm/llvm-project/pull/106760 has been merged.
2468
- static void eraseDeadUnrealizedCasts (
2469
- ArrayRef<UnrealizedConversionCastOp> castOps,
2470
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2471
- // Ops that have already been visited or are currently being visited.
2472
- DenseSet<Operation *> visited;
2473
- // Set of all cast ops for faster lookups.
2474
- DenseSet<Operation *> castOpSet;
2475
- // Set of all cast ops that have been determined to be alive.
2476
- DenseSet<Operation *> live;
2477
-
2478
- for (UnrealizedConversionCastOp op : castOps)
2479
- castOpSet.insert (op);
2480
-
2481
- // Visit a cast operation. Return "true" if the operation is live.
2482
- std::function<bool (Operation *)> visit = [&](Operation *op) -> bool {
2483
- // No need to traverse any IR if the op was already marked as live.
2484
- if (live.contains (op))
2485
- return true ;
2486
-
2487
- // Do not visit ops multiple times. If we find a circle, no live user was
2488
- // found on the current path.
2489
- if (!visited.insert (op).second )
2490
- return false ;
2491
-
2492
- // Visit all users.
2493
- for (Operation *user : op->getUsers ()) {
2494
- // If the user is not an unrealized_conversion_cast op, then the given op
2495
- // is live.
2496
- if (!castOpSet.contains (user)) {
2497
- live.insert (op);
2498
- return true ;
2499
- }
2500
- // Otherwise, it is live if a live op can be reached from one of its
2501
- // users (which must all be unrealized_conversion_cast ops).
2502
- if (visit (user)) {
2503
- live.insert (op);
2504
- return true ;
2505
- }
2506
- }
2507
-
2508
- return false ;
2509
- };
2510
-
2511
- // Visit all cast ops.
2512
- for (UnrealizedConversionCastOp op : castOps) {
2513
- visit (op);
2514
- visited.clear ();
2515
- }
2516
-
2517
- // Erase all cast ops that are dead.
2518
- for (UnrealizedConversionCastOp op : castOps) {
2519
- if (live.contains (op)) {
2520
- if (remainingCastOps)
2521
- remainingCastOps->push_back (op);
2522
- continue ;
2523
- }
2524
- op->dropAllUses ();
2525
- op->erase ();
2526
- }
2527
- }
2528
-
2529
2469
LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
2530
2470
if (ops.empty ())
2531
2471
return success ();
@@ -2584,14 +2524,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2584
2524
// Reconcile all UnrealizedConversionCastOps that were inserted by the
2585
2525
// dialect conversion frameworks. (Not the one that were inserted by
2586
2526
// patterns.)
2587
- SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
2588
- eraseDeadUnrealizedCasts (allCastOps, &remainingCastOps1);
2589
- reconcileUnrealizedCasts (remainingCastOps1, &remainingCastOps2);
2527
+ SmallVector<UnrealizedConversionCastOp> remainingCastOps;
2528
+ reconcileUnrealizedCasts (allCastOps, &remainingCastOps);
2590
2529
2591
2530
// Try to legalize all unresolved materializations.
2592
2531
if (config.buildMaterializations ) {
2593
2532
IRRewriter rewriter (rewriterImpl.context , config.listener );
2594
- for (UnrealizedConversionCastOp castOp : remainingCastOps2 ) {
2533
+ for (UnrealizedConversionCastOp castOp : remainingCastOps ) {
2595
2534
auto it = rewriteMap.find (castOp.getOperation ());
2596
2535
assert (it != rewriteMap.end () && " inconsistent state" );
2597
2536
if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
@@ -2650,26 +2589,18 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2650
2589
continue ;
2651
2590
Operation *op = opReplacement->getOperation ();
2652
2591
for (OpResult result : op->getResults ()) {
2653
- Value newValue = rewriterImpl.mapping .lookupOrNull (result);
2654
-
2655
- // If the operation result was replaced with null, all of the uses of this
2656
- // value should be replaced.
2657
- if (!newValue) {
2658
- if (failed (legalizeErasedResult (op, result, rewriterImpl)))
2659
- return failure ();
2660
- continue ;
2661
- }
2662
-
2663
- // Otherwise, check to see if the type of the result changed.
2664
- if (result.getType () == newValue.getType ())
2592
+ // If the type of this op result changed and the result is still live,
2593
+ // we need to materialize a conversion.
2594
+ if (rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
2665
2595
continue ;
2666
-
2667
2596
Operation *liveUser =
2668
2597
findLiveUserOfReplaced (result, rewriterImpl, inverseMapping);
2669
2598
if (!liveUser)
2670
2599
continue ;
2671
2600
2672
2601
// Legalize this result.
2602
+ Value newValue = rewriterImpl.mapping .lookupOrNull (result);
2603
+ assert (newValue && " replacement value not found" );
2673
2604
Value castValue = rewriterImpl.buildUnresolvedMaterialization (
2674
2605
MaterializationKind::Source, computeInsertPoint (result), op->getLoc (),
2675
2606
/* inputs=*/ newValue, /* outputType=*/ result.getType (),
@@ -2727,25 +2658,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2727
2658
return success ();
2728
2659
}
2729
2660
2730
- LogicalResult OperationConverter::legalizeErasedResult (
2731
- Operation *op, OpResult result,
2732
- ConversionPatternRewriterImpl &rewriterImpl) {
2733
- // If the operation result was replaced with null, all of the uses of this
2734
- // value should be replaced.
2735
- auto liveUserIt = llvm::find_if_not (result.getUsers (), [&](Operation *user) {
2736
- return rewriterImpl.isOpIgnored (user);
2737
- });
2738
- if (liveUserIt != result.user_end ()) {
2739
- InFlightDiagnostic diag = op->emitError (" failed to legalize operation '" )
2740
- << op->getName () << " ' marked as erased" ;
2741
- diag.attachNote (liveUserIt->getLoc ())
2742
- << " found live user of result #" << result.getResultNumber () << " : "
2743
- << *liveUserIt;
2744
- return failure ();
2745
- }
2746
- return success ();
2747
- }
2748
-
2749
2661
// ===----------------------------------------------------------------------===//
2750
2662
// Reconcile Unrealized Casts
2751
2663
// ===----------------------------------------------------------------------===//
0 commit comments