-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms][NFC] Turn block type conversion into IRRewrite
#81756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Turn block type conversion into IRRewrite
#81756
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure). Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via Overview of changes:
Patch is 40.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81756.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 67b076b295eae8..b2baa88879b6e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,12 +154,13 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewrites, unsigned numIgnoredOperations)
+ unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations) {}
+ numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -178,6 +179,9 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
+
+ /// The current number of erased operations/blocks.
+ unsigned numErased;
};
//===----------------------------------------------------------------------===//
@@ -292,374 +296,6 @@ static Value buildUnresolvedTargetMaterialization(
outputType, outputType, converter, unresolvedMaterializations);
}
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
- ArgConverter(
- PatternRewriter &rewriter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
- : rewriter(rewriter),
- unresolvedMaterializations(unresolvedMaterializations) {}
-
- /// This structure contains the information pertaining to an argument that has
- /// been converted.
- struct ConvertedArgInfo {
- ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value castValue = nullptr)
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
- /// The start index of in the new argument list that contains arguments that
- /// replace the original.
- unsigned newArgIdx;
-
- /// The number of arguments that replaced the original argument.
- unsigned newArgSize;
-
- /// The cast value that was created to cast from the new arguments to the
- /// old. This only used if 'newArgSize' > 1.
- Value castValue;
- };
-
- /// This structure contains information pertaining to a block that has had its
- /// signature converted.
- struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
- : origBlock(origBlock), converter(converter) {}
-
- /// The original block that was requested to have its signature converted.
- Block *origBlock;
-
- /// The conversion information for each of the arguments. The information is
- /// std::nullopt if the argument was dropped during conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
- /// The type converter used to convert the arguments.
- const TypeConverter *converter;
- };
-
- //===--------------------------------------------------------------------===//
- // Rewrite Application
- //===--------------------------------------------------------------------===//
-
- /// Erase any rewrites registered for the blocks within the given operation
- /// which is about to be removed. This merely drops the rewrites without
- /// undoing them.
- void notifyOpRemoved(Operation *op);
-
- /// Cleanup and undo any generated conversions for the arguments of block.
- /// This method replaces the new block with the original, reverting the IR to
- /// its original state.
- void discardRewrites(Block *block);
-
- /// Fully replace uses of the old arguments with the new.
- void applyRewrites(ConversionValueMapping &mapping);
-
- /// Materialize any necessary conversions for converted arguments that have
- /// live users, using the provided `findLiveUser` to search for a user that
- /// survives the conversion process.
- LogicalResult
- materializeLiveConversions(ConversionValueMapping &mapping,
- OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser);
-
- //===--------------------------------------------------------------------===//
- // Conversion
- //===--------------------------------------------------------------------===//
-
- /// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. Returns `block` if it did
- /// not require conversion.
- FailureOr<Block *>
- convertSignature(Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// Apply the given signature conversion on the given block. The new block
- /// containing the updated signature is returned. If no conversions were
- /// necessary, e.g. if the block has no arguments, `block` is returned.
- /// `converter` is used to generate any necessary cast operations that
- /// translate between the origin argument types and those specified in the
- /// signature conversion.
- Block *applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// A collection of blocks that have had their arguments converted. This is a
- /// map from the new replacement block, back to the original block.
- llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
- /// The pattern rewriter to use when materializing conversions.
- PatternRewriter &rewriter;
-
- /// An ordered set of unresolved materializations during conversion.
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
- if (conversionInfo.empty())
- return;
-
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region) {
- // Drop any rewrites from within.
- for (Operation &nestedOp : block)
- if (nestedOp.getNumRegions())
- notifyOpRemoved(&nestedOp);
-
- // Check if this block was converted.
- auto it = conversionInfo.find(&block);
- if (it == conversionInfo.end())
- continue;
-
- // Drop all uses of the original arguments and delete the original block.
- Block *origBlock = it->second.origBlock;
- for (BlockArgument arg : origBlock->getArguments())
- arg.dropAllUses();
- conversionInfo.erase(it);
- }
- }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
- auto it = conversionInfo.find(block);
- if (it == conversionInfo.end())
- return;
- Block *origBlock = it->second.origBlock;
-
- // Drop all uses of the new block arguments and replace uses of the new block.
- for (int i = block->getNumArguments() - 1; i >= 0; --i)
- block->getArgument(i).dropAllUses();
- block->replaceAllUsesWith(origBlock);
-
- // Move the operations back the original block, move the original block back
- // into its original location and the delete the new block.
- origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
- block->erase();
-
- conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
- for (auto &info : conversionInfo) {
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
- BlockArgument origArg = origBlock->getArgument(i);
-
- // Handle the case of a 1->0 value mapping.
- if (!argInfo) {
- if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
- origArg.replaceAllUsesWith(newArg);
- continue;
- }
-
- // Otherwise this is a 1->1+ value mapping.
- Value castValue = argInfo->castValue;
- assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
- // If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty()) {
- origArg.replaceAllUsesWith(
- mapping.lookupOrDefault(castValue, origArg.getType()));
- }
- }
-
- delete origBlock;
- blockInfo.origBlock = nullptr;
- }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
- ConversionValueMapping &mapping, OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser) {
- for (auto &info : conversionInfo) {
- Block *newBlock = info.first;
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- // If the type of this argument changed and the argument is still live, we
- // need to materialize a conversion.
- BlockArgument origArg = origBlock->getArgument(i);
- if (mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = mapping.lookupOrDefault(origArg);
- bool isDroppedArg = replacementValue == origArg;
- if (isDroppedArg)
- rewriter.setInsertionPointToStart(newBlock);
- else
- rewriter.setInsertionPointAfterValue(replacementValue);
- Value newArg;
- if (blockInfo.converter) {
- newArg = blockInfo.converter->materializeSourceConversion(
- rewriter, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(replacementValue));
- assert((!newArg || newArg.getType() == origArg.getType()) &&
- "materialization hook did not provide a value of the expected "
- "type");
- }
- if (!newArg) {
- InFlightDiagnostic diag =
- emitError(origArg.getLoc())
- << "failed to materialize conversion for block argument #" << i
- << " that remained live after conversion, type was "
- << origArg.getType();
- if (!isDroppedArg)
- diag << ", with target type " << replacementValue.getType();
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- return failure();
- }
- mapping.map(origArg, newArg);
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
- Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // Check if the block was already converted.
- // * If the block is mapped in `conversionInfo`, it is a converted block.
- // * If the block is detached, conservatively assume that it is going to be
- // deleted; it is likely the old block (before it was converted).
- if (conversionInfo.count(block) || !block->getParent())
- return block;
- // If a converter wasn't provided, and the block wasn't already converted,
- // there is nothing we can do.
- if (!converter)
- return failure();
-
- // Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion, mapping,
- argReplacements);
- return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // If no arguments are being changed or added, there is nothing to do.
- unsigned origArgCount = block->getNumArguments();
- auto convertedTypes = signatureConversion.getConvertedTypes();
- if (origArgCount == 0 && convertedTypes.empty())
- return block;
-
- // Split the block at the beginning to get a new block to use for the updated
- // signature.
- Block *newBlock = block->splitBlock(block->begin());
- block->replaceAllUsesWith(newBlock);
- // Unlink the block, but do not erase it yet, so that the change can be rolled
- // back.
- block->getParent()->getBlocks().remove(block);
-
- // Map all new arguments to the location of the argument they originate from.
- SmallVector<Location> newLocs(convertedTypes.size(),
- rewriter.getUnknownLoc());
- for (unsigned i = 0; i < origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap || inputMap->replacementValue)
- continue;
- Location origLoc = block->getArgument(i).getLoc();
- for (unsigned j = 0; j < inputMap->size; ++j)
- newLocs[inputMap->inputNo + j] = origLoc;
- }
-
- SmallVector<Value, 4> newArgRange(
- newBlock->addArguments(convertedTypes, newLocs));
- ArrayRef<Value> newArgs(newArgRange);
-
- // Remap each of the original arguments as determined by the signature
- // conversion.
- ConvertedBlockInfo info(block, converter);
- info.argInfo.resize(origArgCount);
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(newBlock);
- for (unsigned i = 0; i != origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap)
- continue;
- BlockArgument origArg = block->getArgument(i);
-
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
- continue;
- }
-
- // Otherwise, this is a 1->1+ mapping.
- auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Value newArg;
-
- // If this is a 1->1 mapping and the types of new and replacement arguments
- // match (i.e. it's an identity map), then the argument is mapped to its
- // original type.
- // FIXME: We simply pass through the replacement argument if there wasn't a
- // converter, which isn't great as it allows implicit type conversions to
- // appear. We should properly restructure this code to handle cases where a
- // converter isn't provided and also to properly handle the case where an
- // argument materialization is actually a temporary source materialization
- // (e.g. in the case of 1->N).
- if (replArgs.size() == 1 &&
- (!converter || replArgs[0].getType() == origArg.getType())) {
- newArg = replArgs.front();
- } else {
- Type origOutputType = origArg.getType();
-
- // Legalize the argument output type.
- Type outputType = origOutputType;
- if (Type legalOutputType = converter->convertType(outputType))
- outputType = legalOutputType;
-
- newArg = buildUnresolvedArgumentMaterialization(
- rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
- }
-
- mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
- info.argInfo[i] =
- ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
- }
-
- conversionInfo.insert({newBlock, std::move(info)});
- return newBlock;
-}
-
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
@@ -694,6 +330,12 @@ class IRRewrite {
/// Commit the rewrite.
virtual void commit() {}
+ /// Erase the given op (unless it was already erased).
+ void eraseOp(Operation *op);
+
+ /// Erase the given block (unless it was already erased).
+ void eraseBlock(Block *block);
+
Kind getKind() const { return kind; }
static bool classof(const IRRewrite *rewrite) { return true; }
@@ -744,8 +386,7 @@ class CreateBlockRewrite : public BlockRewrite {
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
};
@@ -881,8 +522,7 @@ class SplitBlockRewrite : public BlockRewrite {
// Merge back the block that was split out.
originalBlock->getOperations().splice(originalBlock->end(),
block->getOperations());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
private:
@@ -890,20 +530,59 @@ class SplitBlockRewrite : public BlockRewrite {
Block *originalBlock;
};
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
+
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value castValue;
+};
+
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block)
- : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+ BlockTypeConversionRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+ const TypeConverter *converter)
+ : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+ origBlock(origBlock), argInfo(argInfo), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- // TODO: Block type conversions are currently committed in
- // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+ /// Materialize any necessary conversions for converted arguments that have
+ /// live users, using the provided `findLiveUser` to search for a user that
+ /// survives the conversion process.
+ LogicalResult
+ materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+ void commit() override;
+
void rollback() override;
+
+private:
+ /// The original block that was requested to have its signature converted.
+ Block *origBlock;
+
+ /// The conversion information for each of the arguments. The information is
+ /// std::nullopt if the argument was dropped during conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+ /// The type converter used to convert the arguments.
+ const TypeConverter...
[truncated]
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure). Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via Overview of changes:
Patch is 40.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81756.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 67b076b295eae8..b2baa88879b6e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,12 +154,13 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewrites, unsigned numIgnoredOperations)
+ unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations) {}
+ numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -178,6 +179,9 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
+
+ /// The current number of erased operations/blocks.
+ unsigned numErased;
};
//===----------------------------------------------------------------------===//
@@ -292,374 +296,6 @@ static Value buildUnresolvedTargetMaterialization(
outputType, outputType, converter, unresolvedMaterializations);
}
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
- ArgConverter(
- PatternRewriter &rewriter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
- : rewriter(rewriter),
- unresolvedMaterializations(unresolvedMaterializations) {}
-
- /// This structure contains the information pertaining to an argument that has
- /// been converted.
- struct ConvertedArgInfo {
- ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value castValue = nullptr)
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
- /// The start index of in the new argument list that contains arguments that
- /// replace the original.
- unsigned newArgIdx;
-
- /// The number of arguments that replaced the original argument.
- unsigned newArgSize;
-
- /// The cast value that was created to cast from the new arguments to the
- /// old. This only used if 'newArgSize' > 1.
- Value castValue;
- };
-
- /// This structure contains information pertaining to a block that has had its
- /// signature converted.
- struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
- : origBlock(origBlock), converter(converter) {}
-
- /// The original block that was requested to have its signature converted.
- Block *origBlock;
-
- /// The conversion information for each of the arguments. The information is
- /// std::nullopt if the argument was dropped during conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
- /// The type converter used to convert the arguments.
- const TypeConverter *converter;
- };
-
- //===--------------------------------------------------------------------===//
- // Rewrite Application
- //===--------------------------------------------------------------------===//
-
- /// Erase any rewrites registered for the blocks within the given operation
- /// which is about to be removed. This merely drops the rewrites without
- /// undoing them.
- void notifyOpRemoved(Operation *op);
-
- /// Cleanup and undo any generated conversions for the arguments of block.
- /// This method replaces the new block with the original, reverting the IR to
- /// its original state.
- void discardRewrites(Block *block);
-
- /// Fully replace uses of the old arguments with the new.
- void applyRewrites(ConversionValueMapping &mapping);
-
- /// Materialize any necessary conversions for converted arguments that have
- /// live users, using the provided `findLiveUser` to search for a user that
- /// survives the conversion process.
- LogicalResult
- materializeLiveConversions(ConversionValueMapping &mapping,
- OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser);
-
- //===--------------------------------------------------------------------===//
- // Conversion
- //===--------------------------------------------------------------------===//
-
- /// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. Returns `block` if it did
- /// not require conversion.
- FailureOr<Block *>
- convertSignature(Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// Apply the given signature conversion on the given block. The new block
- /// containing the updated signature is returned. If no conversions were
- /// necessary, e.g. if the block has no arguments, `block` is returned.
- /// `converter` is used to generate any necessary cast operations that
- /// translate between the origin argument types and those specified in the
- /// signature conversion.
- Block *applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// A collection of blocks that have had their arguments converted. This is a
- /// map from the new replacement block, back to the original block.
- llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
- /// The pattern rewriter to use when materializing conversions.
- PatternRewriter &rewriter;
-
- /// An ordered set of unresolved materializations during conversion.
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
- if (conversionInfo.empty())
- return;
-
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region) {
- // Drop any rewrites from within.
- for (Operation &nestedOp : block)
- if (nestedOp.getNumRegions())
- notifyOpRemoved(&nestedOp);
-
- // Check if this block was converted.
- auto it = conversionInfo.find(&block);
- if (it == conversionInfo.end())
- continue;
-
- // Drop all uses of the original arguments and delete the original block.
- Block *origBlock = it->second.origBlock;
- for (BlockArgument arg : origBlock->getArguments())
- arg.dropAllUses();
- conversionInfo.erase(it);
- }
- }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
- auto it = conversionInfo.find(block);
- if (it == conversionInfo.end())
- return;
- Block *origBlock = it->second.origBlock;
-
- // Drop all uses of the new block arguments and replace uses of the new block.
- for (int i = block->getNumArguments() - 1; i >= 0; --i)
- block->getArgument(i).dropAllUses();
- block->replaceAllUsesWith(origBlock);
-
- // Move the operations back the original block, move the original block back
- // into its original location and the delete the new block.
- origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
- block->erase();
-
- conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
- for (auto &info : conversionInfo) {
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
- BlockArgument origArg = origBlock->getArgument(i);
-
- // Handle the case of a 1->0 value mapping.
- if (!argInfo) {
- if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
- origArg.replaceAllUsesWith(newArg);
- continue;
- }
-
- // Otherwise this is a 1->1+ value mapping.
- Value castValue = argInfo->castValue;
- assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
- // If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty()) {
- origArg.replaceAllUsesWith(
- mapping.lookupOrDefault(castValue, origArg.getType()));
- }
- }
-
- delete origBlock;
- blockInfo.origBlock = nullptr;
- }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
- ConversionValueMapping &mapping, OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser) {
- for (auto &info : conversionInfo) {
- Block *newBlock = info.first;
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- // If the type of this argument changed and the argument is still live, we
- // need to materialize a conversion.
- BlockArgument origArg = origBlock->getArgument(i);
- if (mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = mapping.lookupOrDefault(origArg);
- bool isDroppedArg = replacementValue == origArg;
- if (isDroppedArg)
- rewriter.setInsertionPointToStart(newBlock);
- else
- rewriter.setInsertionPointAfterValue(replacementValue);
- Value newArg;
- if (blockInfo.converter) {
- newArg = blockInfo.converter->materializeSourceConversion(
- rewriter, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(replacementValue));
- assert((!newArg || newArg.getType() == origArg.getType()) &&
- "materialization hook did not provide a value of the expected "
- "type");
- }
- if (!newArg) {
- InFlightDiagnostic diag =
- emitError(origArg.getLoc())
- << "failed to materialize conversion for block argument #" << i
- << " that remained live after conversion, type was "
- << origArg.getType();
- if (!isDroppedArg)
- diag << ", with target type " << replacementValue.getType();
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- return failure();
- }
- mapping.map(origArg, newArg);
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
- Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // Check if the block was already converted.
- // * If the block is mapped in `conversionInfo`, it is a converted block.
- // * If the block is detached, conservatively assume that it is going to be
- // deleted; it is likely the old block (before it was converted).
- if (conversionInfo.count(block) || !block->getParent())
- return block;
- // If a converter wasn't provided, and the block wasn't already converted,
- // there is nothing we can do.
- if (!converter)
- return failure();
-
- // Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion, mapping,
- argReplacements);
- return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // If no arguments are being changed or added, there is nothing to do.
- unsigned origArgCount = block->getNumArguments();
- auto convertedTypes = signatureConversion.getConvertedTypes();
- if (origArgCount == 0 && convertedTypes.empty())
- return block;
-
- // Split the block at the beginning to get a new block to use for the updated
- // signature.
- Block *newBlock = block->splitBlock(block->begin());
- block->replaceAllUsesWith(newBlock);
- // Unlink the block, but do not erase it yet, so that the change can be rolled
- // back.
- block->getParent()->getBlocks().remove(block);
-
- // Map all new arguments to the location of the argument they originate from.
- SmallVector<Location> newLocs(convertedTypes.size(),
- rewriter.getUnknownLoc());
- for (unsigned i = 0; i < origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap || inputMap->replacementValue)
- continue;
- Location origLoc = block->getArgument(i).getLoc();
- for (unsigned j = 0; j < inputMap->size; ++j)
- newLocs[inputMap->inputNo + j] = origLoc;
- }
-
- SmallVector<Value, 4> newArgRange(
- newBlock->addArguments(convertedTypes, newLocs));
- ArrayRef<Value> newArgs(newArgRange);
-
- // Remap each of the original arguments as determined by the signature
- // conversion.
- ConvertedBlockInfo info(block, converter);
- info.argInfo.resize(origArgCount);
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(newBlock);
- for (unsigned i = 0; i != origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap)
- continue;
- BlockArgument origArg = block->getArgument(i);
-
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
- continue;
- }
-
- // Otherwise, this is a 1->1+ mapping.
- auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Value newArg;
-
- // If this is a 1->1 mapping and the types of new and replacement arguments
- // match (i.e. it's an identity map), then the argument is mapped to its
- // original type.
- // FIXME: We simply pass through the replacement argument if there wasn't a
- // converter, which isn't great as it allows implicit type conversions to
- // appear. We should properly restructure this code to handle cases where a
- // converter isn't provided and also to properly handle the case where an
- // argument materialization is actually a temporary source materialization
- // (e.g. in the case of 1->N).
- if (replArgs.size() == 1 &&
- (!converter || replArgs[0].getType() == origArg.getType())) {
- newArg = replArgs.front();
- } else {
- Type origOutputType = origArg.getType();
-
- // Legalize the argument output type.
- Type outputType = origOutputType;
- if (Type legalOutputType = converter->convertType(outputType))
- outputType = legalOutputType;
-
- newArg = buildUnresolvedArgumentMaterialization(
- rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
- }
-
- mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
- info.argInfo[i] =
- ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
- }
-
- conversionInfo.insert({newBlock, std::move(info)});
- return newBlock;
-}
-
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
@@ -694,6 +330,12 @@ class IRRewrite {
/// Commit the rewrite.
virtual void commit() {}
+ /// Erase the given op (unless it was already erased).
+ void eraseOp(Operation *op);
+
+ /// Erase the given block (unless it was already erased).
+ void eraseBlock(Block *block);
+
Kind getKind() const { return kind; }
static bool classof(const IRRewrite *rewrite) { return true; }
@@ -744,8 +386,7 @@ class CreateBlockRewrite : public BlockRewrite {
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
};
@@ -881,8 +522,7 @@ class SplitBlockRewrite : public BlockRewrite {
// Merge back the block that was split out.
originalBlock->getOperations().splice(originalBlock->end(),
block->getOperations());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
private:
@@ -890,20 +530,59 @@ class SplitBlockRewrite : public BlockRewrite {
Block *originalBlock;
};
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
+
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value castValue;
+};
+
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block)
- : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+ BlockTypeConversionRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+ const TypeConverter *converter)
+ : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+ origBlock(origBlock), argInfo(argInfo), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- // TODO: Block type conversions are currently committed in
- // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+ /// Materialize any necessary conversions for converted arguments that have
+ /// live users, using the provided `findLiveUser` to search for a user that
+ /// survives the conversion process.
+ LogicalResult
+ materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+ void commit() override;
+
void rollback() override;
+
+private:
+ /// The original block that was requested to have its signature converted.
+ Block *origBlock;
+
+ /// The conversion information for each of the arguments. The information is
+ /// std::nullopt if the argument was dropped during conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+ /// The type converter used to convert the arguments.
+ const TypeConverter...
[truncated]
|
c7afdb2
to
a79501e
Compare
ae08e91
to
dcd13b8
Compare
a79501e
to
a7fffc3
Compare
dcd13b8
to
61e82f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, just small nits. Thanks
5dc79f6
to
445908f
Compare
61e82f6
to
68cb259
Compare
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure). Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via `BlockTypeConversionRewrite::rollback`, but there was no `BlockTypeConversionRewrite::commit` equivalent. Overview of changes: * Remove `ArgConverter`, an internal helper class that kept track of all block type conversions. There is now a separate `BlockTypeConversionRewrite` for each block type conversion. * No more special handling for block type conversions. They are now normal "IR rewrites", just like "block creation" or "block movement". In particular, trigger "commits" of block type conversion via `BlockTypeConversionRewrite::commit`. * Remove `ArgConverter::notifyOpRemoved`. This function was used to inform the `ArgConverter` that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add a `notifyOpRemoved` API to `IRRewrite`. Instead, erasing ops/block should go through a new `SingleEraseRewriter` (that is owned by the `ConversionPatternRewriterImpl`) if there is chance of double-free. This rewriter ignores `eraseOp`/`eraseBlock` if the op/block was already freed. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
68cb259
to
04eb7bd
Compare
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).
Until now, the signature conversion of a block was only a "partial" IR rewrite. Rollbacks were triggered via
BlockTypeConversionRewrite::rollback
, but there was noBlockTypeConversionRewrite::commit
equivalent.Overview of changes:
ArgConverter
, an internal helper class that kept track of all block type conversions. There is now a separateBlockTypeConversionRewrite
for each block type conversion.BlockTypeConversionRewrite::commit
.ArgConverter::notifyOpRemoved
. This function was used to inform theArgConverter
that an operation was erased, to prevent a double-free of operations in certain situations. It would be unpractical to add anotifyOpRemoved
API toIRRewrite
. Instead, erasing ops/block should go through a newSingleEraseRewriter
(that is owned by theConversionPatternRewriterImpl
) if there is chance of double-free. This rewriter ignoreseraseOp
/eraseBlock
if the op/block was already freed.