@@ -75,6 +75,10 @@ namespace {
75
75
// / This class wraps a IRMapping to provide recursive lookup
76
76
// / functionality, i.e. we will traverse if the mapped value also has a mapping.
77
77
struct ConversionValueMapping {
78
+ // / Return "true" if an SSA value is mapped to the given value. May return
79
+ // / false positives.
80
+ bool isMappedTo (Value value) const { return mappedTo.contains (value); }
81
+
78
82
// / Lookup the most recently mapped value with the desired type in the
79
83
// / mapping.
80
84
// /
@@ -99,22 +103,18 @@ struct ConversionValueMapping {
99
103
assert (it != oldVal && " inserting cyclic mapping" );
100
104
});
101
105
mapping.map (oldVal, newVal);
106
+ mappedTo.insert (newVal);
102
107
}
103
108
104
109
// / Drop the last mapping for the given value.
105
110
void erase (Value value) { mapping.erase (value); }
106
111
107
- // / Returns the inverse raw value mapping (without recursive query support).
108
- DenseMap<Value, SmallVector<Value>> getInverse () const {
109
- DenseMap<Value, SmallVector<Value>> inverse;
110
- for (auto &it : mapping.getValueMap ())
111
- inverse[it.second ].push_back (it.first );
112
- return inverse;
113
- }
114
-
115
112
private:
116
113
// / Current value mappings.
117
114
IRMapping mapping;
115
+
116
+ // / All SSA values that are mapped to. May contain false positives.
117
+ DenseSet<Value> mappedTo;
118
118
};
119
119
} // namespace
120
120
@@ -434,29 +434,23 @@ class MoveBlockRewrite : public BlockRewrite {
434
434
class BlockTypeConversionRewrite : public BlockRewrite {
435
435
public:
436
436
BlockTypeConversionRewrite (ConversionPatternRewriterImpl &rewriterImpl,
437
- Block *block, Block *origBlock,
438
- const TypeConverter *converter)
437
+ Block *block, Block *origBlock)
439
438
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
440
- origBlock (origBlock), converter(converter) {}
439
+ origBlock (origBlock) {}
441
440
442
441
static bool classof (const IRRewrite *rewrite) {
443
442
return rewrite->getKind () == Kind::BlockTypeConversion;
444
443
}
445
444
446
445
Block *getOrigBlock () const { return origBlock; }
447
446
448
- const TypeConverter *getConverter () const { return converter; }
449
-
450
447
void commit (RewriterBase &rewriter) override ;
451
448
452
449
void rollback () override ;
453
450
454
451
private:
455
452
// / The original block that was requested to have its signature converted.
456
453
Block *origBlock;
457
-
458
- // / The type converter used to convert the arguments.
459
- const TypeConverter *converter;
460
454
};
461
455
462
456
// / Replacing a block argument. This rewrite is not immediately reflected in the
@@ -465,8 +459,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
465
459
class ReplaceBlockArgRewrite : public BlockRewrite {
466
460
public:
467
461
ReplaceBlockArgRewrite (ConversionPatternRewriterImpl &rewriterImpl,
468
- Block *block, BlockArgument arg)
469
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
462
+ Block *block, BlockArgument arg,
463
+ const TypeConverter *converter)
464
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
465
+ converter (converter) {}
470
466
471
467
static bool classof (const IRRewrite *rewrite) {
472
468
return rewrite->getKind () == Kind::ReplaceBlockArg;
@@ -478,6 +474,9 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
478
474
479
475
private:
480
476
BlockArgument arg;
477
+
478
+ // / The current type converter when the block argument was replaced.
479
+ const TypeConverter *converter;
481
480
};
482
481
483
482
// / An operation rewrite.
@@ -627,8 +626,6 @@ class ReplaceOperationRewrite : public OperationRewrite {
627
626
628
627
void cleanup (RewriterBase &rewriter) override ;
629
628
630
- const TypeConverter *getConverter () const { return converter; }
631
-
632
629
private:
633
630
// / An optional type converter that can be used to materialize conversions
634
631
// / between the new and old values if necessary.
@@ -825,6 +822,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
825
822
ValueRange replacements, Value originalValue,
826
823
const TypeConverter *converter);
827
824
825
+ // / Find a replacement value for the given SSA value in the conversion value
826
+ // / mapping. The replacement value must have the same type as the given SSA
827
+ // / value. If there is no replacement value with the correct type, find the
828
+ // / latest replacement value (regardless of the type) and build a source
829
+ // / materialization.
830
+ Value findOrBuildReplacementValue (Value value,
831
+ const TypeConverter *converter);
832
+
828
833
// ===--------------------------------------------------------------------===//
829
834
// Rewriter Notification Hooks
830
835
// ===--------------------------------------------------------------------===//
@@ -970,7 +975,7 @@ void BlockTypeConversionRewrite::rollback() {
970
975
}
971
976
972
977
void ReplaceBlockArgRewrite::commit (RewriterBase &rewriter) {
973
- Value repl = rewriterImpl.mapping . lookupOrNull (arg, arg. getType () );
978
+ Value repl = rewriterImpl.findOrBuildReplacementValue (arg, converter );
974
979
if (!repl)
975
980
return ;
976
981
@@ -999,7 +1004,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
999
1004
// Compute replacement values.
1000
1005
SmallVector<Value> replacements =
1001
1006
llvm::map_to_vector (op->getResults (), [&](OpResult result) {
1002
- return rewriterImpl.mapping . lookupOrNull (result, result. getType () );
1007
+ return rewriterImpl.findOrBuildReplacementValue (result, converter );
1003
1008
});
1004
1009
1005
1010
// Notify the listener that the operation is about to be replaced.
@@ -1069,8 +1074,10 @@ void UnresolvedMaterializationRewrite::rollback() {
1069
1074
void ConversionPatternRewriterImpl::applyRewrites () {
1070
1075
// Commit all rewrites.
1071
1076
IRRewriter rewriter (context, config.listener );
1072
- for (auto &rewrite : rewrites)
1073
- rewrite->commit (rewriter);
1077
+ // Note: New rewrites may be added during the "commit" phase and the
1078
+ // `rewrites` vector may reallocate.
1079
+ for (size_t i = 0 ; i < rewrites.size (); ++i)
1080
+ rewrites[i]->commit (rewriter);
1074
1081
1075
1082
// Clean up all rewrites.
1076
1083
for (auto &rewrite : rewrites)
@@ -1275,7 +1282,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1275
1282
/* inputs=*/ ValueRange (),
1276
1283
/* outputType=*/ origArgType, /* originalType=*/ Type (), converter);
1277
1284
mapping.map (origArg, repl);
1278
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1285
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter );
1279
1286
continue ;
1280
1287
}
1281
1288
@@ -1285,7 +1292,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1285
1292
" invalid to provide a replacement value when the argument isn't "
1286
1293
" dropped" );
1287
1294
mapping.map (origArg, repl);
1288
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1295
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter );
1289
1296
continue ;
1290
1297
}
1291
1298
@@ -1298,10 +1305,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1298
1305
insertNTo1Materialization (
1299
1306
OpBuilder::InsertPoint (newBlock, newBlock->begin ()), origArg.getLoc (),
1300
1307
/* replacements=*/ replArgs, /* outputValue=*/ origArg, converter);
1301
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1308
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter );
1302
1309
}
1303
1310
1304
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter );
1311
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
1305
1312
1306
1313
// Erase the old block. (It is just unlinked for now and will be erased during
1307
1314
// cleanup.)
@@ -1371,6 +1378,42 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
1371
1378
}
1372
1379
}
1373
1380
1381
+ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue (
1382
+ Value value, const TypeConverter *converter) {
1383
+ // Find a replacement value with the same type.
1384
+ Value repl = mapping.lookupOrNull (value, value.getType ());
1385
+ if (repl)
1386
+ return repl;
1387
+
1388
+ // Check if the value is dead. No replacement value is needed in that case.
1389
+ // This is an approximate check that may have false negatives but does not
1390
+ // require computing and traversing an inverse mapping. (We may end up
1391
+ // building source materializations that are never used and that fold away.)
1392
+ if (llvm::all_of (value.getUsers (),
1393
+ [&](Operation *op) { return replacedOps.contains (op); }) &&
1394
+ !mapping.isMappedTo (value))
1395
+ return Value ();
1396
+
1397
+ // No replacement value was found. Get the latest replacement value
1398
+ // (regardless of the type) and build a source materialization to the
1399
+ // original type.
1400
+ repl = mapping.lookupOrNull (value);
1401
+ if (!repl) {
1402
+ // No replacement value is registered in the mapping. This means that the
1403
+ // value is dropped and no longer needed. (If the value were still needed,
1404
+ // a source materialization producing a replacement value "out of thin air"
1405
+ // would have already been created during `replaceOp` or
1406
+ // `applySignatureConversion`.)
1407
+ return Value ();
1408
+ }
1409
+ Value castValue = buildUnresolvedMaterialization (
1410
+ MaterializationKind::Source, computeInsertPoint (repl), value.getLoc (),
1411
+ /* inputs=*/ repl, /* outputType=*/ value.getType (),
1412
+ /* originalType=*/ Type (), converter);
1413
+ mapping.map (value, castValue);
1414
+ return castValue;
1415
+ }
1416
+
1374
1417
// ===----------------------------------------------------------------------===//
1375
1418
// Rewriter Notification Hooks
1376
1419
@@ -1597,7 +1640,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1597
1640
<< " '(in region of '" << parentOp->getName ()
1598
1641
<< " '(" << from.getOwner ()->getParentOp () << " )\n " ;
1599
1642
});
1600
- impl->appendRewrite <ReplaceBlockArgRewrite>(from.getOwner (), from);
1643
+ impl->appendRewrite <ReplaceBlockArgRewrite>(from.getOwner (), from,
1644
+ impl->currentTypeConverter );
1601
1645
impl->mapping .map (impl->mapping .lookupOrDefault (from), to);
1602
1646
}
1603
1647
@@ -2417,10 +2461,6 @@ struct OperationConverter {
2417
2461
// / Converts an operation with the given rewriter.
2418
2462
LogicalResult convert (ConversionPatternRewriter &rewriter, Operation *op);
2419
2463
2420
- // / This method is called after the conversion process to legalize any
2421
- // / remaining artifacts and complete the conversion.
2422
- void finalize (ConversionPatternRewriter &rewriter);
2423
-
2424
2464
// / Dialect conversion configuration.
2425
2465
ConversionConfig config;
2426
2466
@@ -2541,11 +2581,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2541
2581
if (failed (convert (rewriter, op)))
2542
2582
return rewriterImpl.undoRewrites (), failure ();
2543
2583
2544
- // Now that all of the operations have been converted, finalize the conversion
2545
- // process to ensure any lingering conversion artifacts are cleaned up and
2546
- // legalized.
2547
- finalize (rewriter);
2548
-
2549
2584
// After a successful conversion, apply rewrites.
2550
2585
rewriterImpl.applyRewrites ();
2551
2586
@@ -2579,80 +2614,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2579
2614
return success ();
2580
2615
}
2581
2616
2582
- // / Finds a user of the given value, or of any other value that the given value
2583
- // / replaced, that was not replaced in the conversion process.
2584
- static Operation *findLiveUserOfReplaced (
2585
- Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2586
- const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2587
- SmallVector<Value> worklist = {initialValue};
2588
- while (!worklist.empty ()) {
2589
- Value value = worklist.pop_back_val ();
2590
-
2591
- // Walk the users of this value to see if there are any live users that
2592
- // weren't replaced during conversion.
2593
- auto liveUserIt = llvm::find_if_not (value.getUsers (), [&](Operation *user) {
2594
- return rewriterImpl.isOpIgnored (user);
2595
- });
2596
- if (liveUserIt != value.user_end ())
2597
- return *liveUserIt;
2598
- auto mapIt = inverseMapping.find (value);
2599
- if (mapIt != inverseMapping.end ())
2600
- worklist.append (mapIt->second );
2601
- }
2602
- return nullptr ;
2603
- }
2604
-
2605
- // / Helper function that returns the replaced values and the type converter if
2606
- // / the given rewrite object is an "operation replacement" or a "block type
2607
- // / conversion" (which corresponds to a "block replacement"). Otherwise, return
2608
- // / an empty ValueRange and a null type converter pointer.
2609
- static std::pair<ValueRange, const TypeConverter *>
2610
- getReplacedValues (IRRewrite *rewrite) {
2611
- if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2612
- return {opRewrite->getOperation ()->getResults (), opRewrite->getConverter ()};
2613
- if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2614
- return {blockRewrite->getOrigBlock ()->getArguments (),
2615
- blockRewrite->getConverter ()};
2616
- return {};
2617
- }
2618
-
2619
- void OperationConverter::finalize (ConversionPatternRewriter &rewriter) {
2620
- ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2621
- DenseMap<Value, SmallVector<Value>> inverseMapping =
2622
- rewriterImpl.mapping .getInverse ();
2623
-
2624
- // Process requested value replacements.
2625
- for (unsigned i = 0 , e = rewriterImpl.rewrites .size (); i < e; ++i) {
2626
- ValueRange replacedValues;
2627
- const TypeConverter *converter;
2628
- std::tie (replacedValues, converter) =
2629
- getReplacedValues (rewriterImpl.rewrites [i].get ());
2630
- for (Value originalValue : replacedValues) {
2631
- // If the type of this value changed and the value is still live, we need
2632
- // to materialize a conversion.
2633
- if (rewriterImpl.mapping .lookupOrNull (originalValue,
2634
- originalValue.getType ()))
2635
- continue ;
2636
- Operation *liveUser =
2637
- findLiveUserOfReplaced (originalValue, rewriterImpl, inverseMapping);
2638
- if (!liveUser)
2639
- continue ;
2640
-
2641
- // Legalize this value replacement.
2642
- Value newValue = rewriterImpl.mapping .lookupOrNull (originalValue);
2643
- assert (newValue && " replacement value not found" );
2644
- Value castValue = rewriterImpl.buildUnresolvedMaterialization (
2645
- MaterializationKind::Source, computeInsertPoint (newValue),
2646
- originalValue.getLoc (),
2647
- /* inputs=*/ newValue, /* outputType=*/ originalValue.getType (),
2648
- /* originalType=*/ Type (), converter);
2649
- rewriterImpl.mapping .map (originalValue, castValue);
2650
- inverseMapping[castValue].push_back (originalValue);
2651
- llvm::erase (inverseMapping[newValue], originalValue);
2652
- }
2653
- }
2654
- }
2655
-
2656
2617
// ===----------------------------------------------------------------------===//
2657
2618
// Reconcile Unrealized Casts
2658
2619
// ===----------------------------------------------------------------------===//
0 commit comments