Skip to content

[mlir][Transforms][NFC] Turn unresolved materializations into IRRewrites #81761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Feb 14, 2024

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).

This commit turns the creation of unresolved materializations (unrealized_conversion_cast) into IRRewrite objects. After this commit, all steps in applyRewrites and discardRewrites are calls to IRRewrite::commit and IRRewrite::rollback.

@matthias-springer matthias-springer force-pushed the users/matthias_springer/op_creation branch from 3873a3e to 6701034 Compare February 16, 2024 15:22
@matthias-springer matthias-springer force-pushed the users/matthias-springer/unresolved_materialization branch 2 times, most recently from ab0cd8c to 51a7e12 Compare February 19, 2024 10:32
@matthias-springer matthias-springer marked this pull request as ready for review February 19, 2024 10:32
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).

This commit turns the creation of unresolved materializations (unrealized_conversion_cast) into IRRewrite objects. After this commit, all steps in applyRewrites and discardRewrites are calls to IRRewrite::commit and IRRewrite::rollback.


Patch is 25.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81761.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+176-195)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5b7ad4e7b8e281..4ef26a739e4ea1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,15 +152,11 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
-                unsigned numIgnoredOperations, unsigned numErased)
-      : numUnresolvedMaterializations(numUnresolvedMaterializations),
-        numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+                unsigned numErased)
+      : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of unresolved materializations.
-  unsigned numUnresolvedMaterializations;
-
   /// The current number of rewrites performed.
   unsigned numRewrites;
 
@@ -171,109 +167,10 @@ struct RewriterState {
   unsigned numErased;
 };
 
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
-  /// The type of materialization.
-  enum Kind {
-    /// This materialization materializes a conversion for an illegal block
-    /// argument type, to a legal one.
-    Argument,
-
-    /// This materialization materializes a conversion from an illegal type to a
-    /// legal one.
-    Target
-  };
-
-  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
-                            const TypeConverter *converter = nullptr,
-                            Kind kind = Target, Type origOutputType = nullptr)
-      : op(op), converterAndKind(converter, kind),
-        origOutputType(origOutputType) {}
-
-  /// Return the temporary conversion operation inserted for this
-  /// materialization.
-  UnrealizedConversionCastOp getOp() const { return op; }
-
-  /// Return the type converter of this materialization (which may be null).
-  const TypeConverter *getConverter() const {
-    return converterAndKind.getPointer();
-  }
-
-  /// Return the kind of this materialization.
-  Kind getKind() const { return converterAndKind.getInt(); }
-
-  /// Set the kind of this materialization.
-  void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
-  /// Return the original illegal output type of the input values.
-  Type getOrigOutputType() const { return origOutputType; }
-
-private:
-  /// The unresolved materialization operation created during conversion.
-  UnrealizedConversionCastOp op;
-
-  /// The corresponding type converter to use when resolving this
-  /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
-  /// The original output type. This is only used for argument conversions.
-  Type origOutputType;
-};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
-    UnresolvedMaterialization::Kind kind, Block *insertBlock,
-    Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
-    Type origOutputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  // Avoid materializing an unnecessary cast.
-  if (inputs.size() == 1 && inputs.front().getType() == outputType)
-    return inputs.front();
-
-  // Create an unresolved materialization. We use a new OpBuilder to avoid
-  // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
-  auto convertOp =
-      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
-                                          origOutputType);
-  return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
-    PatternRewriter &rewriter, Location loc, ValueRange inputs,
-    Type origOutputType, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
-      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
-      converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
-    Location loc, Value input, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  Block *insertBlock = input.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(input))
-    insertPt = ++inputRes.getOwner()->getIterator();
-
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
-      outputType, outputType, converter, unresolvedMaterializations);
-}
-
 //===----------------------------------------------------------------------===//
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -295,7 +192,8 @@ class IRRewrite {
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
-    CreateOperation
+    CreateOperation,
+    UnresolvedMaterialization
   };
 
   virtual ~IRRewrite() = default;
@@ -602,7 +500,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::CreateOperation;
+           rewrite->getKind() <= Kind::UnresolvedMaterialization;
   }
 
 protected:
@@ -721,6 +619,70 @@ class CreateOperationRewrite : public OperationRewrite {
 
   void rollback() override;
 };
+
+/// The type of materialization.
+enum MaterializationKind {
+  /// This materialization materializes a conversion for an illegal block
+  /// argument type, to a legal one.
+  Argument,
+
+  /// This materialization materializes a conversion from an illegal type to a
+  /// legal one.
+  Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+  UnresolvedMaterializationRewrite(
+      ConversionPatternRewriterImpl &rewriterImpl,
+      UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+      MaterializationKind kind = MaterializationKind::Target,
+      Type origOutputType = nullptr)
+      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+        converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::UnresolvedMaterialization;
+  }
+
+  UnrealizedConversionCastOp getOperation() const {
+    return cast<UnrealizedConversionCastOp>(op);
+  }
+
+  void rollback() override;
+
+  void cleanup() override;
+
+  /// Return the type converter of this materialization (which may be null).
+  const TypeConverter *getConverter() const {
+    return converterAndKind.getPointer();
+  }
+
+  /// Return the kind of this materialization.
+  MaterializationKind getMaterializationKind() const {
+    return converterAndKind.getInt();
+  }
+
+  /// Set the kind of this materialization.
+  void setMaterializationKind(MaterializationKind kind) {
+    converterAndKind.setInt(kind);
+  }
+
+  /// Return the original illegal output type of the input values.
+  Type getOrigOutputType() const { return origOutputType; }
+
+private:
+  /// The corresponding type converter to use when resolving this
+  /// materialization, and the kind of this materialization.
+  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+      converterAndKind;
+
+  /// The original output type. This is only used for argument conversions.
+  Type origOutputType;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -763,14 +725,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
         notifyCallback(nullptr) {}
 
-  /// Cleanup and destroy any generated rewrite operations. This method is
-  /// invoked when the conversion process fails.
-  void discardRewrites();
-
-  /// Apply all requested operation rewrites. This method is invoked when the
-  /// conversion process succeeds.
-  void applyRewrites();
-
   //===--------------------------------------------------------------------===//
   // State Management
   //===--------------------------------------------------------------------===//
@@ -778,6 +732,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Return the current state of the rewriter.
   RewriterState getCurrentState();
 
+  /// Apply all requested operation rewrites. This method is invoked when the
+  /// conversion process succeeds.
+  void applyRewrites();
+
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
@@ -810,17 +768,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// removes them from being considered for legalization.
   void markNestedOpsIgnored(Operation *op);
 
-  /// Detach any operations nested in the given operation from their parent
-  /// blocks, and erase the given operation. This can be used when the nested
-  /// operations are scheduled for erasure themselves, so deleting the regions
-  /// of the given operation together with their content would result in
-  /// double-free. This happens, for example, when rolling back op creation in
-  /// the reverse order and if the nested ops were created before the parent op.
-  /// This function does not need to collect nested ops recursively because it
-  /// is expected to also be called for each nested op when it is about to be
-  /// deleted.
-  void detachNestedAndErase(Operation *op);
-
   //===--------------------------------------------------------------------===//
   // Type Conversion
   //===--------------------------------------------------------------------===//
@@ -859,6 +806,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
+  //===--------------------------------------------------------------------===//
+  // Materializations
+  //===--------------------------------------------------------------------===//
+  /// Build an unresolved materialization operation given an output type and set
+  /// of input operands.
+  Value buildUnresolvedMaterialization(MaterializationKind kind,
+                                       Block *insertBlock,
+                                       Block::iterator insertPt, Location loc,
+                                       ValueRange inputs, Type outputType,
+                                       Type origOutputType,
+                                       const TypeConverter *converter);
+
+  Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
+                                               Location loc, ValueRange inputs,
+                                               Type origOutputType,
+                                               Type outputType,
+                                               const TypeConverter *converter);
+
+  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+                                             Type outputType,
+                                             const TypeConverter *converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -938,10 +907,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all unresolved type conversion materializations during
-  /// conversion.
-  SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
-
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
@@ -1129,26 +1094,15 @@ void CreateOperationRewrite::rollback() {
   eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
-  // if (erasedIR.erasedOps.contains(op)) return;
-
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region.getBlocks()) {
-      while (!block.getOperations().empty())
-        block.getOperations().remove(block.getOperations().begin());
-      block.dropAllDefinedValueUses();
-    }
+void UnresolvedMaterializationRewrite::rollback() {
+  if (getMaterializationKind() == MaterializationKind::Target) {
+    for (Value input : op->getOperands())
+      rewriterImpl.mapping.erase(input);
   }
-  eraseRewriter.eraseOp(op);
+  eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::discardRewrites() {
-  undoRewrites();
-
-  // Remove any newly created ops.
-  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
-    detachNestedAndErase(materialization.getOp());
-}
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
@@ -1156,39 +1110,20 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit();
   for (auto &rewrite : rewrites)
     rewrite->cleanup();
-
-  // Drop all of the unresolved materialization operations created during
-  // conversion.
-  for (auto &mat : unresolvedMaterializations)
-    eraseRewriter.eraseOp(mat.getOp());
 }
 
 //===----------------------------------------------------------------------===//
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
-                       ignoredOps.size(), eraseRewriter.erased.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(),
+                       eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   // Undo any rewrites.
   undoRewrites(state.numRewrites);
 
-  // Pop all of the newly inserted materializations.
-  while (unresolvedMaterializations.size() !=
-         state.numUnresolvedMaterializations) {
-    UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
-    UnrealizedConversionCastOp op = mat.getOp();
-
-    // If this was a target materialization, drop the mapping that was inserted.
-    if (mat.getKind() == UnresolvedMaterialization::Target) {
-      for (Value input : op->getOperands())
-        mapping.erase(input);
-    }
-    detachNestedAndErase(op);
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1249,8 +1184,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
       Value castValue = buildUnresolvedTargetMaterialization(
-          operandLoc, newOperand, desiredType, currentTypeConverter,
-          unresolvedMaterializations);
+          operandLoc, newOperand, desiredType, currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1432,7 +1366,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
       newArg = buildUnresolvedArgumentMaterialization(
           rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter, unresolvedMaterializations);
+          converter);
     }
 
     mapping.map(origArg, newArg);
@@ -1445,6 +1379,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   return newBlock;
 }
 
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+    Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+    const TypeConverter *converter) {
+  // Avoid materializing an unnecessary cast.
+  if (inputs.size() == 1 && inputs.front().getType() == outputType)
+    return inputs.front();
+
+  // Create an unresolved materialization. We use a new OpBuilder to avoid
+  // tracking the materialization like we do for other operations.
+  OpBuilder builder(insertBlock, insertPt);
+  auto convertOp =
+      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  origOutputType);
+  return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+    PatternRewriter &rewriter, Location loc, ValueRange inputs,
+    Type origOutputType, Type outputType, const TypeConverter *converter) {
+  return buildUnresolvedMaterialization(
+      MaterializationKind::Argument, rewriter.getInsertionBlock(),
+      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+      converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+    Location loc, Value input, Type outputType,
+    const TypeConverter *converter) {
+  Block *insertBlock = input.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(input))
+    insertPt = ++inputRes.getOwner()->getIterator();
+
+  return buildUnresolvedMaterialization(MaterializationKind::Target,
+                                        insertBlock, insertPt, loc, input,
+                                        outputType, outputType, converter);
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -2497,18 +2475,18 @@ LogicalResult OperationConverter::convertOperations(
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
-      return rewriterImpl.discardRewrites(), failure();
+      return rewriterImpl.undoRewrites(), failure();
 
   // Now that all of the operations have been converted, finalize the conversion
   // process to ensure any lingering conversion artifacts are cleaned up and
   // legalized.
   if (failed(finalize(rewriter)))
-    return rewriterImpl.discardRewrites(), failure();
+    return rewriterImpl.undoRewrites(), failure();
 
   // After a successful conversion, apply rewrites if this is not an analysis
   // conversion.
   if (mode == OpConversionMode::Analysis) {
-    rewriterImpl.discardRewrites();
+    rewriterImpl.undoRewrites();
   } else {
     rewriterImpl.applyRewrites();
   }
@@ -2613,11 +2591,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
 /// Compute all of the unresolved materializations that will persist beyond the
 /// conversion process, and require inserting a proper user materialization for.
 static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterialization *> &materializati...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).

This commit turns the creation of unresolved materializations (unrealized_conversion_cast) into IRRewrite objects. After this commit, all steps in applyRewrites and discardRewrites are calls to IRRewrite::commit and IRRewrite::rollback.


Patch is 25.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81761.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+176-195)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5b7ad4e7b8e281..4ef26a739e4ea1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,15 +152,11 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
-                unsigned numIgnoredOperations, unsigned numErased)
-      : numUnresolvedMaterializations(numUnresolvedMaterializations),
-        numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+                unsigned numErased)
+      : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of unresolved materializations.
-  unsigned numUnresolvedMaterializations;
-
   /// The current number of rewrites performed.
   unsigned numRewrites;
 
@@ -171,109 +167,10 @@ struct RewriterState {
   unsigned numErased;
 };
 
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
-  /// The type of materialization.
-  enum Kind {
-    /// This materialization materializes a conversion for an illegal block
-    /// argument type, to a legal one.
-    Argument,
-
-    /// This materialization materializes a conversion from an illegal type to a
-    /// legal one.
-    Target
-  };
-
-  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
-                            const TypeConverter *converter = nullptr,
-                            Kind kind = Target, Type origOutputType = nullptr)
-      : op(op), converterAndKind(converter, kind),
-        origOutputType(origOutputType) {}
-
-  /// Return the temporary conversion operation inserted for this
-  /// materialization.
-  UnrealizedConversionCastOp getOp() const { return op; }
-
-  /// Return the type converter of this materialization (which may be null).
-  const TypeConverter *getConverter() const {
-    return converterAndKind.getPointer();
-  }
-
-  /// Return the kind of this materialization.
-  Kind getKind() const { return converterAndKind.getInt(); }
-
-  /// Set the kind of this materialization.
-  void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
-  /// Return the original illegal output type of the input values.
-  Type getOrigOutputType() const { return origOutputType; }
-
-private:
-  /// The unresolved materialization operation created during conversion.
-  UnrealizedConversionCastOp op;
-
-  /// The corresponding type converter to use when resolving this
-  /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
-  /// The original output type. This is only used for argument conversions.
-  Type origOutputType;
-};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
-    UnresolvedMaterialization::Kind kind, Block *insertBlock,
-    Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
-    Type origOutputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  // Avoid materializing an unnecessary cast.
-  if (inputs.size() == 1 && inputs.front().getType() == outputType)
-    return inputs.front();
-
-  // Create an unresolved materialization. We use a new OpBuilder to avoid
-  // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
-  auto convertOp =
-      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
-                                          origOutputType);
-  return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
-    PatternRewriter &rewriter, Location loc, ValueRange inputs,
-    Type origOutputType, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
-      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
-      converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
-    Location loc, Value input, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  Block *insertBlock = input.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(input))
-    insertPt = ++inputRes.getOwner()->getIterator();
-
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
-      outputType, outputType, converter, unresolvedMaterializations);
-}
-
 //===----------------------------------------------------------------------===//
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -295,7 +192,8 @@ class IRRewrite {
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
-    CreateOperation
+    CreateOperation,
+    UnresolvedMaterialization
   };
 
   virtual ~IRRewrite() = default;
@@ -602,7 +500,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::CreateOperation;
+           rewrite->getKind() <= Kind::UnresolvedMaterialization;
   }
 
 protected:
@@ -721,6 +619,70 @@ class CreateOperationRewrite : public OperationRewrite {
 
   void rollback() override;
 };
+
+/// The type of materialization.
+enum MaterializationKind {
+  /// This materialization materializes a conversion for an illegal block
+  /// argument type, to a legal one.
+  Argument,
+
+  /// This materialization materializes a conversion from an illegal type to a
+  /// legal one.
+  Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+  UnresolvedMaterializationRewrite(
+      ConversionPatternRewriterImpl &rewriterImpl,
+      UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+      MaterializationKind kind = MaterializationKind::Target,
+      Type origOutputType = nullptr)
+      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+        converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::UnresolvedMaterialization;
+  }
+
+  UnrealizedConversionCastOp getOperation() const {
+    return cast<UnrealizedConversionCastOp>(op);
+  }
+
+  void rollback() override;
+
+  void cleanup() override;
+
+  /// Return the type converter of this materialization (which may be null).
+  const TypeConverter *getConverter() const {
+    return converterAndKind.getPointer();
+  }
+
+  /// Return the kind of this materialization.
+  MaterializationKind getMaterializationKind() const {
+    return converterAndKind.getInt();
+  }
+
+  /// Set the kind of this materialization.
+  void setMaterializationKind(MaterializationKind kind) {
+    converterAndKind.setInt(kind);
+  }
+
+  /// Return the original illegal output type of the input values.
+  Type getOrigOutputType() const { return origOutputType; }
+
+private:
+  /// The corresponding type converter to use when resolving this
+  /// materialization, and the kind of this materialization.
+  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+      converterAndKind;
+
+  /// The original output type. This is only used for argument conversions.
+  Type origOutputType;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -763,14 +725,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
         notifyCallback(nullptr) {}
 
-  /// Cleanup and destroy any generated rewrite operations. This method is
-  /// invoked when the conversion process fails.
-  void discardRewrites();
-
-  /// Apply all requested operation rewrites. This method is invoked when the
-  /// conversion process succeeds.
-  void applyRewrites();
-
   //===--------------------------------------------------------------------===//
   // State Management
   //===--------------------------------------------------------------------===//
@@ -778,6 +732,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Return the current state of the rewriter.
   RewriterState getCurrentState();
 
+  /// Apply all requested operation rewrites. This method is invoked when the
+  /// conversion process succeeds.
+  void applyRewrites();
+
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
@@ -810,17 +768,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// removes them from being considered for legalization.
   void markNestedOpsIgnored(Operation *op);
 
-  /// Detach any operations nested in the given operation from their parent
-  /// blocks, and erase the given operation. This can be used when the nested
-  /// operations are scheduled for erasure themselves, so deleting the regions
-  /// of the given operation together with their content would result in
-  /// double-free. This happens, for example, when rolling back op creation in
-  /// the reverse order and if the nested ops were created before the parent op.
-  /// This function does not need to collect nested ops recursively because it
-  /// is expected to also be called for each nested op when it is about to be
-  /// deleted.
-  void detachNestedAndErase(Operation *op);
-
   //===--------------------------------------------------------------------===//
   // Type Conversion
   //===--------------------------------------------------------------------===//
@@ -859,6 +806,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
+  //===--------------------------------------------------------------------===//
+  // Materializations
+  //===--------------------------------------------------------------------===//
+  /// Build an unresolved materialization operation given an output type and set
+  /// of input operands.
+  Value buildUnresolvedMaterialization(MaterializationKind kind,
+                                       Block *insertBlock,
+                                       Block::iterator insertPt, Location loc,
+                                       ValueRange inputs, Type outputType,
+                                       Type origOutputType,
+                                       const TypeConverter *converter);
+
+  Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
+                                               Location loc, ValueRange inputs,
+                                               Type origOutputType,
+                                               Type outputType,
+                                               const TypeConverter *converter);
+
+  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+                                             Type outputType,
+                                             const TypeConverter *converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -938,10 +907,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all unresolved type conversion materializations during
-  /// conversion.
-  SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
-
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
@@ -1129,26 +1094,15 @@ void CreateOperationRewrite::rollback() {
   eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
-  // if (erasedIR.erasedOps.contains(op)) return;
-
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region.getBlocks()) {
-      while (!block.getOperations().empty())
-        block.getOperations().remove(block.getOperations().begin());
-      block.dropAllDefinedValueUses();
-    }
+void UnresolvedMaterializationRewrite::rollback() {
+  if (getMaterializationKind() == MaterializationKind::Target) {
+    for (Value input : op->getOperands())
+      rewriterImpl.mapping.erase(input);
   }
-  eraseRewriter.eraseOp(op);
+  eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::discardRewrites() {
-  undoRewrites();
-
-  // Remove any newly created ops.
-  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
-    detachNestedAndErase(materialization.getOp());
-}
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
@@ -1156,39 +1110,20 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit();
   for (auto &rewrite : rewrites)
     rewrite->cleanup();
-
-  // Drop all of the unresolved materialization operations created during
-  // conversion.
-  for (auto &mat : unresolvedMaterializations)
-    eraseRewriter.eraseOp(mat.getOp());
 }
 
 //===----------------------------------------------------------------------===//
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
-                       ignoredOps.size(), eraseRewriter.erased.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(),
+                       eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   // Undo any rewrites.
   undoRewrites(state.numRewrites);
 
-  // Pop all of the newly inserted materializations.
-  while (unresolvedMaterializations.size() !=
-         state.numUnresolvedMaterializations) {
-    UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
-    UnrealizedConversionCastOp op = mat.getOp();
-
-    // If this was a target materialization, drop the mapping that was inserted.
-    if (mat.getKind() == UnresolvedMaterialization::Target) {
-      for (Value input : op->getOperands())
-        mapping.erase(input);
-    }
-    detachNestedAndErase(op);
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1249,8 +1184,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
       Value castValue = buildUnresolvedTargetMaterialization(
-          operandLoc, newOperand, desiredType, currentTypeConverter,
-          unresolvedMaterializations);
+          operandLoc, newOperand, desiredType, currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1432,7 +1366,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
       newArg = buildUnresolvedArgumentMaterialization(
           rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter, unresolvedMaterializations);
+          converter);
     }
 
     mapping.map(origArg, newArg);
@@ -1445,6 +1379,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   return newBlock;
 }
 
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+    Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+    const TypeConverter *converter) {
+  // Avoid materializing an unnecessary cast.
+  if (inputs.size() == 1 && inputs.front().getType() == outputType)
+    return inputs.front();
+
+  // Create an unresolved materialization. We use a new OpBuilder to avoid
+  // tracking the materialization like we do for other operations.
+  OpBuilder builder(insertBlock, insertPt);
+  auto convertOp =
+      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  origOutputType);
+  return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+    PatternRewriter &rewriter, Location loc, ValueRange inputs,
+    Type origOutputType, Type outputType, const TypeConverter *converter) {
+  return buildUnresolvedMaterialization(
+      MaterializationKind::Argument, rewriter.getInsertionBlock(),
+      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+      converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+    Location loc, Value input, Type outputType,
+    const TypeConverter *converter) {
+  Block *insertBlock = input.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(input))
+    insertPt = ++inputRes.getOwner()->getIterator();
+
+  return buildUnresolvedMaterialization(MaterializationKind::Target,
+                                        insertBlock, insertPt, loc, input,
+                                        outputType, outputType, converter);
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -2497,18 +2475,18 @@ LogicalResult OperationConverter::convertOperations(
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
-      return rewriterImpl.discardRewrites(), failure();
+      return rewriterImpl.undoRewrites(), failure();
 
   // Now that all of the operations have been converted, finalize the conversion
   // process to ensure any lingering conversion artifacts are cleaned up and
   // legalized.
   if (failed(finalize(rewriter)))
-    return rewriterImpl.discardRewrites(), failure();
+    return rewriterImpl.undoRewrites(), failure();
 
   // After a successful conversion, apply rewrites if this is not an analysis
   // conversion.
   if (mode == OpConversionMode::Analysis) {
-    rewriterImpl.discardRewrites();
+    rewriterImpl.undoRewrites();
   } else {
     rewriterImpl.applyRewrites();
   }
@@ -2613,11 +2591,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
 /// Compute all of the unresolved materializations that will persist beyond the
 /// conversion process, and require inserting a proper user materialization for.
 static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterialization *> &materializati...
[truncated]

@matthias-springer matthias-springer changed the title [mlir][Transforms][NFC][WIP] Turn unresolved materializations into IRRewrites [mlir][Transforms][NFC] Turn unresolved materializations into IRRewrites Feb 19, 2024
@matthias-springer matthias-springer force-pushed the users/matthias_springer/op_creation branch from 6701034 to d15c439 Compare February 22, 2024 10:11
@matthias-springer matthias-springer force-pushed the users/matthias-springer/unresolved_materialization branch from 51a7e12 to 6d02f6d Compare February 22, 2024 10:17
@matthias-springer matthias-springer force-pushed the users/matthias_springer/op_creation branch 2 times, most recently from e4e7d7c to 8859a66 Compare February 23, 2024 09:02
Base automatically changed from users/matthias_springer/op_creation to main February 23, 2024 09:03
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
@matthias-springer matthias-springer force-pushed the users/matthias-springer/unresolved_materialization branch from 6d02f6d to 658d828 Compare February 23, 2024 09:04
@matthias-springer matthias-springer merged commit 59ff4d1 into main Feb 23, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/unresolved_materialization branch February 23, 2024 09:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants