Skip to content

Commit 0fd4cb8

Browse files
[mlir][Transforms] Dialect conversion: Unify materialization of value replacements
PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`. This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterates over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed.
1 parent 066359e commit 0fd4cb8

File tree

1 file changed

+42
-92
lines changed

1 file changed

+42
-92
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 42 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,17 +2336,6 @@ struct OperationConverter {
23362336
/// remaining artifacts and complete the conversion.
23372337
LogicalResult finalize(ConversionPatternRewriter &rewriter);
23382338

2339-
/// Legalize the types of converted block arguments.
2340-
LogicalResult
2341-
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2342-
ConversionPatternRewriterImpl &rewriterImpl);
2343-
2344-
/// Legalize the types of converted op results.
2345-
LogicalResult legalizeConvertedOpResultTypes(
2346-
ConversionPatternRewriter &rewriter,
2347-
ConversionPatternRewriterImpl &rewriterImpl,
2348-
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2349-
23502339
/// Dialect conversion configuration.
23512340
ConversionConfig config;
23522341

@@ -2510,19 +2499,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25102499
return success();
25112500
}
25122501

2513-
LogicalResult
2514-
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2515-
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2516-
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2517-
return failure();
2518-
DenseMap<Value, SmallVector<Value>> inverseMapping =
2519-
rewriterImpl.mapping.getInverse();
2520-
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
2521-
inverseMapping)))
2522-
return failure();
2523-
return success();
2524-
}
2525-
25262502
/// Finds a user of the given value, or of any other value that the given value
25272503
/// replaced, that was not replaced in the conversion process.
25282504
static Operation *findLiveUserOfReplaced(
@@ -2546,87 +2522,61 @@ static Operation *findLiveUserOfReplaced(
25462522
return nullptr;
25472523
}
25482524

2549-
LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2550-
ConversionPatternRewriter &rewriter,
2551-
ConversionPatternRewriterImpl &rewriterImpl,
2552-
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2553-
// Process requested operation replacements.
2554-
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2555-
auto *opReplacement =
2556-
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2557-
if (!opReplacement)
2558-
continue;
2559-
Operation *op = opReplacement->getOperation();
2560-
for (OpResult result : op->getResults()) {
2561-
// If the type of this op result changed and the result is still live,
2562-
// we need to materialize a conversion.
2563-
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
2525+
/// Helper function that returns the replaced values and the type converter if
2526+
/// the given rewrite object is an "operation replacement" or a "block type
2527+
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
2528+
/// an empty ValueRange and a null type converter pointer.
2529+
static std::pair<ValueRange, const TypeConverter *>
2530+
getReplacedValues(IRRewrite *rewrite) {
2531+
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2532+
return std::make_pair(opRewrite->getOperation()->getResults(),
2533+
opRewrite->getConverter());
2534+
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2535+
return std::make_pair(blockRewrite->getOrigBlock()->getArguments(),
2536+
blockRewrite->getConverter());
2537+
return std::make_pair(ValueRange(), nullptr);
2538+
}
2539+
2540+
LogicalResult
2541+
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2542+
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2543+
DenseMap<Value, SmallVector<Value>> inverseMapping =
2544+
rewriterImpl.mapping.getInverse();
2545+
2546+
// Process requested value replacements.
2547+
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2548+
ValueRange replacedValues;
2549+
const TypeConverter *converter;
2550+
std::tie(replacedValues, converter) =
2551+
getReplacedValues(rewriterImpl.rewrites[i].get());
2552+
for (Value originalValue : replacedValues) {
2553+
// If the type of this value changed and the value is still live, we need
2554+
// to materialize a conversion.
2555+
if (rewriterImpl.mapping.lookupOrNull(originalValue,
2556+
originalValue.getType()))
25642557
continue;
25652558
Operation *liveUser =
2566-
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2559+
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
25672560
if (!liveUser)
25682561
continue;
25692562

2570-
// Legalize this result.
2571-
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2563+
// Legalize this value replacement.
2564+
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
25722565
assert(newValue && "replacement value not found");
25732566
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2574-
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
2575-
/*inputs=*/newValue, /*outputType=*/result.getType(),
2576-
opReplacement->getConverter());
2577-
rewriterImpl.mapping.map(result, castValue);
2578-
inverseMapping[castValue].push_back(result);
2579-
llvm::erase(inverseMapping[newValue], result);
2567+
MaterializationKind::Source, computeInsertPoint(newValue),
2568+
originalValue.getLoc(),
2569+
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2570+
converter);
2571+
rewriterImpl.mapping.map(originalValue, castValue);
2572+
inverseMapping[castValue].push_back(originalValue);
2573+
llvm::erase(inverseMapping[newValue], originalValue);
25802574
}
25812575
}
25822576

25832577
return success();
25842578
}
25852579

2586-
LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2587-
ConversionPatternRewriter &rewriter,
2588-
ConversionPatternRewriterImpl &rewriterImpl) {
2589-
// Functor used to check if all users of a value will be dead after
2590-
// conversion.
2591-
// TODO: This should probably query the inverse mapping, same as in
2592-
// `legalizeConvertedOpResultTypes`.
2593-
auto findLiveUser = [&](Value val) {
2594-
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2595-
return rewriterImpl.isOpIgnored(user);
2596-
});
2597-
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2598-
};
2599-
// Note: `rewrites` may be reallocated as the loop is running.
2600-
for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2601-
++i) {
2602-
auto &rewrite = rewriterImpl.rewrites[i];
2603-
if (auto *blockTypeConversionRewrite =
2604-
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
2605-
// Process the remapping for each of the original arguments.
2606-
for (Value origArg :
2607-
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
2608-
// If the type of this argument changed and the argument is still live,
2609-
// we need to materialize a conversion.
2610-
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
2611-
continue;
2612-
Operation *liveUser = findLiveUser(origArg);
2613-
if (!liveUser)
2614-
continue;
2615-
2616-
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
2617-
assert(replacementValue && "replacement value not found");
2618-
Value repl = rewriterImpl.buildUnresolvedMaterialization(
2619-
MaterializationKind::Source, computeInsertPoint(replacementValue),
2620-
origArg.getLoc(), /*inputs=*/replacementValue,
2621-
/*outputType=*/origArg.getType(),
2622-
blockTypeConversionRewrite->getConverter());
2623-
rewriterImpl.mapping.map(origArg, repl);
2624-
}
2625-
}
2626-
}
2627-
return success();
2628-
}
2629-
26302580
//===----------------------------------------------------------------------===//
26312581
// Reconcile Unrealized Casts
26322582
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)