Skip to content

Revert "[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig #83662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 2, 2024

Conversation

joker-eph
Copy link
Collaborator

This reverts commit 5f1319b.

A FIR test is broken on Windows

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 2, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

This reverts commit 5f1319b.

A FIR test is broken on Windows


Patch is 20.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83662.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+28-47)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+73-61)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+4-10)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..8c12cdd9be3696 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,7 +24,6 @@ namespace mlir {
 // Forward declarations.
 class Attribute;
 class Block;
-struct ConversionConfig;
 class ConversionPatternRewriter;
 class MLIRContext;
 class Operation;
@@ -768,8 +767,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Conversion pattern rewriters must not be used outside of dialect
   /// conversions. They apply some IR rewrites in a delayed fashion and could
   /// bring the IR into an inconsistent state when used standalone.
-  explicit ConversionPatternRewriter(MLIRContext *ctx,
-                                     const ConversionConfig &config);
+  explicit ConversionPatternRewriter(MLIRContext *ctx);
 
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
@@ -1069,30 +1067,6 @@ class PDLConversionConfig final {
 
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
-//===----------------------------------------------------------------------===//
-// ConversionConfig
-//===----------------------------------------------------------------------===//
-
-/// Dialect conversion configuration.
-struct ConversionConfig {
-  /// An optional callback used to notify about match failure diagnostics during
-  /// the conversion. Diagnostics reported to this callback may only be
-  /// available in debug mode.
-  function_ref<void(Diagnostic &)> notifyCallback = nullptr;
-
-  /// Partial conversion only. All operations that are found not to be
-  /// legalizable are placed in this set. (Note that if there is an op
-  /// explicitly marked as illegal, the conversion terminates and the set will
-  /// not necessarily be complete.)
-  DenseSet<Operation *> *unlegalizedOps = nullptr;
-
-  /// Analysis conversion only. All operations that are found to be legalizable
-  /// are placed in this set. Note that no actual rewrites are applied to the
-  /// IR during an analysis conversion and only pre-existing operations are
-  /// added to the set.
-  DenseSet<Operation *> *legalizableOps = nullptr;
-};
-
 //===----------------------------------------------------------------------===//
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
@@ -1106,16 +1080,19 @@ struct ConversionConfig {
 /// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal.
+/// returns failure if there ops explicitly marked as illegal. If an
+/// `unconvertedOps` set is provided, all operations that are found not to be
+/// legalizable to the given `target` are placed within that set. (Note that if
+/// there is an op explicitly marked as illegal, the conversion terminates and
+/// the `unconvertedOps` set will not necessarily be complete.)
 LogicalResult
-applyPartialConversion(ArrayRef<Operation *> ops,
-                       const ConversionTarget &target,
+applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       ConversionConfig config = ConversionConfig());
+                       DenseSet<Operation *> *unconvertedOps = nullptr);
 LogicalResult
 applyPartialConversion(Operation *op, const ConversionTarget &target,
                        const FrozenRewritePatternSet &patterns,
-                       ConversionConfig config = ConversionConfig());
+                       DenseSet<Operation *> *unconvertedOps = nullptr);
 
 /// Apply a complete conversion on the given operations, and all nested
 /// operations. This method returns failure if the conversion of any operation
@@ -1123,27 +1100,31 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
 /// within 'ops'.
 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
                                   const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns,
-                                  ConversionConfig config = ConversionConfig());
+                                  const FrozenRewritePatternSet &patterns);
 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
-                                  const FrozenRewritePatternSet &patterns,
-                                  ConversionConfig config = ConversionConfig());
+                                  const FrozenRewritePatternSet &patterns);
 
 /// Apply an analysis conversion on the given operations, and all nested
 /// operations. This method analyzes which operations would be successfully
 /// converted to the target if a conversion was applied. All operations that
 /// were found to be legalizable to the given 'target' are placed within the
-/// provided 'config.legalizableOps' set; note that no actual rewrites are
-/// applied to the operations on success. This method only returns failure if
-/// there are unreachable blocks in any of the regions nested within 'ops'.
-LogicalResult
-applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                        const FrozenRewritePatternSet &patterns,
-                        ConversionConfig config = ConversionConfig());
-LogicalResult
-applyAnalysisConversion(Operation *op, ConversionTarget &target,
-                        const FrozenRewritePatternSet &patterns,
-                        ConversionConfig config = ConversionConfig());
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+/// This method only returns failure if there are unreachable blocks in any of
+/// the regions nested within 'ops'. There's an additional argument
+/// `notifyCallback` which is used for collecting match failure diagnostics
+/// generated during the conversion. Diagnostics are only reported to this
+/// callback may only be available in debug mode.
+LogicalResult applyAnalysisConversion(
+    ArrayRef<Operation *> ops, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns,
+    DenseSet<Operation *> &convertedOps,
+    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+LogicalResult applyAnalysisConversion(
+    Operation *op, ConversionTarget &target,
+    const FrozenRewritePatternSet &patterns,
+    DenseSet<Operation *> &convertedOps,
+    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 26899301eb742e..ffdb442033d323 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -230,8 +230,6 @@ class IRRewrite {
   /// Erase the given block (unless it was already erased).
   void eraseBlock(Block *block);
 
-  const ConversionConfig &getConfig() const;
-
   const Kind kind;
   ConversionPatternRewriterImpl &rewriterImpl;
 };
@@ -734,9 +732,8 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
-                                         const ConversionConfig &config)
-      : eraseRewriter(ctx), config(config) {}
+  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+      : eraseRewriter(rewriter.getContext()) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -936,8 +933,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// converting the arguments of blocks within that region.
   DenseMap<Region *, const TypeConverter *> regionToConverter;
 
-  /// Dialect conversion configuration.
-  const ConversionConfig &config;
+  /// This allows the user to collect the match failure message.
+  function_ref<void(Diagnostic &)> notifyCallback;
+
+  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+  /// this is populated with ops found to be legalizable to the target.
+  /// When mode == OpConversionMode::Partial, this is populated with ops found
+  /// *not* to be legalizable to the target.
+  DenseSet<Operation *> *trackedOps = nullptr;
 
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
@@ -960,10 +963,6 @@ void IRRewrite::eraseBlock(Block *block) {
   rewriterImpl.eraseRewriter.eraseBlock(block);
 }
 
-const ConversionConfig &IRRewrite::getConfig() const {
-  return rewriterImpl.config;
-}
-
 void BlockTypeConversionRewrite::commit() {
   // Process the remapping for each of the original arguments.
   for (auto [origArg, info] :
@@ -1081,8 +1080,8 @@ void ReplaceOperationRewrite::commit() {
     if (Value newValue =
             rewriterImpl.mapping.lookupOrNull(result, result.getType()))
       result.replaceAllUsesWith(newValue);
-  if (getConfig().unlegalizedOps)
-    getConfig().unlegalizedOps->erase(op);
+  if (rewriterImpl.trackedOps)
+    rewriterImpl.trackedOps->erase(op);
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   op->getBlock()->getOperations().remove(op);
 }
@@ -1505,8 +1504,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     logger.startLine() << "** Failure : " << diag.str() << "\n";
-    if (config.notifyCallback)
-      config.notifyCallback(diag);
+    if (notifyCallback)
+      notifyCallback(diag);
   });
 }
 
@@ -1514,10 +1513,9 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 // ConversionPatternRewriter
 //===----------------------------------------------------------------------===//
 
-ConversionPatternRewriter::ConversionPatternRewriter(
-    MLIRContext *ctx, const ConversionConfig &config)
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+      impl(new detail::ConversionPatternRewriterImpl(*this)) {
   setListener(impl.get());
 }
 
@@ -1986,12 +1984,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
-      if (rewriterImpl.config.notifyCallback) {
+      if (rewriterImpl.notifyCallback) {
         Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
         diag << "Failed to apply pattern \"" << pattern.getDebugName()
              << "\" on op:\n"
              << *op;
-        rewriterImpl.config.notifyCallback(diag);
+        rewriterImpl.notifyCallback(diag);
       }
     });
     rewriterImpl.resetState(curState);
@@ -2379,12 +2377,14 @@ namespace mlir {
 struct OperationConverter {
   explicit OperationConverter(const ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              const ConversionConfig &config,
-                              OpConversionMode mode)
-      : opLegalizer(target, patterns), config(config), mode(mode) {}
+                              OpConversionMode mode,
+                              DenseSet<Operation *> *trackedOps = nullptr)
+      : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
 
   /// Converts the given operations to the conversion target.
-  LogicalResult convertOperations(ArrayRef<Operation *> ops);
+  LogicalResult
+  convertOperations(ArrayRef<Operation *> ops,
+                    function_ref<void(Diagnostic &)> notifyCallback = nullptr);
 
 private:
   /// Converts an operation with the given rewriter.
@@ -2421,11 +2421,14 @@ struct OperationConverter {
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
-  /// Dialect conversion configuration.
-  ConversionConfig config;
-
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
+
+  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+  /// this is populated with ops found to be legalizable to the target.
+  /// When mode == OpConversionMode::Partial, this is populated with ops found
+  /// *not* to be legalizable to the target.
+  DenseSet<Operation *> *trackedOps;
 };
 } // namespace mlir
 
@@ -2439,27 +2442,28 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       return op->emitError()
              << "failed to legalize operation '" << op->getName() << "'";
     // Partial conversions allow conversions to fail iff the operation was not
-    // explicitly marked as illegal. If the user provided a `unlegalizedOps`
-    // set, non-legalizable ops are added to that set.
+    // explicitly marked as illegal. If the user provided a nonlegalizableOps
+    // set, non-legalizable ops are included.
     if (mode == OpConversionMode::Partial) {
       if (opLegalizer.isIllegal(op))
         return op->emitError()
                << "failed to legalize operation '" << op->getName()
                << "' that was explicitly marked illegal";
-      if (config.unlegalizedOps)
-        config.unlegalizedOps->insert(op);
+      if (trackedOps)
+        trackedOps->insert(op);
     }
   } else if (mode == OpConversionMode::Analysis) {
     // Analysis conversions don't fail if any operations fail to legalize,
     // they are only interested in the operations that were successfully
     // legalized.
-    if (config.legalizableOps)
-      config.legalizableOps->insert(op);
+    trackedOps->insert(op);
   }
   return success();
 }
 
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+LogicalResult OperationConverter::convertOperations(
+    ArrayRef<Operation *> ops,
+    function_ref<void(Diagnostic &)> notifyCallback) {
   if (ops.empty())
     return success();
   const ConversionTarget &target = opLegalizer.getTarget();
@@ -2480,8 +2484,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   }
 
   // Convert each operation and discard rewrites on failure.
-  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
+  ConversionPatternRewriter rewriter(ops.front()->getContext());
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  rewriterImpl.notifyCallback = notifyCallback;
+  rewriterImpl.trackedOps = trackedOps;
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
@@ -3468,51 +3474,57 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
 //===----------------------------------------------------------------------===//
 // Partial Conversion
 
-LogicalResult mlir::applyPartialConversion(
-    ArrayRef<Operation *> ops, const ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
-  OperationConverter opConverter(target, patterns, config,
-                                 OpConversionMode::Partial);
+LogicalResult
+mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+                             const ConversionTarget &target,
+                             const FrozenRewritePatternSet &patterns,
+                             DenseSet<Operation *> *unconvertedOps) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
+                                 unconvertedOps);
   return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
                              const FrozenRewritePatternSet &patterns,
-                             ConversionConfig config) {
-  return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
+                             DenseSet<Operation *> *unconvertedOps) {
+  return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
+                                unconvertedOps);
 }
 
 //===----------------------------------------------------------------------===//
 // Full Conversion
 
-LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
-                                        const ConversionTarget &target,
-                                        const FrozenRewritePatternSet &patterns,
-                                        ConversionConfig config) {
-  OperationConverter opConverter(target, patterns, config,
-                                 OpConversionMode::Full);
+LogicalResult
+mlir::applyFullConversion(ArrayRef<Operation *> ops,
+                          const ConversionTarget &target,
+                          const FrozenRewritePatternSet &patterns) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Full);
   return opConverter.convertOperations(ops);
 }
-LogicalResult mlir::applyFullConversion(Operation *op,
-                                        const ConversionTarget &target,
-                                        const FrozenRewritePatternSet &patterns,
-                                        ConversionConfig config) {
-  return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
+LogicalResult
+mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
+                          const FrozenRewritePatternSet &patterns) {
+  return applyFullConversion(llvm::ArrayRef(op), target, patterns);
 }
 
 //===----------------------------------------------------------------------===//
 // Analysis Conversion
 
-LogicalResult mlir::applyAnalysisConversion(
-    ArrayRef<Operation *> ops, ConversionTarget &target,
-    const FrozenRewritePatternSet &patterns, ConversionConfig config) {
-  OperationConverter opConverter(target, patterns, config,
-                                 OpConversionMode::Analysis);
-  return opConverter.convertOperations(ops);
+LogicalResult
+mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
+                              ConversionTarget &target,
+                              const FrozenRewritePatternSet &patterns,
+                              DenseSet<Operation *> &convertedOps,
+                              function_ref<void(Diagnostic &)> notifyCallback) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
+                                 &convertedOps);
+  return opConverter.convertOperations(ops, notifyCallback);
 }
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
-                              ConversionConfig config) {
-  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
+                              DenseSet<Operation *> &convertedOps,
+                              function_ref<void(Diagnostic &)> notifyCallback) {
+  return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
+                                 convertedOps, notifyCallback);
 }
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index abc0e43c7b7f2d..157bfcc1eb23be 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1152,10 +1152,8 @@ struct TestLegalizePatternDriver
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
-      ConversionConfig config;
-      config.unlegalizedOps = &unlegalizedOps;
-      if (failed(applyPartialConversion(getOperation(), target,
-                                        std::move(patterns), config))) {
+      if (failed(applyPartialConversion(
+              getOperation(), target, std::move(patterns), &unlegalizedOps))) {
         getOperation()->emitRemark() << "applyPartialConversion failed";
       }
       // Emit remarks for each legalizable operation.
@@ -1183,10 +1181,8 @@ struct ...
[truncated]

Copy link

github-actions bot commented Mar 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

…`ConversionConfig` (llvm#82250)"

This reverts commit 5f1319b.

A FIR test is broken on Windows
@joker-eph joker-eph merged commit 60fbd60 into llvm:main Mar 2, 2024
@joker-eph joker-eph deleted the fix-fir-windows branch March 2, 2024 22:41
@joker-eph joker-eph added the skip-precommit-approval PR for CI feedback, not intended for review label Mar 2, 2024
@mmilanifard
Copy link
Contributor

This revert is breaking downstream builds. New struct is used here (merged 5 days ago).

@joker-eph
Copy link
Collaborator Author

This is a problem for downstream, I'm not sure what's the consideration here?

@mmilanifard
Copy link
Contributor

Merge conflict. Investigating a fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir skip-precommit-approval PR for CI feedback, not intended for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants