Skip to content

Commit 4056d93

Browse files
Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase" (#117094)
Reverts #116934 This commit broke the build.
1 parent aa65473 commit 4056d93

File tree

1 file changed

+112
-72
lines changed

1 file changed

+112
-72
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 112 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,6 @@ namespace {
7575
/// This class wraps a IRMapping to provide recursive lookup
7676
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
7777
struct ConversionValueMapping {
78-
/// Return "true" if an SSA value is mapped to the given value. May return
79-
/// false positives.
80-
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
81-
8278
/// Lookup the most recently mapped value with the desired type in the
8379
/// mapping.
8480
///
@@ -103,18 +99,22 @@ struct ConversionValueMapping {
10399
assert(it != oldVal && "inserting cyclic mapping");
104100
});
105101
mapping.map(oldVal, newVal);
106-
mappedTo.insert(newVal);
107102
}
108103

109104
/// Drop the last mapping for the given value.
110105
void erase(Value value) { mapping.erase(value); }
111106

107+
/// Returns the inverse raw value mapping (without recursive query support).
108+
DenseMap<Value, SmallVector<Value>> getInverse() const {
109+
DenseMap<Value, SmallVector<Value>> inverse;
110+
for (auto &it : mapping.getValueMap())
111+
inverse[it.second].push_back(it.first);
112+
return inverse;
113+
}
114+
112115
private:
113116
/// Current value mappings.
114117
IRMapping mapping;
115-
116-
/// All SSA values that are mapped to. May contain false positives.
117-
DenseSet<Value> mappedTo;
118118
};
119119
} // namespace
120120

@@ -434,23 +434,29 @@ class MoveBlockRewrite : public BlockRewrite {
434434
class BlockTypeConversionRewrite : public BlockRewrite {
435435
public:
436436
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
437-
Block *block, Block *origBlock)
437+
Block *block, Block *origBlock,
438+
const TypeConverter *converter)
438439
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
439-
origBlock(origBlock) {}
440+
origBlock(origBlock), converter(converter) {}
440441

441442
static bool classof(const IRRewrite *rewrite) {
442443
return rewrite->getKind() == Kind::BlockTypeConversion;
443444
}
444445

445446
Block *getOrigBlock() const { return origBlock; }
446447

448+
const TypeConverter *getConverter() const { return converter; }
449+
447450
void commit(RewriterBase &rewriter) override;
448451

449452
void rollback() override;
450453

451454
private:
452455
/// The original block that was requested to have its signature converted.
453456
Block *origBlock;
457+
458+
/// The type converter used to convert the arguments.
459+
const TypeConverter *converter;
454460
};
455461

456462
/// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -459,10 +465,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
459465
class ReplaceBlockArgRewrite : public BlockRewrite {
460466
public:
461467
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
462-
Block *block, BlockArgument arg,
463-
const TypeConverter *converter)
464-
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
465-
converter(converter) {}
468+
Block *block, BlockArgument arg)
469+
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
466470

467471
static bool classof(const IRRewrite *rewrite) {
468472
return rewrite->getKind() == Kind::ReplaceBlockArg;
@@ -474,9 +478,6 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
474478

475479
private:
476480
BlockArgument arg;
477-
478-
/// The current type converter when the block argument was replaced.
479-
const TypeConverter *converter;
480481
};
481482

482483
/// An operation rewrite.
@@ -626,6 +627,8 @@ class ReplaceOperationRewrite : public OperationRewrite {
626627

627628
void cleanup(RewriterBase &rewriter) override;
628629

630+
const TypeConverter *getConverter() const { return converter; }
631+
629632
private:
630633
/// An optional type converter that can be used to materialize conversions
631634
/// between the new and old values if necessary.
@@ -822,14 +825,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
822825
ValueRange replacements, Value originalValue,
823826
const TypeConverter *converter);
824827

825-
/// Find a replacement value for the given SSA value in the conversion value
826-
/// mapping. The replacement value must have the same type as the given SSA
827-
/// value. If there is no replacement value with the correct type, find the
828-
/// latest replacement value (regardless of the type) and build a source
829-
/// materialization.
830-
Value findOrBuildReplacementValue(Value value,
831-
const TypeConverter *converter);
832-
833828
//===--------------------------------------------------------------------===//
834829
// Rewriter Notification Hooks
835830
//===--------------------------------------------------------------------===//
@@ -975,7 +970,7 @@ void BlockTypeConversionRewrite::rollback() {
975970
}
976971

977972
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
978-
Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
973+
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
979974
if (!repl)
980975
return;
981976

@@ -1004,7 +999,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1004999
// Compute replacement values.
10051000
SmallVector<Value> replacements =
10061001
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1007-
return rewriterImpl.findOrBuildReplacementValue(result, converter);
1002+
return rewriterImpl.mapping.lookupOrNull(result, result.getType());
10081003
});
10091004

10101005
// Notify the listener that the operation is about to be replaced.
@@ -1074,10 +1069,8 @@ void UnresolvedMaterializationRewrite::rollback() {
10741069
void ConversionPatternRewriterImpl::applyRewrites() {
10751070
// Commit all rewrites.
10761071
IRRewriter rewriter(context, config.listener);
1077-
// Note: New rewrites may be added during the "commit" phase and the
1078-
// `rewrites` vector may reallocate.
1079-
for (size_t i = 0; i < rewrites.size(); ++i)
1080-
rewrites[i]->commit(rewriter);
1072+
for (auto &rewrite : rewrites)
1073+
rewrite->commit(rewriter);
10811074

10821075
// Clean up all rewrites.
10831076
for (auto &rewrite : rewrites)
@@ -1282,7 +1275,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12821275
/*inputs=*/ValueRange(),
12831276
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
12841277
mapping.map(origArg, repl);
1285-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1278+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
12861279
continue;
12871280
}
12881281

@@ -1292,7 +1285,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12921285
"invalid to provide a replacement value when the argument isn't "
12931286
"dropped");
12941287
mapping.map(origArg, repl);
1295-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1288+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
12961289
continue;
12971290
}
12981291

@@ -1305,10 +1298,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13051298
insertNTo1Materialization(
13061299
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
13071300
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1308-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1301+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13091302
}
13101303

1311-
appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
1304+
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
13121305

13131306
// Erase the old block. (It is just unlinked for now and will be erased during
13141307
// cleanup.)
@@ -1378,41 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
13781371
}
13791372
}
13801373

1381-
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1382-
Value value, const TypeConverter *converter) {
1383-
// Find a replacement value with the same type.
1384-
Value repl = mapping.lookupOrNull(value, value.getType());
1385-
if (repl)
1386-
return repl;
1387-
1388-
// Check if the value is dead. No replacement value is needed in that case.
1389-
// This is an approximate check that may have false negatives but does not
1390-
// require computing and traversing an inverse mapping. (We may end up
1391-
// building source materializations that are never used and that fold away.)
1392-
if (llvm::all_of(value.getUsers(),
1393-
[&](Operation *op) { return replacedOps.contains(op); }) &&
1394-
!mapping.isMappedTo(value))
1395-
return Value();
1396-
1397-
// No replacement value was found. Get the latest replacement value
1398-
// (regardless of the type) and build a source materialization to the
1399-
// original type.
1400-
repl = mapping.lookupOrNull(value);
1401-
if (!repl) {
1402-
// No replacement value is registered in the mapping. This means that the
1403-
// value is dropped and no longer needed. (If the value were still needed,
1404-
// a source materialization producing a replacement value "out of thin air"
1405-
// would have already been created during `replaceOp` or
1406-
// `applySignatureConversion`.)
1407-
return Value();
1408-
}
1409-
Value castValue = buildUnresolvedMaterialization(
1410-
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
1411-
/*inputs=*/repl, /*outputType=*/value.getType(),
1412-
/*originalType=*/Type(), converter);
1413-
return castValue;
1414-
}
1415-
14161374
//===----------------------------------------------------------------------===//
14171375
// Rewriter Notification Hooks
14181376

@@ -1639,8 +1597,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
16391597
<< "'(in region of '" << parentOp->getName()
16401598
<< "'(" << from.getOwner()->getParentOp() << ")\n";
16411599
});
1642-
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1643-
impl->currentTypeConverter);
1600+
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
16441601
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
16451602
}
16461603

@@ -2460,6 +2417,10 @@ struct OperationConverter {
24602417
/// Converts an operation with the given rewriter.
24612418
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
24622419

2420+
/// This method is called after the conversion process to legalize any
2421+
/// remaining artifacts and complete the conversion.
2422+
void finalize(ConversionPatternRewriter &rewriter);
2423+
24632424
/// Dialect conversion configuration.
24642425
ConversionConfig config;
24652426

@@ -2580,6 +2541,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25802541
if (failed(convert(rewriter, op)))
25812542
return rewriterImpl.undoRewrites(), failure();
25822543

2544+
// Now that all of the operations have been converted, finalize the conversion
2545+
// process to ensure any lingering conversion artifacts are cleaned up and
2546+
// legalized.
2547+
finalize(rewriter);
2548+
25832549
// After a successful conversion, apply rewrites.
25842550
rewriterImpl.applyRewrites();
25852551

@@ -2613,6 +2579,80 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
26132579
return success();
26142580
}
26152581

2582+
/// Finds a user of the given value, or of any other value that the given value
2583+
/// replaced, that was not replaced in the conversion process.
2584+
static Operation *findLiveUserOfReplaced(
2585+
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2586+
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2587+
SmallVector<Value> worklist = {initialValue};
2588+
while (!worklist.empty()) {
2589+
Value value = worklist.pop_back_val();
2590+
2591+
// Walk the users of this value to see if there are any live users that
2592+
// weren't replaced during conversion.
2593+
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2594+
return rewriterImpl.isOpIgnored(user);
2595+
});
2596+
if (liveUserIt != value.user_end())
2597+
return *liveUserIt;
2598+
auto mapIt = inverseMapping.find(value);
2599+
if (mapIt != inverseMapping.end())
2600+
worklist.append(mapIt->second);
2601+
}
2602+
return nullptr;
2603+
}
2604+
2605+
/// Helper function that returns the replaced values and the type converter if
2606+
/// the given rewrite object is an "operation replacement" or a "block type
2607+
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
2608+
/// an empty ValueRange and a null type converter pointer.
2609+
static std::pair<ValueRange, const TypeConverter *>
2610+
getReplacedValues(IRRewrite *rewrite) {
2611+
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2612+
return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2613+
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2614+
return {blockRewrite->getOrigBlock()->getArguments(),
2615+
blockRewrite->getConverter()};
2616+
return {};
2617+
}
2618+
2619+
void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2620+
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2621+
DenseMap<Value, SmallVector<Value>> inverseMapping =
2622+
rewriterImpl.mapping.getInverse();
2623+
2624+
// Process requested value replacements.
2625+
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2626+
ValueRange replacedValues;
2627+
const TypeConverter *converter;
2628+
std::tie(replacedValues, converter) =
2629+
getReplacedValues(rewriterImpl.rewrites[i].get());
2630+
for (Value originalValue : replacedValues) {
2631+
// If the type of this value changed and the value is still live, we need
2632+
// to materialize a conversion.
2633+
if (rewriterImpl.mapping.lookupOrNull(originalValue,
2634+
originalValue.getType()))
2635+
continue;
2636+
Operation *liveUser =
2637+
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
2638+
if (!liveUser)
2639+
continue;
2640+
2641+
// Legalize this value replacement.
2642+
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
2643+
assert(newValue && "replacement value not found");
2644+
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2645+
MaterializationKind::Source, computeInsertPoint(newValue),
2646+
originalValue.getLoc(),
2647+
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2648+
/*originalType=*/Type(), converter);
2649+
rewriterImpl.mapping.map(originalValue, castValue);
2650+
inverseMapping[castValue].push_back(originalValue);
2651+
llvm::erase(inverseMapping[newValue], originalValue);
2652+
}
2653+
}
2654+
}
2655+
26162656
//===----------------------------------------------------------------------===//
26172657
// Reconcile Unrealized Casts
26182658
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)