Skip to content

[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig #83754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 4, 2024

Conversation

matthias-springer
Copy link
Member

This commit adds a new ConversionConfig struct that allows users to customize the dialect conversion. This configuration is similar to GreedyRewriteConfig for the greedy pattern rewrite driver.

A few existing options are moved to this objects, simplifying the dialect conversion API.

This is a re-upload of #82250.

This reverts commit 60fbd60.

Copy link

github-actions bot commented Mar 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

`AllocaOpConversion` takes an `ArrayRef<Operation *>`, but the
underlying `SmallVector<Operation *>` was dead by the time the pattern
ran.
…ionConfig`

This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver.

A few existing options are moved to this objects, simplifying the dialect conversion API.

This is a re-upload of #82250.

This reverts commit 60fbd60.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conv_config branch from 642fcdf to 6899818 Compare March 4, 2024 05:59
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/fix_flang_1 March 4, 2024 05:59
Base automatically changed from users/matthias-springer/fix_flang_1 to main March 4, 2024 06:54
@matthias-springer matthias-springer marked this pull request as ready for review March 4, 2024 06:54
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir flang Flang issues not falling into any other category flang:fir-hlfir labels Mar 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 4, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new ConversionConfig struct that allows users to customize the dialect conversion. This configuration is similar to GreedyRewriteConfig for the greedy pattern rewrite driver.

A few existing options are moved to this objects, simplifying the dialect conversion API.

This is a re-upload of #82250.

This reverts commit 60fbd60.


Patch is 21.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83754.diff

4 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/MemoryAllocation.cpp (+2-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+45-27)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+61-73)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+10-4)
diff --git a/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp b/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp
index f0e201402fa79c..166a6b10def293 100644
--- a/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp
+++ b/flang/lib/Optimizer/Transforms/MemoryAllocation.cpp
@@ -200,7 +200,8 @@ class MemoryAllocationOpt
       return keepStackAllocation(alloca, &func.front(), options);
     });
 
-    patterns.insert<AllocaOpConversion>(context, analysis.getReturns(func));
+    llvm::SmallVector<mlir::Operation *> returnOps = analysis.getReturns(func);
+    patterns.insert<AllocaOpConversion>(context, returnOps);
     if (mlir::failed(
             mlir::applyPartialConversion(func, target, std::move(patterns)))) {
       mlir::emitError(func.getLoc(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 88eefa69a8003f..84396529eb7c2e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,6 +24,7 @@ namespace mlir {
 // Forward declarations.
 class Attribute;
 class Block;
+struct ConversionConfig;
 class ConversionPatternRewriter;
 class MLIRContext;
 class Operation;
@@ -767,7 +768,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Conversion pattern rewriters must not be used outside of dialect
   /// conversions. They apply some IR rewrites in a delayed fashion and could
   /// bring the IR into an inconsistent state when used standalone.
-  explicit ConversionPatternRewriter(MLIRContext *ctx);
+  explicit ConversionPatternRewriter(MLIRContext *ctx,
+                                     const ConversionConfig &config);
 
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
@@ -1067,6 +1069,30 @@ class PDLConversionConfig final {
 
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
+//===----------------------------------------------------------------------===//
+// ConversionConfig
+//===----------------------------------------------------------------------===//
+
+/// Dialect conversion configuration.
+struct ConversionConfig {
+  /// An optional callback used to notify about match failure diagnostics during
+  /// the conversion. Diagnostics reported to this callback may only be
+  /// available in debug mode.
+  function_ref<void(Diagnostic &)> notifyCallback = nullptr;
+
+  /// Partial conversion only. All operations that are found not to be
+  /// legalizable are placed in this set. (Note that if there is an op
+  /// explicitly marked as illegal, the conversion terminates and the set will
+  /// not necessarily be complete.)
+  DenseSet<Operation *> *unlegalizedOps = nullptr;
+
+  /// Analysis conversion only. All operations that are found to be legalizable
+  /// are placed in this set. Note that no actual rewrites are applied to the
+  /// IR during an analysis conversion and only pre-existing operations are
+  /// added to the set.
+  DenseSet<Operation *> *legalizableOps = nullptr;
+};
+
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
@@ -1080,20 +1106,16 @@ class PDLConversionConfig final {
 /// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If an
-/// `unconvertedOps` set is provided, all operations that are found not to be
-/// legalizable to the given `target` are placed within that set. (Note that if
-/// there is an op explicitly marked as illegal, the conversion terminates and
-/// the `unconvertedOps` set will not necessarily be complete.)
+/// returns failure if there ops explicitly marked as illegal.
 LogicalResult
 applyPartialConversion(ArrayRef<Operation *> ops,
                        const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       DenseSet<Operation *> *unconvertedOps = nullptr);
+                       ConversionConfig config = ConversionConfig());
 LogicalResult
 applyPartialConversion(Operation *op, const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       DenseSet<Operation *> *unconvertedOps = nullptr);
+                       ConversionConfig config = ConversionConfig());
 
 /// Apply a complete conversion on the given operations, and all nested
 /// operations. This method returns failure if the conversion of any operation
@@ -1101,31 +1123,27 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
 /// within 'ops'.
 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
                                   const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns);
+                                  const FrozenRewritePatternSet &patterns,
+                                  ConversionConfig config = ConversionConfig());
 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns);
+                                  const FrozenRewritePatternSet &patterns,
+                                  ConversionConfig config = ConversionConfig());
 
 /// Apply an analysis conversion on the given operations, and all nested
 /// operations. This method analyzes which operations would be successfully
 /// converted to the target if a conversion was applied. All operations that
 /// were found to be legalizable to the given 'target' are placed within the
-/// provided 'convertedOps' set; note that no actual rewrites are applied to the
-/// operations on success and only pre-existing operations are added to the set.
-/// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops'. There's an additional argument
-/// `notifyCallback` which is used for collecting match failure diagnostics
-/// generated during the conversion. Diagnostics are only reported to this
-/// callback may only be available in debug mode.
-LogicalResult applyAnalysisConversion(
-    ArrayRef<Operation *> ops, ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns,
-    DenseSet<Operation *> &convertedOps,
-    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
-LogicalResult applyAnalysisConversion(
-    Operation *op, ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns,
-    DenseSet<Operation *> &convertedOps,
-    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+/// provided 'config.legalizableOps' set; note that no actual rewrites are
+/// applied to the operations on success. This method only returns failure if
+/// there are unreachable blocks in any of the regions nested within 'ops'.
+LogicalResult
+applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+                        const FrozenRewritePatternSet &patterns,
+                        ConversionConfig config = ConversionConfig());
+LogicalResult
+applyAnalysisConversion(Operation *op, ConversionTarget &target,
+                        const FrozenRewritePatternSet &patterns,
+                        ConversionConfig config = ConversionConfig());
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b4284ce8be8a1b..7846f1ab56811a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -230,6 +230,8 @@ class IRRewrite {
   /// Erase the given block (unless it was already erased).
   void eraseBlock(Block *block);
 
+  const ConversionConfig &getConfig() const;
+
   const Kind kind;
   ConversionPatternRewriterImpl &rewriterImpl;
 };
@@ -735,8 +737,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
-      : eraseRewriter(rewriter.getContext()) {}
+  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+                                         const ConversionConfig &config)
+      : eraseRewriter(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -936,14 +939,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// converting the arguments of blocks within that region.
   DenseMap<Region *, const TypeConverter *> regionToConverter;
 
-  /// This allows the user to collect the match failure message.
-  function_ref<void(Diagnostic &)> notifyCallback;
-
-  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
-  /// this is populated with ops found to be legalizable to the target.
-  /// When mode == OpConversionMode::Partial, this is populated with ops found
-  /// *not* to be legalizable to the target.
-  DenseSet<Operation *> *trackedOps = nullptr;
+  /// Dialect conversion configuration.
+  const ConversionConfig &config;
 
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
@@ -966,6 +963,10 @@ void IRRewrite::eraseBlock(Block *block) {
   rewriterImpl.eraseRewriter.eraseBlock(block);
 }
 
+const ConversionConfig &IRRewrite::getConfig() const {
+  return rewriterImpl.config;
+}
+
 void BlockTypeConversionRewrite::commit() {
   // Process the remapping for each of the original arguments.
   for (auto [origArg, info] :
@@ -1085,8 +1086,8 @@ void ReplaceOperationRewrite::commit() {
     if (Value newValue =
             rewriterImpl.mapping.lookupOrNull(result, result.getType()))
       result.replaceAllUsesWith(newValue);
-  if (rewriterImpl.trackedOps)
-    rewriterImpl.trackedOps->erase(op);
+  if (getConfig().unlegalizedOps)
+    getConfig().unlegalizedOps->erase(op);
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   op->getBlock()->getOperations().remove(op);
 }
@@ -1514,8 +1515,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     logger.startLine() << "** Failure : " << diag.str() << "\n";
-    if (notifyCallback)
-      notifyCallback(diag);
+    if (config.notifyCallback)
+      config.notifyCallback(diag);
   });
 }
 
@@ -1523,9 +1524,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 // ConversionPatternRewriter
 //===----------------------------------------------------------------------===//
 
-ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+ConversionPatternRewriter::ConversionPatternRewriter(
+    MLIRContext *ctx, const ConversionConfig &config)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(*this)) {
+      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
   setListener(impl.get());
 }
 
@@ -1994,12 +1996,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
-      if (rewriterImpl.notifyCallback) {
+      if (rewriterImpl.config.notifyCallback) {
         Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
         diag << "Failed to apply pattern \"" << pattern.getDebugName()
              << "\" on op:\n"
              << *op;
-        rewriterImpl.notifyCallback(diag);
+        rewriterImpl.config.notifyCallback(diag);
       }
     });
     rewriterImpl.resetState(curState);
@@ -2387,14 +2389,12 @@ namespace mlir {
 struct OperationConverter {
   explicit OperationConverter(const ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              OpConversionMode mode,
-                              DenseSet<Operation *> *trackedOps = nullptr)
-      : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
+                              const ConversionConfig &config,
+                              OpConversionMode mode)
+      : opLegalizer(target, patterns), config(config), mode(mode) {}
 
   /// Converts the given operations to the conversion target.
-  LogicalResult
-  convertOperations(ArrayRef<Operation *> ops,
-                    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+  LogicalResult convertOperations(ArrayRef<Operation *> ops);
 
 private:
   /// Converts an operation with the given rewriter.
@@ -2431,14 +2431,11 @@ struct OperationConverter {
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
+  /// Dialect conversion configuration.
+  ConversionConfig config;
+
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
-
-  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
-  /// this is populated with ops found to be legalizable to the target.
-  /// When mode == OpConversionMode::Partial, this is populated with ops found
-  /// *not* to be legalizable to the target.
-  DenseSet<Operation *> *trackedOps;
 };
 } // namespace mlir
 
@@ -2452,28 +2449,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       return op->emitError()
              << "failed to legalize operation '" << op->getName() << "'";
     // Partial conversions allow conversions to fail iff the operation was not
-    // explicitly marked as illegal. If the user provided a nonlegalizableOps
-    // set, non-legalizable ops are included.
+    // explicitly marked as illegal. If the user provided a `unlegalizedOps`
+    // set, non-legalizable ops are added to that set.
     if (mode == OpConversionMode::Partial) {
       if (opLegalizer.isIllegal(op))
         return op->emitError()
                << "failed to legalize operation '" << op->getName()
                << "' that was explicitly marked illegal";
-      if (trackedOps)
-        trackedOps->insert(op);
+      if (config.unlegalizedOps)
+        config.unlegalizedOps->insert(op);
     }
   } else if (mode == OpConversionMode::Analysis) {
     // Analysis conversions don't fail if any operations fail to legalize,
     // they are only interested in the operations that were successfully
     // legalized.
-    trackedOps->insert(op);
+    if (config.legalizableOps)
+      config.legalizableOps->insert(op);
   }
   return success();
 }
 
-LogicalResult OperationConverter::convertOperations(
-    ArrayRef<Operation *> ops,
-    function_ref<void(Diagnostic &)> notifyCallback) {
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
   const ConversionTarget &target = opLegalizer.getTarget();
@@ -2494,10 +2490,8 @@ LogicalResult OperationConverter::convertOperations(
   }
 
   // Convert each operation and discard rewrites on failure.
-  ConversionPatternRewriter rewriter(ops.front()->getContext());
+  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  rewriterImpl.notifyCallback = notifyCallback;
-  rewriterImpl.trackedOps = trackedOps;
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
@@ -3484,57 +3478,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
 //===----------------------------------------------------------------------===//
 // Partial Conversion
 
-LogicalResult
-mlir::applyPartialConversion(ArrayRef<Operation *> ops,
-                             const ConversionTarget &target,
-                             const FrozenRewritePatternSet &patterns,
-                             DenseSet<Operation *> *unconvertedOps) {
-  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
-                                 unconvertedOps);
+LogicalResult mlir::applyPartialConversion(
+    ArrayRef<Operation *> ops, const ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+  OperationConverter opConverter(target, patterns, config,
+                                 OpConversionMode::Partial);
   return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
                              const FrozenRewritePatternSet &patterns,
-                             DenseSet<Operation *> *unconvertedOps) {
-  return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
-                                unconvertedOps);
+                             ConversionConfig config) {
+  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
 }
 
 //===----------------------------------------------------------------------===//
 // Full Conversion
 
-LogicalResult
-mlir::applyFullConversion(ArrayRef<Operation *> ops,
-                          const ConversionTarget &target,
-                          const FrozenRewritePatternSet &patterns) {
-  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
+                                        const ConversionTarget &target,
+                                        const FrozenRewritePatternSet &patterns,
+                                        ConversionConfig config) {
+  OperationConverter opConverter(target, patterns, config,
+                                 OpConversionMode::Full);
   return opConverter.convertOperations(ops);
 }
-LogicalResult
-mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
-                          const FrozenRewritePatternSet &patterns) {
-  return applyFullConversion(llvm::ArrayRef(op), target, patterns);
+LogicalResult mlir::applyFullConversion(Operation *op,
+                                        const ConversionTarget &target,
+                                        const FrozenRewritePatternSet &patterns,
+                                        ConversionConfig config) {
+  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
 }
 
 //===----------------------------------------------------------------------===//
 // Analysis Conversion
 
-LogicalResult
-mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
-                              ConversionTarget &target,
-                              const FrozenRewritePatternSet &patterns,
-                              DenseSet<Operation *> &convertedOps,
-                              function_ref<void(Diagnostic &)> notifyCallback) {
-  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
-                                 &convertedOps);
-  return opConverter.convertOperations(ops, notifyCallback);
+LogicalResult mlir::applyAnalysisConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+  OperationConverter opConverter(target, patterns, config,
+                                 OpConversionMode::Analysis);
+  return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              DenseSet<Operation *> &convertedOps,
-                              function_ref<void(Diagnostic &)> notifyCallback) {
-  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
-                                 convertedOps, notifyCallback);
+                              ConversionConfig config) {
+  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
 }
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 157bfcc1eb23be..abc0e43c7b7f2d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1152,8 +1152,10 @@ struct...
[truncated]

@matthias-springer matthias-springer merged commit a282109 into main Mar 4, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conv_config branch March 4, 2024 06:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants