Skip to content

Commit 3761b67

Browse files
[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (#117097)
This is a re-upload of #116934, which was reverted. The dialect conversion driver has three phases: - **Create** `IRRewrite` objects as the IR is traversed. - **Finalize** `IRRewrite` objects. During this phase, source materializations for mismatching value types are created. (E.g., when `Value` is replaced with a `Value` of different type, but there is a user of the original value that was not modified because it is already legal.) - **Commit** `IRRewrite` objects. During this phase, all remaining IR modifications are materialized. In particular, SSA values are actually being replaced during this phase. This commit removes the "finalize" phase. This simplifies the code base a bit and avoids one traversal over the `IRRewrite` stack. Source materializations are now built during the "commit" phase, right before an SSA value is being replaced. This commit also removes the "inverse mapping" of the conversion value mapping, which was used to predict if an SSA value will be dead at the end of the conversion. This check is replaced with an approximate check that does not require an inverse mapping. (A false positive for `v` can occur if another value `v2` is mapped to `v` and `v2` turns out to be dead at the end of the conversion. This case is not expected to happen very often.) This reduces the complexity of the driver a bit and removes one potential source of bugs. (There have been bugs in the usage of the inverse mapping in the past.) `BlockTypeConversionRewrite` no longer stores a pointer to the type converter. This pointer is now stored in `ReplaceBlockArgRewrite`. This commit is in preparation of merging the 1:1 and 1:N dialect conversion driver. It simplifies the upcoming changes around the conversion value mapping. (API surface of the conversion value mapping is reduced.)
1 parent f84fc44 commit 3761b67

File tree

1 file changed

+73
-112
lines changed

1 file changed

+73
-112
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 73 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ namespace {
7575
/// This class wraps a IRMapping to provide recursive lookup
7676
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
7777
struct ConversionValueMapping {
78+
/// Return "true" if an SSA value is mapped to the given value. May return
79+
/// false positives.
80+
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
81+
7882
/// Lookup the most recently mapped value with the desired type in the
7983
/// mapping.
8084
///
@@ -99,22 +103,18 @@ struct ConversionValueMapping {
99103
assert(it != oldVal && "inserting cyclic mapping");
100104
});
101105
mapping.map(oldVal, newVal);
106+
mappedTo.insert(newVal);
102107
}
103108

104109
/// Drop the last mapping for the given value.
105110
void erase(Value value) { mapping.erase(value); }
106111

107-
/// Returns the inverse raw value mapping (without recursive query support).
108-
DenseMap<Value, SmallVector<Value>> getInverse() const {
109-
DenseMap<Value, SmallVector<Value>> inverse;
110-
for (auto &it : mapping.getValueMap())
111-
inverse[it.second].push_back(it.first);
112-
return inverse;
113-
}
114-
115112
private:
116113
/// Current value mappings.
117114
IRMapping mapping;
115+
116+
/// All SSA values that are mapped to. May contain false positives.
117+
DenseSet<Value> mappedTo;
118118
};
119119
} // namespace
120120

@@ -434,29 +434,23 @@ class MoveBlockRewrite : public BlockRewrite {
434434
class BlockTypeConversionRewrite : public BlockRewrite {
435435
public:
436436
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
437-
Block *block, Block *origBlock,
438-
const TypeConverter *converter)
437+
Block *block, Block *origBlock)
439438
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
440-
origBlock(origBlock), converter(converter) {}
439+
origBlock(origBlock) {}
441440

442441
static bool classof(const IRRewrite *rewrite) {
443442
return rewrite->getKind() == Kind::BlockTypeConversion;
444443
}
445444

446445
Block *getOrigBlock() const { return origBlock; }
447446

448-
const TypeConverter *getConverter() const { return converter; }
449-
450447
void commit(RewriterBase &rewriter) override;
451448

452449
void rollback() override;
453450

454451
private:
455452
/// The original block that was requested to have its signature converted.
456453
Block *origBlock;
457-
458-
/// The type converter used to convert the arguments.
459-
const TypeConverter *converter;
460454
};
461455

462456
/// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -465,8 +459,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
465459
class ReplaceBlockArgRewrite : public BlockRewrite {
466460
public:
467461
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
468-
Block *block, BlockArgument arg)
469-
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
462+
Block *block, BlockArgument arg,
463+
const TypeConverter *converter)
464+
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
465+
converter(converter) {}
470466

471467
static bool classof(const IRRewrite *rewrite) {
472468
return rewrite->getKind() == Kind::ReplaceBlockArg;
@@ -478,6 +474,9 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
478474

479475
private:
480476
BlockArgument arg;
477+
478+
/// The current type converter when the block argument was replaced.
479+
const TypeConverter *converter;
481480
};
482481

483482
/// An operation rewrite.
@@ -627,8 +626,6 @@ class ReplaceOperationRewrite : public OperationRewrite {
627626

628627
void cleanup(RewriterBase &rewriter) override;
629628

630-
const TypeConverter *getConverter() const { return converter; }
631-
632629
private:
633630
/// An optional type converter that can be used to materialize conversions
634631
/// between the new and old values if necessary.
@@ -825,6 +822,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
825822
ValueRange replacements, Value originalValue,
826823
const TypeConverter *converter);
827824

825+
/// Find a replacement value for the given SSA value in the conversion value
826+
/// mapping. The replacement value must have the same type as the given SSA
827+
/// value. If there is no replacement value with the correct type, find the
828+
/// latest replacement value (regardless of the type) and build a source
829+
/// materialization.
830+
Value findOrBuildReplacementValue(Value value,
831+
const TypeConverter *converter);
832+
828833
//===--------------------------------------------------------------------===//
829834
// Rewriter Notification Hooks
830835
//===--------------------------------------------------------------------===//
@@ -970,7 +975,7 @@ void BlockTypeConversionRewrite::rollback() {
970975
}
971976

972977
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
973-
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
978+
Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
974979
if (!repl)
975980
return;
976981

@@ -999,7 +1004,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
9991004
// Compute replacement values.
10001005
SmallVector<Value> replacements =
10011006
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1002-
return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1007+
return rewriterImpl.findOrBuildReplacementValue(result, converter);
10031008
});
10041009

10051010
// Notify the listener that the operation is about to be replaced.
@@ -1069,8 +1074,10 @@ void UnresolvedMaterializationRewrite::rollback() {
10691074
void ConversionPatternRewriterImpl::applyRewrites() {
10701075
// Commit all rewrites.
10711076
IRRewriter rewriter(context, config.listener);
1072-
for (auto &rewrite : rewrites)
1073-
rewrite->commit(rewriter);
1077+
// Note: New rewrites may be added during the "commit" phase and the
1078+
// `rewrites` vector may reallocate.
1079+
for (size_t i = 0; i < rewrites.size(); ++i)
1080+
rewrites[i]->commit(rewriter);
10741081

10751082
// Clean up all rewrites.
10761083
for (auto &rewrite : rewrites)
@@ -1275,7 +1282,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12751282
/*inputs=*/ValueRange(),
12761283
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
12771284
mapping.map(origArg, repl);
1278-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1285+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
12791286
continue;
12801287
}
12811288

@@ -1285,7 +1292,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12851292
"invalid to provide a replacement value when the argument isn't "
12861293
"dropped");
12871294
mapping.map(origArg, repl);
1288-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1295+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
12891296
continue;
12901297
}
12911298

@@ -1298,10 +1305,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12981305
insertNTo1Materialization(
12991306
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
13001307
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1301-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1308+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
13021309
}
13031310

1304-
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1311+
appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
13051312

13061313
// Erase the old block. (It is just unlinked for now and will be erased during
13071314
// cleanup.)
@@ -1371,6 +1378,42 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
13711378
}
13721379
}
13731380

1381+
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1382+
Value value, const TypeConverter *converter) {
1383+
// Find a replacement value with the same type.
1384+
Value repl = mapping.lookupOrNull(value, value.getType());
1385+
if (repl)
1386+
return repl;
1387+
1388+
// Check if the value is dead. No replacement value is needed in that case.
1389+
// This is an approximate check that may have false negatives but does not
1390+
// require computing and traversing an inverse mapping. (We may end up
1391+
// building source materializations that are never used and that fold away.)
1392+
if (llvm::all_of(value.getUsers(),
1393+
[&](Operation *op) { return replacedOps.contains(op); }) &&
1394+
!mapping.isMappedTo(value))
1395+
return Value();
1396+
1397+
// No replacement value was found. Get the latest replacement value
1398+
// (regardless of the type) and build a source materialization to the
1399+
// original type.
1400+
repl = mapping.lookupOrNull(value);
1401+
if (!repl) {
1402+
// No replacement value is registered in the mapping. This means that the
1403+
// value is dropped and no longer needed. (If the value were still needed,
1404+
// a source materialization producing a replacement value "out of thin air"
1405+
// would have already been created during `replaceOp` or
1406+
// `applySignatureConversion`.)
1407+
return Value();
1408+
}
1409+
Value castValue = buildUnresolvedMaterialization(
1410+
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
1411+
/*inputs=*/repl, /*outputType=*/value.getType(),
1412+
/*originalType=*/Type(), converter);
1413+
mapping.map(value, castValue);
1414+
return castValue;
1415+
}
1416+
13741417
//===----------------------------------------------------------------------===//
13751418
// Rewriter Notification Hooks
13761419

@@ -1597,7 +1640,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
15971640
<< "'(in region of '" << parentOp->getName()
15981641
<< "'(" << from.getOwner()->getParentOp() << ")\n";
15991642
});
1600-
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1643+
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1644+
impl->currentTypeConverter);
16011645
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
16021646
}
16031647

@@ -2417,10 +2461,6 @@ struct OperationConverter {
24172461
/// Converts an operation with the given rewriter.
24182462
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
24192463

2420-
/// This method is called after the conversion process to legalize any
2421-
/// remaining artifacts and complete the conversion.
2422-
void finalize(ConversionPatternRewriter &rewriter);
2423-
24242464
/// Dialect conversion configuration.
24252465
ConversionConfig config;
24262466

@@ -2541,11 +2581,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25412581
if (failed(convert(rewriter, op)))
25422582
return rewriterImpl.undoRewrites(), failure();
25432583

2544-
// Now that all of the operations have been converted, finalize the conversion
2545-
// process to ensure any lingering conversion artifacts are cleaned up and
2546-
// legalized.
2547-
finalize(rewriter);
2548-
25492584
// After a successful conversion, apply rewrites.
25502585
rewriterImpl.applyRewrites();
25512586

@@ -2579,80 +2614,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25792614
return success();
25802615
}
25812616

2582-
/// Finds a user of the given value, or of any other value that the given value
2583-
/// replaced, that was not replaced in the conversion process.
2584-
static Operation *findLiveUserOfReplaced(
2585-
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2586-
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2587-
SmallVector<Value> worklist = {initialValue};
2588-
while (!worklist.empty()) {
2589-
Value value = worklist.pop_back_val();
2590-
2591-
// Walk the users of this value to see if there are any live users that
2592-
// weren't replaced during conversion.
2593-
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2594-
return rewriterImpl.isOpIgnored(user);
2595-
});
2596-
if (liveUserIt != value.user_end())
2597-
return *liveUserIt;
2598-
auto mapIt = inverseMapping.find(value);
2599-
if (mapIt != inverseMapping.end())
2600-
worklist.append(mapIt->second);
2601-
}
2602-
return nullptr;
2603-
}
2604-
2605-
/// Helper function that returns the replaced values and the type converter if
2606-
/// the given rewrite object is an "operation replacement" or a "block type
2607-
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
2608-
/// an empty ValueRange and a null type converter pointer.
2609-
static std::pair<ValueRange, const TypeConverter *>
2610-
getReplacedValues(IRRewrite *rewrite) {
2611-
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2612-
return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2613-
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2614-
return {blockRewrite->getOrigBlock()->getArguments(),
2615-
blockRewrite->getConverter()};
2616-
return {};
2617-
}
2618-
2619-
void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2620-
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2621-
DenseMap<Value, SmallVector<Value>> inverseMapping =
2622-
rewriterImpl.mapping.getInverse();
2623-
2624-
// Process requested value replacements.
2625-
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2626-
ValueRange replacedValues;
2627-
const TypeConverter *converter;
2628-
std::tie(replacedValues, converter) =
2629-
getReplacedValues(rewriterImpl.rewrites[i].get());
2630-
for (Value originalValue : replacedValues) {
2631-
// If the type of this value changed and the value is still live, we need
2632-
// to materialize a conversion.
2633-
if (rewriterImpl.mapping.lookupOrNull(originalValue,
2634-
originalValue.getType()))
2635-
continue;
2636-
Operation *liveUser =
2637-
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
2638-
if (!liveUser)
2639-
continue;
2640-
2641-
// Legalize this value replacement.
2642-
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
2643-
assert(newValue && "replacement value not found");
2644-
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2645-
MaterializationKind::Source, computeInsertPoint(newValue),
2646-
originalValue.getLoc(),
2647-
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2648-
/*originalType=*/Type(), converter);
2649-
rewriterImpl.mapping.map(originalValue, castValue);
2650-
inverseMapping[castValue].push_back(originalValue);
2651-
llvm::erase(inverseMapping[newValue], originalValue);
2652-
}
2653-
}
2654-
}
2655-
26562617
//===----------------------------------------------------------------------===//
26572618
// Reconcile Unrealized Casts
26582619
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)