-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms][NFC] Turn unresolved materializations into IRRewrite
s
#81761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Turn unresolved materializations into IRRewrite
s
#81761
Conversation
3873a3e
to
6701034
Compare
ab0cd8c
to
51a7e12
Compare
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure). This commit turns the creation of unresolved materializations ( Patch is 25.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81761.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5b7ad4e7b8e281..4ef26a739e4ea1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,15 +152,11 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
- unsigned numIgnoredOperations, unsigned numErased)
- : numUnresolvedMaterializations(numUnresolvedMaterializations),
- numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+ RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
+ : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
numErased(numErased) {}
- /// The current number of unresolved materializations.
- unsigned numUnresolvedMaterializations;
-
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -171,109 +167,10 @@ struct RewriterState {
unsigned numErased;
};
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
- /// The type of materialization.
- enum Kind {
- /// This materialization materializes a conversion for an illegal block
- /// argument type, to a legal one.
- Argument,
-
- /// This materialization materializes a conversion from an illegal type to a
- /// legal one.
- Target
- };
-
- UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
- const TypeConverter *converter = nullptr,
- Kind kind = Target, Type origOutputType = nullptr)
- : op(op), converterAndKind(converter, kind),
- origOutputType(origOutputType) {}
-
- /// Return the temporary conversion operation inserted for this
- /// materialization.
- UnrealizedConversionCastOp getOp() const { return op; }
-
- /// Return the type converter of this materialization (which may be null).
- const TypeConverter *getConverter() const {
- return converterAndKind.getPointer();
- }
-
- /// Return the kind of this materialization.
- Kind getKind() const { return converterAndKind.getInt(); }
-
- /// Set the kind of this materialization.
- void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
- /// Return the original illegal output type of the input values.
- Type getOrigOutputType() const { return origOutputType; }
-
-private:
- /// The unresolved materialization operation created during conversion.
- UnrealizedConversionCastOp op;
-
- /// The corresponding type converter to use when resolving this
- /// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
- /// The original output type. This is only used for argument conversions.
- Type origOutputType;
-};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
- UnresolvedMaterialization::Kind kind, Block *insertBlock,
- Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
- Type origOutputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- // Avoid materializing an unnecessary cast.
- if (inputs.size() == 1 && inputs.front().getType() == outputType)
- return inputs.front();
-
- // Create an unresolved materialization. We use a new OpBuilder to avoid
- // tracking the materialization like we do for other operations.
- OpBuilder builder(insertBlock, insertPt);
- auto convertOp =
- builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.emplace_back(convertOp, converter, kind,
- origOutputType);
- return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
- PatternRewriter &rewriter, Location loc, ValueRange inputs,
- Type origOutputType, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
- rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
- converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
- Location loc, Value input, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- Block *insertBlock = input.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(input))
- insertPt = ++inputRes.getOwner()->getIterator();
-
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
- outputType, outputType, converter, unresolvedMaterializations);
-}
-
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
-namespace {
/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
@@ -295,7 +192,8 @@ class IRRewrite {
MoveOperation,
ModifyOperation,
ReplaceOperation,
- CreateOperation
+ CreateOperation,
+ UnresolvedMaterialization
};
virtual ~IRRewrite() = default;
@@ -602,7 +500,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::CreateOperation;
+ rewrite->getKind() <= Kind::UnresolvedMaterialization;
}
protected:
@@ -721,6 +619,70 @@ class CreateOperationRewrite : public OperationRewrite {
void rollback() override;
};
+
+/// The type of materialization.
+enum MaterializationKind {
+ /// This materialization materializes a conversion for an illegal block
+ /// argument type, to a legal one.
+ Argument,
+
+ /// This materialization materializes a conversion from an illegal type to a
+ /// legal one.
+ Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+ UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+ MaterializationKind kind = MaterializationKind::Target,
+ Type origOutputType = nullptr)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::UnresolvedMaterialization;
+ }
+
+ UnrealizedConversionCastOp getOperation() const {
+ return cast<UnrealizedConversionCastOp>(op);
+ }
+
+ void rollback() override;
+
+ void cleanup() override;
+
+ /// Return the type converter of this materialization (which may be null).
+ const TypeConverter *getConverter() const {
+ return converterAndKind.getPointer();
+ }
+
+ /// Return the kind of this materialization.
+ MaterializationKind getMaterializationKind() const {
+ return converterAndKind.getInt();
+ }
+
+ /// Set the kind of this materialization.
+ void setMaterializationKind(MaterializationKind kind) {
+ converterAndKind.setInt(kind);
+ }
+
+ /// Return the original illegal output type of the input values.
+ Type getOrigOutputType() const { return origOutputType; }
+
+private:
+ /// The corresponding type converter to use when resolving this
+ /// materialization, and the kind of this materialization.
+ llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+ converterAndKind;
+
+ /// The original output type. This is only used for argument conversions.
+ Type origOutputType;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -763,14 +725,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
notifyCallback(nullptr) {}
- /// Cleanup and destroy any generated rewrite operations. This method is
- /// invoked when the conversion process fails.
- void discardRewrites();
-
- /// Apply all requested operation rewrites. This method is invoked when the
- /// conversion process succeeds.
- void applyRewrites();
-
//===--------------------------------------------------------------------===//
// State Management
//===--------------------------------------------------------------------===//
@@ -778,6 +732,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Return the current state of the rewriter.
RewriterState getCurrentState();
+ /// Apply all requested operation rewrites. This method is invoked when the
+ /// conversion process succeeds.
+ void applyRewrites();
+
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
@@ -810,17 +768,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// removes them from being considered for legalization.
void markNestedOpsIgnored(Operation *op);
- /// Detach any operations nested in the given operation from their parent
- /// blocks, and erase the given operation. This can be used when the nested
- /// operations are scheduled for erasure themselves, so deleting the regions
- /// of the given operation together with their content would result in
- /// double-free. This happens, for example, when rolling back op creation in
- /// the reverse order and if the nested ops were created before the parent op.
- /// This function does not need to collect nested ops recursively because it
- /// is expected to also be called for each nested op when it is about to be
- /// deleted.
- void detachNestedAndErase(Operation *op);
-
//===--------------------------------------------------------------------===//
// Type Conversion
//===--------------------------------------------------------------------===//
@@ -859,6 +806,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
+ //===--------------------------------------------------------------------===//
+ // Materializations
+ //===--------------------------------------------------------------------===//
+ /// Build an unresolved materialization operation given an output type and set
+ /// of input operands.
+ Value buildUnresolvedMaterialization(MaterializationKind kind,
+ Block *insertBlock,
+ Block::iterator insertPt, Location loc,
+ ValueRange inputs, Type outputType,
+ Type origOutputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
+ Location loc, ValueRange inputs,
+ Type origOutputType,
+ Type outputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+ Type outputType,
+ const TypeConverter *converter);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -938,10 +907,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// replacing a value with one of a different type.
ConversionValueMapping mapping;
- /// Ordered vector of all unresolved type conversion materializations during
- /// conversion.
- SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
-
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
@@ -1129,26 +1094,15 @@ void CreateOperationRewrite::rollback() {
eraseOp(op);
}
-void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
- // if (erasedIR.erasedOps.contains(op)) return;
-
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region.getBlocks()) {
- while (!block.getOperations().empty())
- block.getOperations().remove(block.getOperations().begin());
- block.dropAllDefinedValueUses();
- }
+void UnresolvedMaterializationRewrite::rollback() {
+ if (getMaterializationKind() == MaterializationKind::Target) {
+ for (Value input : op->getOperands())
+ rewriterImpl.mapping.erase(input);
}
- eraseRewriter.eraseOp(op);
+ eraseOp(op);
}
-void ConversionPatternRewriterImpl::discardRewrites() {
- undoRewrites();
-
- // Remove any newly created ops.
- for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
- detachNestedAndErase(materialization.getOp());
-}
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
@@ -1156,39 +1110,20 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrite->commit();
for (auto &rewrite : rewrites)
rewrite->cleanup();
-
- // Drop all of the unresolved materialization operations created during
- // conversion.
- for (auto &mat : unresolvedMaterializations)
- eraseRewriter.eraseOp(mat.getOp());
}
//===----------------------------------------------------------------------===//
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
- ignoredOps.size(), eraseRewriter.erased.size());
+ return RewriterState(rewrites.size(), ignoredOps.size(),
+ eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Undo any rewrites.
undoRewrites(state.numRewrites);
- // Pop all of the newly inserted materializations.
- while (unresolvedMaterializations.size() !=
- state.numUnresolvedMaterializations) {
- UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
- UnrealizedConversionCastOp op = mat.getOp();
-
- // If this was a target materialization, drop the mapping that was inserted.
- if (mat.getKind() == UnresolvedMaterialization::Target) {
- for (Value input : op->getOperands())
- mapping.erase(input);
- }
- detachNestedAndErase(op);
- }
-
// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
@@ -1249,8 +1184,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
Value castValue = buildUnresolvedTargetMaterialization(
- operandLoc, newOperand, desiredType, currentTypeConverter,
- unresolvedMaterializations);
+ operandLoc, newOperand, desiredType, currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
@@ -1432,7 +1366,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
newArg = buildUnresolvedArgumentMaterialization(
rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
+ converter);
}
mapping.map(origArg, newArg);
@@ -1445,6 +1379,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
return newBlock;
}
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+ Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+ const TypeConverter *converter) {
+ // Avoid materializing an unnecessary cast.
+ if (inputs.size() == 1 && inputs.front().getType() == outputType)
+ return inputs.front();
+
+ // Create an unresolved materialization. We use a new OpBuilder to avoid
+ // tracking the materialization like we do for other operations.
+ OpBuilder builder(insertBlock, insertPt);
+ auto convertOp =
+ builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ origOutputType);
+ return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+ PatternRewriter &rewriter, Location loc, ValueRange inputs,
+ Type origOutputType, Type outputType, const TypeConverter *converter) {
+ return buildUnresolvedMaterialization(
+ MaterializationKind::Argument, rewriter.getInsertionBlock(),
+ rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+ converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+ Location loc, Value input, Type outputType,
+ const TypeConverter *converter) {
+ Block *insertBlock = input.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(input))
+ insertPt = ++inputRes.getOwner()->getIterator();
+
+ return buildUnresolvedMaterialization(MaterializationKind::Target,
+ insertBlock, insertPt, loc, input,
+ outputType, outputType, converter);
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -2497,18 +2475,18 @@ LogicalResult OperationConverter::convertOperations(
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// Now that all of the operations have been converted, finalize the conversion
// process to ensure any lingering conversion artifacts are cleaned up and
// legalized.
if (failed(finalize(rewriter)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// After a successful conversion, apply rewrites if this is not an analysis
// conversion.
if (mode == OpConversionMode::Analysis) {
- rewriterImpl.discardRewrites();
+ rewriterImpl.undoRewrites();
} else {
rewriterImpl.applyRewrites();
}
@@ -2613,11 +2591,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
/// Compute all of the unresolved materializations that will persist beyond the
/// conversion process, and require inserting a proper user materialization for.
static void computeNecessaryMaterializations(
- DenseMap<Operation *, UnresolvedMaterialization *> &materializati...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure). This commit turns the creation of unresolved materializations ( Patch is 25.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81761.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5b7ad4e7b8e281..4ef26a739e4ea1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,15 +152,11 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
- unsigned numIgnoredOperations, unsigned numErased)
- : numUnresolvedMaterializations(numUnresolvedMaterializations),
- numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+ RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
+ : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
numErased(numErased) {}
- /// The current number of unresolved materializations.
- unsigned numUnresolvedMaterializations;
-
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -171,109 +167,10 @@ struct RewriterState {
unsigned numErased;
};
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
- /// The type of materialization.
- enum Kind {
- /// This materialization materializes a conversion for an illegal block
- /// argument type, to a legal one.
- Argument,
-
- /// This materialization materializes a conversion from an illegal type to a
- /// legal one.
- Target
- };
-
- UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
- const TypeConverter *converter = nullptr,
- Kind kind = Target, Type origOutputType = nullptr)
- : op(op), converterAndKind(converter, kind),
- origOutputType(origOutputType) {}
-
- /// Return the temporary conversion operation inserted for this
- /// materialization.
- UnrealizedConversionCastOp getOp() const { return op; }
-
- /// Return the type converter of this materialization (which may be null).
- const TypeConverter *getConverter() const {
- return converterAndKind.getPointer();
- }
-
- /// Return the kind of this materialization.
- Kind getKind() const { return converterAndKind.getInt(); }
-
- /// Set the kind of this materialization.
- void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
- /// Return the original illegal output type of the input values.
- Type getOrigOutputType() const { return origOutputType; }
-
-private:
- /// The unresolved materialization operation created during conversion.
- UnrealizedConversionCastOp op;
-
- /// The corresponding type converter to use when resolving this
- /// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
- /// The original output type. This is only used for argument conversions.
- Type origOutputType;
-};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
- UnresolvedMaterialization::Kind kind, Block *insertBlock,
- Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
- Type origOutputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- // Avoid materializing an unnecessary cast.
- if (inputs.size() == 1 && inputs.front().getType() == outputType)
- return inputs.front();
-
- // Create an unresolved materialization. We use a new OpBuilder to avoid
- // tracking the materialization like we do for other operations.
- OpBuilder builder(insertBlock, insertPt);
- auto convertOp =
- builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.emplace_back(convertOp, converter, kind,
- origOutputType);
- return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
- PatternRewriter &rewriter, Location loc, ValueRange inputs,
- Type origOutputType, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
- rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
- converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
- Location loc, Value input, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- Block *insertBlock = input.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(input))
- insertPt = ++inputRes.getOwner()->getIterator();
-
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
- outputType, outputType, converter, unresolvedMaterializations);
-}
-
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
-namespace {
/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
@@ -295,7 +192,8 @@ class IRRewrite {
MoveOperation,
ModifyOperation,
ReplaceOperation,
- CreateOperation
+ CreateOperation,
+ UnresolvedMaterialization
};
virtual ~IRRewrite() = default;
@@ -602,7 +500,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::CreateOperation;
+ rewrite->getKind() <= Kind::UnresolvedMaterialization;
}
protected:
@@ -721,6 +619,70 @@ class CreateOperationRewrite : public OperationRewrite {
void rollback() override;
};
+
+/// The type of materialization.
+enum MaterializationKind {
+ /// This materialization materializes a conversion for an illegal block
+ /// argument type, to a legal one.
+ Argument,
+
+ /// This materialization materializes a conversion from an illegal type to a
+ /// legal one.
+ Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+ UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+ MaterializationKind kind = MaterializationKind::Target,
+ Type origOutputType = nullptr)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::UnresolvedMaterialization;
+ }
+
+ UnrealizedConversionCastOp getOperation() const {
+ return cast<UnrealizedConversionCastOp>(op);
+ }
+
+ void rollback() override;
+
+ void cleanup() override;
+
+ /// Return the type converter of this materialization (which may be null).
+ const TypeConverter *getConverter() const {
+ return converterAndKind.getPointer();
+ }
+
+ /// Return the kind of this materialization.
+ MaterializationKind getMaterializationKind() const {
+ return converterAndKind.getInt();
+ }
+
+ /// Set the kind of this materialization.
+ void setMaterializationKind(MaterializationKind kind) {
+ converterAndKind.setInt(kind);
+ }
+
+ /// Return the original illegal output type of the input values.
+ Type getOrigOutputType() const { return origOutputType; }
+
+private:
+ /// The corresponding type converter to use when resolving this
+ /// materialization, and the kind of this materialization.
+ llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+ converterAndKind;
+
+ /// The original output type. This is only used for argument conversions.
+ Type origOutputType;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -763,14 +725,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
notifyCallback(nullptr) {}
- /// Cleanup and destroy any generated rewrite operations. This method is
- /// invoked when the conversion process fails.
- void discardRewrites();
-
- /// Apply all requested operation rewrites. This method is invoked when the
- /// conversion process succeeds.
- void applyRewrites();
-
//===--------------------------------------------------------------------===//
// State Management
//===--------------------------------------------------------------------===//
@@ -778,6 +732,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Return the current state of the rewriter.
RewriterState getCurrentState();
+ /// Apply all requested operation rewrites. This method is invoked when the
+ /// conversion process succeeds.
+ void applyRewrites();
+
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
@@ -810,17 +768,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// removes them from being considered for legalization.
void markNestedOpsIgnored(Operation *op);
- /// Detach any operations nested in the given operation from their parent
- /// blocks, and erase the given operation. This can be used when the nested
- /// operations are scheduled for erasure themselves, so deleting the regions
- /// of the given operation together with their content would result in
- /// double-free. This happens, for example, when rolling back op creation in
- /// the reverse order and if the nested ops were created before the parent op.
- /// This function does not need to collect nested ops recursively because it
- /// is expected to also be called for each nested op when it is about to be
- /// deleted.
- void detachNestedAndErase(Operation *op);
-
//===--------------------------------------------------------------------===//
// Type Conversion
//===--------------------------------------------------------------------===//
@@ -859,6 +806,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
+ //===--------------------------------------------------------------------===//
+ // Materializations
+ //===--------------------------------------------------------------------===//
+ /// Build an unresolved materialization operation given an output type and set
+ /// of input operands.
+ Value buildUnresolvedMaterialization(MaterializationKind kind,
+ Block *insertBlock,
+ Block::iterator insertPt, Location loc,
+ ValueRange inputs, Type outputType,
+ Type origOutputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
+ Location loc, ValueRange inputs,
+ Type origOutputType,
+ Type outputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+ Type outputType,
+ const TypeConverter *converter);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -938,10 +907,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// replacing a value with one of a different type.
ConversionValueMapping mapping;
- /// Ordered vector of all unresolved type conversion materializations during
- /// conversion.
- SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
-
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
@@ -1129,26 +1094,15 @@ void CreateOperationRewrite::rollback() {
eraseOp(op);
}
-void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
- // if (erasedIR.erasedOps.contains(op)) return;
-
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region.getBlocks()) {
- while (!block.getOperations().empty())
- block.getOperations().remove(block.getOperations().begin());
- block.dropAllDefinedValueUses();
- }
+void UnresolvedMaterializationRewrite::rollback() {
+ if (getMaterializationKind() == MaterializationKind::Target) {
+ for (Value input : op->getOperands())
+ rewriterImpl.mapping.erase(input);
}
- eraseRewriter.eraseOp(op);
+ eraseOp(op);
}
-void ConversionPatternRewriterImpl::discardRewrites() {
- undoRewrites();
-
- // Remove any newly created ops.
- for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
- detachNestedAndErase(materialization.getOp());
-}
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
@@ -1156,39 +1110,20 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrite->commit();
for (auto &rewrite : rewrites)
rewrite->cleanup();
-
- // Drop all of the unresolved materialization operations created during
- // conversion.
- for (auto &mat : unresolvedMaterializations)
- eraseRewriter.eraseOp(mat.getOp());
}
//===----------------------------------------------------------------------===//
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
- ignoredOps.size(), eraseRewriter.erased.size());
+ return RewriterState(rewrites.size(), ignoredOps.size(),
+ eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Undo any rewrites.
undoRewrites(state.numRewrites);
- // Pop all of the newly inserted materializations.
- while (unresolvedMaterializations.size() !=
- state.numUnresolvedMaterializations) {
- UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
- UnrealizedConversionCastOp op = mat.getOp();
-
- // If this was a target materialization, drop the mapping that was inserted.
- if (mat.getKind() == UnresolvedMaterialization::Target) {
- for (Value input : op->getOperands())
- mapping.erase(input);
- }
- detachNestedAndErase(op);
- }
-
// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
@@ -1249,8 +1184,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
Value castValue = buildUnresolvedTargetMaterialization(
- operandLoc, newOperand, desiredType, currentTypeConverter,
- unresolvedMaterializations);
+ operandLoc, newOperand, desiredType, currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
@@ -1432,7 +1366,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
newArg = buildUnresolvedArgumentMaterialization(
rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
+ converter);
}
mapping.map(origArg, newArg);
@@ -1445,6 +1379,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
return newBlock;
}
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+ Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+ const TypeConverter *converter) {
+ // Avoid materializing an unnecessary cast.
+ if (inputs.size() == 1 && inputs.front().getType() == outputType)
+ return inputs.front();
+
+ // Create an unresolved materialization. We use a new OpBuilder to avoid
+ // tracking the materialization like we do for other operations.
+ OpBuilder builder(insertBlock, insertPt);
+ auto convertOp =
+ builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ origOutputType);
+ return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+ PatternRewriter &rewriter, Location loc, ValueRange inputs,
+ Type origOutputType, Type outputType, const TypeConverter *converter) {
+ return buildUnresolvedMaterialization(
+ MaterializationKind::Argument, rewriter.getInsertionBlock(),
+ rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+ converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+ Location loc, Value input, Type outputType,
+ const TypeConverter *converter) {
+ Block *insertBlock = input.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(input))
+ insertPt = ++inputRes.getOwner()->getIterator();
+
+ return buildUnresolvedMaterialization(MaterializationKind::Target,
+ insertBlock, insertPt, loc, input,
+ outputType, outputType, converter);
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -2497,18 +2475,18 @@ LogicalResult OperationConverter::convertOperations(
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// Now that all of the operations have been converted, finalize the conversion
// process to ensure any lingering conversion artifacts are cleaned up and
// legalized.
if (failed(finalize(rewriter)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// After a successful conversion, apply rewrites if this is not an analysis
// conversion.
if (mode == OpConversionMode::Analysis) {
- rewriterImpl.discardRewrites();
+ rewriterImpl.undoRewrites();
} else {
rewriterImpl.applyRewrites();
}
@@ -2613,11 +2591,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
/// Compute all of the unresolved materializations that will persist beyond the
/// conversion process, and require inserting a proper user materialization for.
static void computeNecessaryMaterializations(
- DenseMap<Operation *, UnresolvedMaterialization *> &materializati...
[truncated]
|
IRRewrite
sIRRewrite
s
6701034
to
d15c439
Compare
51a7e12
to
6d02f6d
Compare
e4e7d7c
to
8859a66
Compare
BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
6d02f6d
to
658d828
Compare
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).
This commit turns the creation of unresolved materializations (
unrealized_conversion_cast
) intoIRRewrite
objects. After this commit, all steps inapplyRewrites
anddiscardRewrites
are calls toIRRewrite::commit
andIRRewrite::rollback
.