@@ -75,10 +75,6 @@ namespace {
75
75
// / This class wraps a IRMapping to provide recursive lookup
76
76
// / functionality, i.e. we will traverse if the mapped value also has a mapping.
77
77
struct ConversionValueMapping {
78
- // / Return "true" if an SSA value is mapped to the given value. May return
79
- // / false positives.
80
- bool isMappedTo (Value value) const { return mappedTo.contains (value); }
81
-
82
78
// / Lookup the most recently mapped value with the desired type in the
83
79
// / mapping.
84
80
// /
@@ -103,18 +99,22 @@ struct ConversionValueMapping {
103
99
assert (it != oldVal && " inserting cyclic mapping" );
104
100
});
105
101
mapping.map (oldVal, newVal);
106
- mappedTo.insert (newVal);
107
102
}
108
103
109
104
// / Drop the last mapping for the given value.
110
105
void erase (Value value) { mapping.erase (value); }
111
106
107
+ // / Returns the inverse raw value mapping (without recursive query support).
108
+ DenseMap<Value, SmallVector<Value>> getInverse () const {
109
+ DenseMap<Value, SmallVector<Value>> inverse;
110
+ for (auto &it : mapping.getValueMap ())
111
+ inverse[it.second ].push_back (it.first );
112
+ return inverse;
113
+ }
114
+
112
115
private:
113
116
// / Current value mappings.
114
117
IRMapping mapping;
115
-
116
- // / All SSA values that are mapped to. May contain false positives.
117
- DenseSet<Value> mappedTo;
118
118
};
119
119
} // namespace
120
120
@@ -434,23 +434,29 @@ class MoveBlockRewrite : public BlockRewrite {
434
434
class BlockTypeConversionRewrite : public BlockRewrite {
435
435
public:
436
436
BlockTypeConversionRewrite (ConversionPatternRewriterImpl &rewriterImpl,
437
- Block *block, Block *origBlock)
437
+ Block *block, Block *origBlock,
438
+ const TypeConverter *converter)
438
439
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
439
- origBlock (origBlock) {}
440
+ origBlock (origBlock), converter(converter) {}
440
441
441
442
static bool classof (const IRRewrite *rewrite) {
442
443
return rewrite->getKind () == Kind::BlockTypeConversion;
443
444
}
444
445
445
446
Block *getOrigBlock () const { return origBlock; }
446
447
448
+ const TypeConverter *getConverter () const { return converter; }
449
+
447
450
void commit (RewriterBase &rewriter) override ;
448
451
449
452
void rollback () override ;
450
453
451
454
private:
452
455
// / The original block that was requested to have its signature converted.
453
456
Block *origBlock;
457
+
458
+ // / The type converter used to convert the arguments.
459
+ const TypeConverter *converter;
454
460
};
455
461
456
462
// / Replacing a block argument. This rewrite is not immediately reflected in the
@@ -459,10 +465,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
459
465
class ReplaceBlockArgRewrite : public BlockRewrite {
460
466
public:
461
467
ReplaceBlockArgRewrite (ConversionPatternRewriterImpl &rewriterImpl,
462
- Block *block, BlockArgument arg,
463
- const TypeConverter *converter)
464
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
465
- converter (converter) {}
468
+ Block *block, BlockArgument arg)
469
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
466
470
467
471
static bool classof (const IRRewrite *rewrite) {
468
472
return rewrite->getKind () == Kind::ReplaceBlockArg;
@@ -474,9 +478,6 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
474
478
475
479
private:
476
480
BlockArgument arg;
477
-
478
- // / The current type converter when the block argument was replaced.
479
- const TypeConverter *converter;
480
481
};
481
482
482
483
// / An operation rewrite.
@@ -626,6 +627,8 @@ class ReplaceOperationRewrite : public OperationRewrite {
626
627
627
628
void cleanup (RewriterBase &rewriter) override ;
628
629
630
+ const TypeConverter *getConverter () const { return converter; }
631
+
629
632
private:
630
633
// / An optional type converter that can be used to materialize conversions
631
634
// / between the new and old values if necessary.
@@ -822,14 +825,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
822
825
ValueRange replacements, Value originalValue,
823
826
const TypeConverter *converter);
824
827
825
- // / Find a replacement value for the given SSA value in the conversion value
826
- // / mapping. The replacement value must have the same type as the given SSA
827
- // / value. If there is no replacement value with the correct type, find the
828
- // / latest replacement value (regardless of the type) and build a source
829
- // / materialization.
830
- Value findOrBuildReplacementValue (Value value,
831
- const TypeConverter *converter);
832
-
833
828
// ===--------------------------------------------------------------------===//
834
829
// Rewriter Notification Hooks
835
830
// ===--------------------------------------------------------------------===//
@@ -975,7 +970,7 @@ void BlockTypeConversionRewrite::rollback() {
975
970
}
976
971
977
972
void ReplaceBlockArgRewrite::commit (RewriterBase &rewriter) {
978
- Value repl = rewriterImpl.findOrBuildReplacementValue (arg, converter );
973
+ Value repl = rewriterImpl.mapping . lookupOrNull (arg, arg. getType () );
979
974
if (!repl)
980
975
return ;
981
976
@@ -1004,7 +999,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1004
999
// Compute replacement values.
1005
1000
SmallVector<Value> replacements =
1006
1001
llvm::map_to_vector (op->getResults (), [&](OpResult result) {
1007
- return rewriterImpl.findOrBuildReplacementValue (result, converter );
1002
+ return rewriterImpl.mapping . lookupOrNull (result, result. getType () );
1008
1003
});
1009
1004
1010
1005
// Notify the listener that the operation is about to be replaced.
@@ -1074,10 +1069,8 @@ void UnresolvedMaterializationRewrite::rollback() {
1074
1069
void ConversionPatternRewriterImpl::applyRewrites () {
1075
1070
// Commit all rewrites.
1076
1071
IRRewriter rewriter (context, config.listener );
1077
- // Note: New rewrites may be added during the "commit" phase and the
1078
- // `rewrites` vector may reallocate.
1079
- for (size_t i = 0 ; i < rewrites.size (); ++i)
1080
- rewrites[i]->commit (rewriter);
1072
+ for (auto &rewrite : rewrites)
1073
+ rewrite->commit (rewriter);
1081
1074
1082
1075
// Clean up all rewrites.
1083
1076
for (auto &rewrite : rewrites)
@@ -1282,7 +1275,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1282
1275
/* inputs=*/ ValueRange (),
1283
1276
/* outputType=*/ origArgType, /* originalType=*/ Type (), converter);
1284
1277
mapping.map (origArg, repl);
1285
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter );
1278
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1286
1279
continue ;
1287
1280
}
1288
1281
@@ -1292,7 +1285,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1292
1285
" invalid to provide a replacement value when the argument isn't "
1293
1286
" dropped" );
1294
1287
mapping.map (origArg, repl);
1295
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter );
1288
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1296
1289
continue ;
1297
1290
}
1298
1291
@@ -1305,10 +1298,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1305
1298
insertNTo1Materialization (
1306
1299
OpBuilder::InsertPoint (newBlock, newBlock->begin ()), origArg.getLoc (),
1307
1300
/* replacements=*/ replArgs, /* outputValue=*/ origArg, converter);
1308
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter );
1301
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1309
1302
}
1310
1303
1311
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
1304
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter );
1312
1305
1313
1306
// Erase the old block. (It is just unlinked for now and will be erased during
1314
1307
// cleanup.)
@@ -1378,41 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
1378
1371
}
1379
1372
}
1380
1373
1381
- Value ConversionPatternRewriterImpl::findOrBuildReplacementValue (
1382
- Value value, const TypeConverter *converter) {
1383
- // Find a replacement value with the same type.
1384
- Value repl = mapping.lookupOrNull (value, value.getType ());
1385
- if (repl)
1386
- return repl;
1387
-
1388
- // Check if the value is dead. No replacement value is needed in that case.
1389
- // This is an approximate check that may have false negatives but does not
1390
- // require computing and traversing an inverse mapping. (We may end up
1391
- // building source materializations that are never used and that fold away.)
1392
- if (llvm::all_of (value.getUsers (),
1393
- [&](Operation *op) { return replacedOps.contains (op); }) &&
1394
- !mapping.isMappedTo (value))
1395
- return Value ();
1396
-
1397
- // No replacement value was found. Get the latest replacement value
1398
- // (regardless of the type) and build a source materialization to the
1399
- // original type.
1400
- repl = mapping.lookupOrNull (value);
1401
- if (!repl) {
1402
- // No replacement value is registered in the mapping. This means that the
1403
- // value is dropped and no longer needed. (If the value were still needed,
1404
- // a source materialization producing a replacement value "out of thin air"
1405
- // would have already been created during `replaceOp` or
1406
- // `applySignatureConversion`.)
1407
- return Value ();
1408
- }
1409
- Value castValue = buildUnresolvedMaterialization (
1410
- MaterializationKind::Source, computeInsertPoint (repl), value.getLoc (),
1411
- /* inputs=*/ repl, /* outputType=*/ value.getType (),
1412
- /* originalType=*/ Type (), converter);
1413
- return castValue;
1414
- }
1415
-
1416
1374
// ===----------------------------------------------------------------------===//
1417
1375
// Rewriter Notification Hooks
1418
1376
@@ -1639,8 +1597,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1639
1597
<< " '(in region of '" << parentOp->getName ()
1640
1598
<< " '(" << from.getOwner ()->getParentOp () << " )\n " ;
1641
1599
});
1642
- impl->appendRewrite <ReplaceBlockArgRewrite>(from.getOwner (), from,
1643
- impl->currentTypeConverter );
1600
+ impl->appendRewrite <ReplaceBlockArgRewrite>(from.getOwner (), from);
1644
1601
impl->mapping .map (impl->mapping .lookupOrDefault (from), to);
1645
1602
}
1646
1603
@@ -2460,6 +2417,10 @@ struct OperationConverter {
2460
2417
// / Converts an operation with the given rewriter.
2461
2418
LogicalResult convert (ConversionPatternRewriter &rewriter, Operation *op);
2462
2419
2420
+ // / This method is called after the conversion process to legalize any
2421
+ // / remaining artifacts and complete the conversion.
2422
+ void finalize (ConversionPatternRewriter &rewriter);
2423
+
2463
2424
// / Dialect conversion configuration.
2464
2425
ConversionConfig config;
2465
2426
@@ -2580,6 +2541,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2580
2541
if (failed (convert (rewriter, op)))
2581
2542
return rewriterImpl.undoRewrites (), failure ();
2582
2543
2544
+ // Now that all of the operations have been converted, finalize the conversion
2545
+ // process to ensure any lingering conversion artifacts are cleaned up and
2546
+ // legalized.
2547
+ finalize (rewriter);
2548
+
2583
2549
// After a successful conversion, apply rewrites.
2584
2550
rewriterImpl.applyRewrites ();
2585
2551
@@ -2613,6 +2579,80 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2613
2579
return success ();
2614
2580
}
2615
2581
2582
+ // / Finds a user of the given value, or of any other value that the given value
2583
+ // / replaced, that was not replaced in the conversion process.
2584
+ static Operation *findLiveUserOfReplaced (
2585
+ Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2586
+ const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2587
+ SmallVector<Value> worklist = {initialValue};
2588
+ while (!worklist.empty ()) {
2589
+ Value value = worklist.pop_back_val ();
2590
+
2591
+ // Walk the users of this value to see if there are any live users that
2592
+ // weren't replaced during conversion.
2593
+ auto liveUserIt = llvm::find_if_not (value.getUsers (), [&](Operation *user) {
2594
+ return rewriterImpl.isOpIgnored (user);
2595
+ });
2596
+ if (liveUserIt != value.user_end ())
2597
+ return *liveUserIt;
2598
+ auto mapIt = inverseMapping.find (value);
2599
+ if (mapIt != inverseMapping.end ())
2600
+ worklist.append (mapIt->second );
2601
+ }
2602
+ return nullptr ;
2603
+ }
2604
+
2605
+ // / Helper function that returns the replaced values and the type converter if
2606
+ // / the given rewrite object is an "operation replacement" or a "block type
2607
+ // / conversion" (which corresponds to a "block replacement"). Otherwise, return
2608
+ // / an empty ValueRange and a null type converter pointer.
2609
+ static std::pair<ValueRange, const TypeConverter *>
2610
+ getReplacedValues (IRRewrite *rewrite) {
2611
+ if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2612
+ return {opRewrite->getOperation ()->getResults (), opRewrite->getConverter ()};
2613
+ if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2614
+ return {blockRewrite->getOrigBlock ()->getArguments (),
2615
+ blockRewrite->getConverter ()};
2616
+ return {};
2617
+ }
2618
+
2619
+ void OperationConverter::finalize (ConversionPatternRewriter &rewriter) {
2620
+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2621
+ DenseMap<Value, SmallVector<Value>> inverseMapping =
2622
+ rewriterImpl.mapping .getInverse ();
2623
+
2624
+ // Process requested value replacements.
2625
+ for (unsigned i = 0 , e = rewriterImpl.rewrites .size (); i < e; ++i) {
2626
+ ValueRange replacedValues;
2627
+ const TypeConverter *converter;
2628
+ std::tie (replacedValues, converter) =
2629
+ getReplacedValues (rewriterImpl.rewrites [i].get ());
2630
+ for (Value originalValue : replacedValues) {
2631
+ // If the type of this value changed and the value is still live, we need
2632
+ // to materialize a conversion.
2633
+ if (rewriterImpl.mapping .lookupOrNull (originalValue,
2634
+ originalValue.getType ()))
2635
+ continue ;
2636
+ Operation *liveUser =
2637
+ findLiveUserOfReplaced (originalValue, rewriterImpl, inverseMapping);
2638
+ if (!liveUser)
2639
+ continue ;
2640
+
2641
+ // Legalize this value replacement.
2642
+ Value newValue = rewriterImpl.mapping .lookupOrNull (originalValue);
2643
+ assert (newValue && " replacement value not found" );
2644
+ Value castValue = rewriterImpl.buildUnresolvedMaterialization (
2645
+ MaterializationKind::Source, computeInsertPoint (newValue),
2646
+ originalValue.getLoc (),
2647
+ /* inputs=*/ newValue, /* outputType=*/ originalValue.getType (),
2648
+ /* originalType=*/ Type (), converter);
2649
+ rewriterImpl.mapping .map (originalValue, castValue);
2650
+ inverseMapping[castValue].push_back (originalValue);
2651
+ llvm::erase (inverseMapping[newValue], originalValue);
2652
+ }
2653
+ }
2654
+ }
2655
+
2616
2656
// ===----------------------------------------------------------------------===//
2617
2657
// Reconcile Unrealized Casts
2618
2658
// ===----------------------------------------------------------------------===//
0 commit comments