Skip to content

[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase #117097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 73 additions & 112 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
struct ConversionValueMapping {
/// Return "true" if an SSA value is mapped to the given value. May return
/// false positives.
bool isMappedTo(Value value) const { return mappedTo.contains(value); }

/// Lookup the most recently mapped value with the desired type in the
/// mapping.
///
Expand All @@ -99,22 +103,18 @@ struct ConversionValueMapping {
assert(it != oldVal && "inserting cyclic mapping");
});
mapping.map(oldVal, newVal);
mappedTo.insert(newVal);
}

/// Drop the last mapping for the given value.
void erase(Value value) { mapping.erase(value); }

/// Returns the inverse raw value mapping (without recursive query support).
DenseMap<Value, SmallVector<Value>> getInverse() const {
DenseMap<Value, SmallVector<Value>> inverse;
for (auto &it : mapping.getValueMap())
inverse[it.second].push_back(it.first);
return inverse;
}

private:
/// Current value mappings.
IRMapping mapping;

/// All SSA values that are mapped to. May contain false positives.
DenseSet<Value> mappedTo;
};
} // namespace

Expand Down Expand Up @@ -434,29 +434,23 @@ class MoveBlockRewrite : public BlockRewrite {
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block, Block *origBlock,
const TypeConverter *converter)
Block *block, Block *origBlock)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
origBlock(origBlock), converter(converter) {}
origBlock(origBlock) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
}

Block *getOrigBlock() const { return origBlock; }

const TypeConverter *getConverter() const { return converter; }

void commit(RewriterBase &rewriter) override;

void rollback() override;

private:
/// The original block that was requested to have its signature converted.
Block *origBlock;

/// The type converter used to convert the arguments.
const TypeConverter *converter;
};

/// Replacing a block argument. This rewrite is not immediately reflected in the
Expand All @@ -465,8 +459,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
class ReplaceBlockArgRewrite : public BlockRewrite {
public:
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block, BlockArgument arg)
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
Block *block, BlockArgument arg,
const TypeConverter *converter)
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
converter(converter) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::ReplaceBlockArg;
Expand All @@ -478,6 +474,9 @@ class ReplaceBlockArgRewrite : public BlockRewrite {

private:
BlockArgument arg;

/// The current type converter when the block argument was replaced.
const TypeConverter *converter;
};

/// An operation rewrite.
Expand Down Expand Up @@ -627,8 +626,6 @@ class ReplaceOperationRewrite : public OperationRewrite {

void cleanup(RewriterBase &rewriter) override;

const TypeConverter *getConverter() const { return converter; }

private:
/// An optional type converter that can be used to materialize conversions
/// between the new and old values if necessary.
Expand Down Expand Up @@ -825,6 +822,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange replacements, Value originalValue,
const TypeConverter *converter);

/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
/// value. If there is no replacement value with the correct type, find the
/// latest replacement value (regardless of the type) and build a source
/// materialization.
Value findOrBuildReplacementValue(Value value,
const TypeConverter *converter);

//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -970,7 +975,7 @@ void BlockTypeConversionRewrite::rollback() {
}

void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
if (!repl)
return;

Expand Down Expand Up @@ -999,7 +1004,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
// Compute replacement values.
SmallVector<Value> replacements =
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
return rewriterImpl.mapping.lookupOrNull(result, result.getType());
return rewriterImpl.findOrBuildReplacementValue(result, converter);
});

// Notify the listener that the operation is about to be replaced.
Expand Down Expand Up @@ -1069,8 +1074,10 @@ void UnresolvedMaterializationRewrite::rollback() {
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
IRRewriter rewriter(context, config.listener);
for (auto &rewrite : rewrites)
rewrite->commit(rewriter);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
rewrites[i]->commit(rewriter);

// Clean up all rewrites.
for (auto &rewrite : rewrites)
Expand Down Expand Up @@ -1275,7 +1282,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}

Expand All @@ -1285,7 +1292,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}

Expand All @@ -1298,10 +1305,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
insertNTo1Materialization(
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}

appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
appendRewrite<BlockTypeConversionRewrite>(newBlock, block);

// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
Expand Down Expand Up @@ -1371,6 +1378,42 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
}
}

Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
// Find a replacement value with the same type.
Value repl = mapping.lookupOrNull(value, value.getType());
if (repl)
return repl;

// Check if the value is dead. No replacement value is needed in that case.
// This is an approximate check that may have false negatives but does not
// require computing and traversing an inverse mapping. (We may end up
// building source materializations that are never used and that fold away.)
if (llvm::all_of(value.getUsers(),
[&](Operation *op) { return replacedOps.contains(op); }) &&
!mapping.isMappedTo(value))
return Value();

// No replacement value was found. Get the latest replacement value
// (regardless of the type) and build a source materialization to the
// original type.
repl = mapping.lookupOrNull(value);
if (!repl) {
// No replacement value is registered in the mapping. This means that the
// value is dropped and no longer needed. (If the value were still needed,
// a source materialization producing a replacement value "out of thin air"
// would have already been created during `replaceOp` or
// `applySignatureConversion`.)
return Value();
}
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
/*inputs=*/repl, /*outputType=*/value.getType(),
/*originalType=*/Type(), converter);
mapping.map(value, castValue);
return castValue;
}

//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks

Expand Down Expand Up @@ -1597,7 +1640,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}

Expand Down Expand Up @@ -2417,10 +2461,6 @@ struct OperationConverter {
/// Converts an operation with the given rewriter.
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);

/// This method is called after the conversion process to legalize any
/// remaining artifacts and complete the conversion.
void finalize(ConversionPatternRewriter &rewriter);

/// Dialect conversion configuration.
ConversionConfig config;

Expand Down Expand Up @@ -2541,11 +2581,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (failed(convert(rewriter, op)))
return rewriterImpl.undoRewrites(), failure();

// Now that all of the operations have been converted, finalize the conversion
// process to ensure any lingering conversion artifacts are cleaned up and
// legalized.
finalize(rewriter);

// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();

Expand Down Expand Up @@ -2579,80 +2614,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}

/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
SmallVector<Value> worklist = {initialValue};
while (!worklist.empty()) {
Value value = worklist.pop_back_val();

// Walk the users of this value to see if there are any live users that
// weren't replaced during conversion.
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
if (liveUserIt != value.user_end())
return *liveUserIt;
auto mapIt = inverseMapping.find(value);
if (mapIt != inverseMapping.end())
worklist.append(mapIt->second);
}
return nullptr;
}

/// Helper function that returns the replaced values and the type converter if
/// the given rewrite object is an "operation replacement" or a "block type
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
/// an empty ValueRange and a null type converter pointer.
static std::pair<ValueRange, const TypeConverter *>
getReplacedValues(IRRewrite *rewrite) {
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
return {blockRewrite->getOrigBlock()->getArguments(),
blockRewrite->getConverter()};
return {};
}

void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();

// Process requested value replacements.
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
ValueRange replacedValues;
const TypeConverter *converter;
std::tie(replacedValues, converter) =
getReplacedValues(rewriterImpl.rewrites[i].get());
for (Value originalValue : replacedValues) {
// If the type of this value changed and the value is still live, we need
// to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(originalValue,
originalValue.getType()))
continue;
Operation *liveUser =
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
if (!liveUser)
continue;

// Legalize this value replacement.
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(newValue),
originalValue.getLoc(),
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
/*originalType=*/Type(), converter);
rewriterImpl.mapping.map(originalValue, castValue);
inverseMapping[castValue].push_back(originalValue);
llvm::erase(inverseMapping[newValue], originalValue);
}
}
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
Expand Down
Loading