Skip to content

[mlir][python] remove various caching mechanisms #70831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -945,10 +945,11 @@ When the python bindings need to locate a wrapper module, they consult the
`dialect_search_path` and use it to find an appropriately named module. For the
main repository, this search path is hard-coded to include the `mlir.dialects`
module, which is where wrappers are emitted by the above build rule. Out of tree
dialects and add their modules to the search path by calling:
dialects can add their modules to the search path by calling:

```python
mlir._cext.append_dialect_search_prefix("myproject.mlir.dialects")
from mlir.dialects._ods_common import _cext
_cext.globals.append_dialect_search_prefix("myproject.mlir.dialects")
```

### Wrapper module code organization
Expand Down
24 changes: 7 additions & 17 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H

#include <optional>
#include <string>
#include <vector>

#include "PybindUtils.h"

#include "mlir-c/IR.h"
Expand All @@ -21,6 +17,10 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"

#include <optional>
#include <string>
#include <vector>

namespace mlir {
namespace python {

Expand All @@ -45,17 +45,13 @@ class PyGlobals {
dialectSearchPrefixes.swap(newValues);
}

/// Clears positive and negative caches regarding what implementations are
/// available. Future lookups will do more expensive existence checks.
void clearImportCache();

/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
/// an error on any evaluation issues.
/// Note that this returns void because it is expected that the module
/// contains calls to decorators and helpers that register the salient
/// entities.
void loadDialectModule(llvm::StringRef dialectNamespace);
/// entities. Returns true if dialect is successfully loaded.
bool loadDialectModule(llvm::StringRef dialectNamespace);

/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists and replace == false.
Expand Down Expand Up @@ -113,16 +109,10 @@ class PyGlobals {
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
/// Cache for map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;

/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModulesCache;
/// Cache of operation name to external operation class object. This is
/// maintained on lookup as a shadow of operationClassMap in order for repeat
/// lookups of the classes to only incur the cost of one hashtable lookup.
llvm::StringMap<pybind11::object> operationClassMapCache;
llvm::StringSet<> loadedDialectModules;
};

} // namespace python
Expand Down
131 changes: 38 additions & 93 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
#include "Globals.h"
#include "PybindUtils.h"

#include <optional>
#include <vector>

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"

#include <optional>
#include <vector>

namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
Expand All @@ -36,12 +36,12 @@ PyGlobals::PyGlobals() {

PyGlobals::~PyGlobals() { instance = nullptr; }

void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
if (loadedDialectModulesCache.contains(dialectNamespace))
return;
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
if (loadedDialectModules.contains(dialectNamespace))
return true;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
py::object loaded = py::none();
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
Expand All @@ -57,15 +57,18 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
break;
}

if (loaded.is_none())
return false;
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
loadedDialectModulesCache.insert(dialectNamespace);
loadedDialectModules.insert(dialectNamespace);
return true;
}

void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc, bool replace) {
py::object &found = attributeBuilderMap[attributeKind];
if (found && !found.is_none() && !replace) {
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind +
"' is already registered with func: " +
Expand All @@ -79,13 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
pybind11::function typeCaster,
bool replace) {
pybind11::object &found = typeCasterMap[mlirTypeID];
if (found && !found.is_none() && !replace)
throw std::runtime_error("Type caster is already registered");
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
py::str(found).operator std::string());
found = std::move(typeCaster);
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
typeCasterMapCache[mlirTypeID] = found;
}
}

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
Expand All @@ -108,114 +108,59 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
.str());
}
found = std::move(pyClass);
auto foundIt = operationClassMapCache.find(operationName);
if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
operationClassMapCache[operationName] = found;
}
}

std::optional<py::function>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
// Fast match against the class map first (common case).
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
assert(foundIt->second && "attribute builder is defined");
return foundIt->second;
}

// Not found and loading did not yield a registration. Negative cache.
attributeBuilderMap[attributeKind] = py::none();
return std::nullopt;
}

std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
{
// Fast match against the class map first (common case).
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
if (foundIt != typeCasterMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}
}

// Not found. Load the dialect namespace.
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));

// Attempt to find from the canonical map and cache.
{
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
typeCasterMapCache[mlirTypeID] = foundIt->second;
return foundIt->second;
}
// Negative cache.
typeCasterMap[mlirTypeID] = py::none();
// Make sure dialect module is loaded.
if (!loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))))
return std::nullopt;

const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
assert(foundIt->second && "type caster is defined");
return foundIt->second;
}
return std::nullopt;
}

std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
assert(foundIt->second && "dialect class is defined");
return foundIt->second;
}

// Not found and loading did not yield a registration. Negative cache.
dialectClassMap[dialectNamespace] = py::none();
// Not found and loading did not yield a registration.
return std::nullopt;
}

std::optional<pybind11::object>
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
{
auto foundIt = operationClassMapCache.find(operationName);
if (foundIt != operationClassMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
}

// Not found. Load the dialect namespace.
// Make sure dialect module is loaded.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
loadDialectModule(dialectNamespace);

// Attempt to find from the canonical map and cache.
{
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
operationClassMapCache[operationName] = foundIt->second;
return foundIt->second;
}
// Negative cache.
operationClassMap[operationName] = py::none();
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
}
}

void PyGlobals::clearImportCache() {
loadedDialectModulesCache.clear();
operationClassMapCache.clear();
typeCasterMapCache.clear();
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
assert(foundIt->second && "OpView is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration.
return std::nullopt;
}
11 changes: 8 additions & 3 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//

#include <tuple>

#include "PybindUtils.h"

#include "Globals.h"
#include "IRModule.h"
#include "Pass.h"

#include <tuple>

namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
Expand All @@ -34,9 +34,14 @@ PYBIND11_MODULE(_mlir, m) {
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
self.clearImportCache();
},
"module_name"_a)
.def(
"_check_dialect_module_loaded",
[](PyGlobals &self, const std::string &dialectNamespace) {
return self.loadDialectModule(dialectNamespace);
},
"dialect_namespace"_a)
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class _Globals:
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
def append_dialect_search_prefix(self, module_name: str) -> None: ...
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...

def register_dialect(dialect_class: type) -> object: ...
def register_operation(dialect_class: type) -> object: ...
Empty file.
4 changes: 4 additions & 0 deletions mlir/test/python/ir/custom_dialect/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# The purpose of this empty dialect module is to enable successfully loading the "custom" dialect.
# Without this file here (and a corresponding _cext.globals.append_dialect_search_prefix("custom_dialect")),
# PyGlobals::loadDialectModule would search and fail to find the "custom" dialect for each Operation.create("custom.op")
# (amongst other things).
2 changes: 2 additions & 0 deletions mlir/test/python/ir/custom_dialect/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
config.excludes.add("__init__.py")
config.excludes.add("custom.py")
17 changes: 17 additions & 0 deletions mlir/test/python/ir/dialects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import sys
from mlir.ir import *
from mlir.dialects._ods_common import _cext


def run(f):
Expand Down Expand Up @@ -104,3 +106,18 @@ def testIsRegisteredOperation():
print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
# CHECK: func.not_existing: False
print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")


# CHECK-LABEL: TEST: testAppendPrefixSearchPath
@run
def testAppendPrefixSearchPath():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
assert not _cext.globals._check_dialect_module_loaded("custom")
Operation.create("custom.op")
assert not _cext.globals._check_dialect_module_loaded("custom")

sys.path.append(".")
_cext.globals.append_dialect_search_prefix("custom_dialect")
assert _cext.globals._check_dialect_module_loaded("custom")
2 changes: 0 additions & 2 deletions mlir/test/python/ir/insertion_point.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import io
import itertools
from mlir.ir import *


Expand Down