Skip to content

[mlir][python] remove various caching mechanisms #70831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 3, 2023

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Oct 31, 2023

This PR removes the various caching mechanisms currently in the python bindings - both positive caching and negative caching.

@makslevental makslevental marked this pull request as ready for review October 31, 2023 17:14
@llvmbot llvmbot added the mlir label Oct 31, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

This PR removes the various caching mechanisms currently in the python bindings - both positive caching and negative caching.


Full diff: https://github.com/llvm/llvm-project/pull/70831.diff

4 Files Affected:

  • (modified) mlir/docs/Bindings/Python.md (+1-1)
  • (modified) mlir/lib/Bindings/Python/Globals.h (+5-15)
  • (modified) mlir/lib/Bindings/Python/IRModule.cpp (+28-76)
  • (modified) mlir/lib/Bindings/Python/MainModule.cpp (+2-3)
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bc2e676a878c0f4..ef984e2bed7ea3a 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -945,7 +945,7 @@ When the python bindings need to locate a wrapper module, they consult the
 `dialect_search_path` and use it to find an appropriately named module. For the
 main repository, this search path is hard-coded to include the `mlir.dialects`
 module, which is where wrappers are emitted by the above build rule. Out of tree
-dialects and add their modules to the search path by calling:
+dialects can add their modules to the search path by calling:
 
 ```python
 mlir._cext.append_dialect_search_prefix("myproject.mlir.dialects")
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 21899bdce22e810..4332954f8b6927c 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,10 +9,6 @@
 #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
 #define MLIR_BINDINGS_PYTHON_GLOBALS_H
 
-#include <optional>
-#include <string>
-#include <vector>
-
 #include "PybindUtils.h"
 
 #include "mlir-c/IR.h"
@@ -21,6 +17,10 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
 
+#include <optional>
+#include <string>
+#include <vector>
+
 namespace mlir {
 namespace python {
 
@@ -45,10 +45,6 @@ class PyGlobals {
     dialectSearchPrefixes.swap(newValues);
   }
 
-  /// Clears positive and negative caches regarding what implementations are
-  /// available. Future lookups will do more expensive existence checks.
-  void clearImportCache();
-
   /// Loads a python module corresponding to the given dialect namespace.
   /// No-ops if the module has already been loaded or is not found. Raises
   /// an error on any evaluation issues.
@@ -113,16 +109,10 @@ class PyGlobals {
   llvm::StringMap<pybind11::object> attributeBuilderMap;
   /// Map of MlirTypeID to custom type caster.
   llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
-  /// Cache for map of MlirTypeID to custom type caster.
-  llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
 
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
-  llvm::StringSet<> loadedDialectModulesCache;
-  /// Cache of operation name to external operation class object. This is
-  /// maintained on lookup as a shadow of operationClassMap in order for repeat
-  /// lookups of the classes to only incur the cost of one hashtable lookup.
-  llvm::StringMap<pybind11::object> operationClassMapCache;
+  llvm::StringSet<> loadedDialectModules;
 };
 
 } // namespace python
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index f8e22f7bb0c1ba7..598c41012b3663d 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -10,12 +10,12 @@
 #include "Globals.h"
 #include "PybindUtils.h"
 
-#include <optional>
-#include <vector>
-
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Support.h"
 
+#include <optional>
+#include <vector>
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
@@ -37,7 +37,7 @@ PyGlobals::PyGlobals() {
 PyGlobals::~PyGlobals() { instance = nullptr; }
 
 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
-  if (loadedDialectModulesCache.contains(dialectNamespace))
+  if (loadedDialectModules.contains(dialectNamespace))
     return;
   // Since re-entrancy is possible, make a copy of the search prefixes.
   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
@@ -59,13 +59,13 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
 
   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
   // may have occurred, which may do anything.
-  loadedDialectModulesCache.insert(dialectNamespace);
+  loadedDialectModules.insert(dialectNamespace);
 }
 
 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
                                          py::function pyFunc, bool replace) {
   py::object &found = attributeBuilderMap[attributeKind];
-  if (found && !found.is_none() && !replace) {
+  if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
                               attributeKind +
                               "' is already registered with func: " +
@@ -79,13 +79,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
                                    pybind11::function typeCaster,
                                    bool replace) {
   pybind11::object &found = typeCasterMap[mlirTypeID];
-  if (found && !found.is_none() && !replace)
-    throw std::runtime_error("Type caster is already registered");
+  if (found && !replace)
+    throw std::runtime_error("Type caster is already registered with caster: " +
+                             py::str(found).operator std::string());
   found = std::move(typeCaster);
-  const auto foundIt = typeCasterMapCache.find(mlirTypeID);
-  if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
-    typeCasterMapCache[mlirTypeID] = found;
-  }
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -108,86 +105,51 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
                                  .str());
   }
   found = std::move(pyClass);
-  auto foundIt = operationClassMapCache.find(operationName);
-  if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
-    operationClassMapCache[operationName] = found;
-  }
 }
 
 std::optional<py::function>
 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
-  // Fast match against the class map first (common case).
   const auto foundIt = attributeBuilderMap.find(attributeKind);
   if (foundIt != attributeBuilderMap.end()) {
-    if (foundIt->second.is_none())
-      return std::nullopt;
-    assert(foundIt->second && "py::function is defined");
+    assert(foundIt->second && "attribute builder is defined");
     return foundIt->second;
   }
-
-  // Not found and loading did not yield a registration. Negative cache.
-  attributeBuilderMap[attributeKind] = py::none();
   return std::nullopt;
 }
 
 std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
                                                         MlirDialect dialect) {
-  {
-    // Fast match against the class map first (common case).
-    const auto foundIt = typeCasterMapCache.find(mlirTypeID);
-    if (foundIt != typeCasterMapCache.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::function is defined");
-      return foundIt->second;
-    }
-  }
-
-  // Not found. Load the dialect namespace.
+  // Make sure dialect module is loaded.
   loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
-
-  // Attempt to find from the canonical map and cache.
-  {
-    const auto foundIt = typeCasterMap.find(mlirTypeID);
-    if (foundIt != typeCasterMap.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::object is defined");
-      // Positive cache.
-      typeCasterMapCache[mlirTypeID] = foundIt->second;
-      return foundIt->second;
-    }
-    // Negative cache.
-    typeCasterMap[mlirTypeID] = py::none();
-    return std::nullopt;
+  const auto foundIt = typeCasterMap.find(mlirTypeID);
+  if (foundIt != typeCasterMap.end()) {
+    assert(foundIt->second && "type caster is defined");
+    return foundIt->second;
   }
+  return std::nullopt;
 }
 
 std::optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
+  // Make sure dialect module is loaded.
   loadDialectModule(dialectNamespace);
-  // Fast match against the class map first (common case).
   const auto foundIt = dialectClassMap.find(dialectNamespace);
   if (foundIt != dialectClassMap.end()) {
-    if (foundIt->second.is_none())
-      return std::nullopt;
-    assert(foundIt->second && "py::object is defined");
+    assert(foundIt->second && "dialect class is defined");
     return foundIt->second;
   }
-
-  // Not found and loading did not yield a registration. Negative cache.
-  dialectClassMap[dialectNamespace] = py::none();
+  // Not found and loading did not yield a registration.
   return std::nullopt;
 }
 
 std::optional<pybind11::object>
 PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
+  // Fast match against the class map first (succeeds if second lookup, after
+  // successful dialect load).
   {
-    auto foundIt = operationClassMapCache.find(operationName);
-    if (foundIt != operationClassMapCache.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::object is defined");
+    auto foundIt = operationClassMap.find(operationName);
+    if (foundIt != operationClassMap.end()) {
+      assert(foundIt->second && "OpView is defined");
       return foundIt->second;
     }
   }
@@ -197,25 +159,15 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
   llvm::StringRef dialectNamespace = split.first;
   loadDialectModule(dialectNamespace);
 
-  // Attempt to find from the canonical map and cache.
+  // Try again to load from class map after successful dialect load.
   {
     auto foundIt = operationClassMap.find(operationName);
     if (foundIt != operationClassMap.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::object is defined");
-      // Positive cache.
-      operationClassMapCache[operationName] = foundIt->second;
+      assert(foundIt->second && "OpView is defined");
       return foundIt->second;
     }
-    // Negative cache.
-    operationClassMap[operationName] = py::none();
-    return std::nullopt;
   }
-}
 
-void PyGlobals::clearImportCache() {
-  loadedDialectModulesCache.clear();
-  operationClassMapCache.clear();
-  typeCasterMapCache.clear();
+  // Not found and loading did not yield a registration.
+  return std::nullopt;
 }
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index a936becf67bea75..2b6248321c1c110 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,14 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <tuple>
-
 #include "PybindUtils.h"
 
 #include "Globals.h"
 #include "IRModule.h"
 #include "Pass.h"
 
+#include <tuple>
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace py::literals;
@@ -34,7 +34,6 @@ PYBIND11_MODULE(_mlir, m) {
           "append_dialect_search_prefix",
           [](PyGlobals &self, std::string moduleName) {
             self.getDialectSearchPrefixes().push_back(std::move(moduleName));
-            self.clearImportCache();
           },
           "module_name"_a)
       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,

Copy link
Member

@rkayaith rkayaith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, I never fully understood all those caches

@makslevental makslevental force-pushed the remove_pybind_caches branch 3 times, most recently from 50fbf23 to 384b222 Compare October 31, 2023 22:52
@llvmbot llvmbot added the mlir:python MLIR Python bindings label Oct 31, 2023
@makslevental makslevental force-pushed the remove_pybind_caches branch 2 times, most recently from 6371720 to 0c73805 Compare October 31, 2023 23:25
@makslevental makslevental force-pushed the remove_pybind_caches branch 2 times, most recently from d718bf7 to 07aeb5d Compare November 2, 2023 14:21
@makslevental makslevental changed the title [mlir][python] remove various caching mechanism [mlir][python] remove various caching mechanisms Nov 2, 2023
Copy link

github-actions bot commented Nov 2, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like agreement that this cache isn't doing much here. So LGTM to remove and could be revised if we see regression/get reports.

@makslevental makslevental merged commit 5192e29 into llvm:main Nov 3, 2023
@makslevental makslevental deleted the remove_pybind_caches branch November 3, 2023 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants