Skip to content

Commit 5192e29

Browse files
authored
[mlir][python] remove various caching mechanisms (#70831)
This PR removes the various caching mechanisms currently in the python bindings - both positive caching and negative caching.
1 parent bcb685e commit 5192e29

File tree

10 files changed

+80
-117
lines changed

10 files changed

+80
-117
lines changed

mlir/docs/Bindings/Python.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,10 +945,11 @@ When the python bindings need to locate a wrapper module, they consult the
945945
`dialect_search_path` and use it to find an appropriately named module. For the
946946
main repository, this search path is hard-coded to include the `mlir.dialects`
947947
module, which is where wrappers are emitted by the above build rule. Out of tree
948-
dialects and add their modules to the search path by calling:
948+
dialects can add their modules to the search path by calling:
949949

950950
```python
951-
mlir._cext.append_dialect_search_prefix("myproject.mlir.dialects")
951+
from mlir.dialects._ods_common import _cext
952+
_cext.globals.append_dialect_search_prefix("myproject.mlir.dialects")
952953
```
953954

954955
### Wrapper module code organization

mlir/lib/Bindings/Python/Globals.h

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
1010
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
1111

12-
#include <optional>
13-
#include <string>
14-
#include <vector>
15-
1612
#include "PybindUtils.h"
1713

1814
#include "mlir-c/IR.h"
@@ -21,6 +17,10 @@
2117
#include "llvm/ADT/StringRef.h"
2218
#include "llvm/ADT/StringSet.h"
2319

20+
#include <optional>
21+
#include <string>
22+
#include <vector>
23+
2424
namespace mlir {
2525
namespace python {
2626

@@ -45,17 +45,13 @@ class PyGlobals {
4545
dialectSearchPrefixes.swap(newValues);
4646
}
4747

48-
/// Clears positive and negative caches regarding what implementations are
49-
/// available. Future lookups will do more expensive existence checks.
50-
void clearImportCache();
51-
5248
/// Loads a python module corresponding to the given dialect namespace.
5349
/// No-ops if the module has already been loaded or is not found. Raises
5450
/// an error on any evaluation issues.
5551
/// Note that this returns void because it is expected that the module
5652
/// contains calls to decorators and helpers that register the salient
57-
/// entities.
58-
void loadDialectModule(llvm::StringRef dialectNamespace);
53+
/// entities. Returns true if dialect is successfully loaded.
54+
bool loadDialectModule(llvm::StringRef dialectNamespace);
5955

6056
/// Adds a user-friendly Attribute builder.
6157
/// Raises an exception if the mapping already exists and replace == false.
@@ -113,16 +109,10 @@ class PyGlobals {
113109
llvm::StringMap<pybind11::object> attributeBuilderMap;
114110
/// Map of MlirTypeID to custom type caster.
115111
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
116-
/// Cache for map of MlirTypeID to custom type caster.
117-
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
118112

119113
/// Set of dialect namespaces that we have attempted to import implementation
120114
/// modules for.
121-
llvm::StringSet<> loadedDialectModulesCache;
122-
/// Cache of operation name to external operation class object. This is
123-
/// maintained on lookup as a shadow of operationClassMap in order for repeat
124-
/// lookups of the classes to only incur the cost of one hashtable lookup.
125-
llvm::StringMap<pybind11::object> operationClassMapCache;
115+
llvm::StringSet<> loadedDialectModules;
126116
};
127117

128118
} // namespace python

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 38 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
#include "Globals.h"
1111
#include "PybindUtils.h"
1212

13-
#include <optional>
14-
#include <vector>
15-
1613
#include "mlir-c/Bindings/Python/Interop.h"
1714
#include "mlir-c/Support.h"
1815

16+
#include <optional>
17+
#include <vector>
18+
1919
namespace py = pybind11;
2020
using namespace mlir;
2121
using namespace mlir::python;
@@ -36,12 +36,12 @@ PyGlobals::PyGlobals() {
3636

3737
PyGlobals::~PyGlobals() { instance = nullptr; }
3838

39-
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
40-
if (loadedDialectModulesCache.contains(dialectNamespace))
41-
return;
39+
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
40+
if (loadedDialectModules.contains(dialectNamespace))
41+
return true;
4242
// Since re-entrancy is possible, make a copy of the search prefixes.
4343
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
44-
py::object loaded;
44+
py::object loaded = py::none();
4545
for (std::string moduleName : localSearchPrefixes) {
4646
moduleName.push_back('.');
4747
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
@@ -57,15 +57,18 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
5757
break;
5858
}
5959

60+
if (loaded.is_none())
61+
return false;
6062
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
6163
// may have occurred, which may do anything.
62-
loadedDialectModulesCache.insert(dialectNamespace);
64+
loadedDialectModules.insert(dialectNamespace);
65+
return true;
6366
}
6467

6568
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
6669
py::function pyFunc, bool replace) {
6770
py::object &found = attributeBuilderMap[attributeKind];
68-
if (found && !found.is_none() && !replace) {
71+
if (found && !replace) {
6972
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
7073
attributeKind +
7174
"' is already registered with func: " +
@@ -79,13 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
7982
pybind11::function typeCaster,
8083
bool replace) {
8184
pybind11::object &found = typeCasterMap[mlirTypeID];
82-
if (found && !found.is_none() && !replace)
83-
throw std::runtime_error("Type caster is already registered");
85+
if (found && !replace)
86+
throw std::runtime_error("Type caster is already registered with caster: " +
87+
py::str(found).operator std::string());
8488
found = std::move(typeCaster);
85-
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
86-
if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
87-
typeCasterMapCache[mlirTypeID] = found;
88-
}
8989
}
9090

9191
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -108,114 +108,59 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
108108
.str());
109109
}
110110
found = std::move(pyClass);
111-
auto foundIt = operationClassMapCache.find(operationName);
112-
if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
113-
operationClassMapCache[operationName] = found;
114-
}
115111
}
116112

117113
std::optional<py::function>
118114
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
119-
// Fast match against the class map first (common case).
120115
const auto foundIt = attributeBuilderMap.find(attributeKind);
121116
if (foundIt != attributeBuilderMap.end()) {
122-
if (foundIt->second.is_none())
123-
return std::nullopt;
124-
assert(foundIt->second && "py::function is defined");
117+
assert(foundIt->second && "attribute builder is defined");
125118
return foundIt->second;
126119
}
127-
128-
// Not found and loading did not yield a registration. Negative cache.
129-
attributeBuilderMap[attributeKind] = py::none();
130120
return std::nullopt;
131121
}
132122

133123
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
134124
MlirDialect dialect) {
135-
{
136-
// Fast match against the class map first (common case).
137-
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
138-
if (foundIt != typeCasterMapCache.end()) {
139-
if (foundIt->second.is_none())
140-
return std::nullopt;
141-
assert(foundIt->second && "py::function is defined");
142-
return foundIt->second;
143-
}
144-
}
145-
146-
// Not found. Load the dialect namespace.
147-
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
148-
149-
// Attempt to find from the canonical map and cache.
150-
{
151-
const auto foundIt = typeCasterMap.find(mlirTypeID);
152-
if (foundIt != typeCasterMap.end()) {
153-
if (foundIt->second.is_none())
154-
return std::nullopt;
155-
assert(foundIt->second && "py::object is defined");
156-
// Positive cache.
157-
typeCasterMapCache[mlirTypeID] = foundIt->second;
158-
return foundIt->second;
159-
}
160-
// Negative cache.
161-
typeCasterMap[mlirTypeID] = py::none();
125+
// Make sure dialect module is loaded.
126+
if (!loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))))
162127
return std::nullopt;
128+
129+
const auto foundIt = typeCasterMap.find(mlirTypeID);
130+
if (foundIt != typeCasterMap.end()) {
131+
assert(foundIt->second && "type caster is defined");
132+
return foundIt->second;
163133
}
134+
return std::nullopt;
164135
}
165136

166137
std::optional<py::object>
167138
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
168-
loadDialectModule(dialectNamespace);
169-
// Fast match against the class map first (common case).
139+
// Make sure dialect module is loaded.
140+
if (!loadDialectModule(dialectNamespace))
141+
return std::nullopt;
170142
const auto foundIt = dialectClassMap.find(dialectNamespace);
171143
if (foundIt != dialectClassMap.end()) {
172-
if (foundIt->second.is_none())
173-
return std::nullopt;
174-
assert(foundIt->second && "py::object is defined");
144+
assert(foundIt->second && "dialect class is defined");
175145
return foundIt->second;
176146
}
177-
178-
// Not found and loading did not yield a registration. Negative cache.
179-
dialectClassMap[dialectNamespace] = py::none();
147+
// Not found and loading did not yield a registration.
180148
return std::nullopt;
181149
}
182150

183151
std::optional<pybind11::object>
184152
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
185-
{
186-
auto foundIt = operationClassMapCache.find(operationName);
187-
if (foundIt != operationClassMapCache.end()) {
188-
if (foundIt->second.is_none())
189-
return std::nullopt;
190-
assert(foundIt->second && "py::object is defined");
191-
return foundIt->second;
192-
}
193-
}
194-
195-
// Not found. Load the dialect namespace.
153+
// Make sure dialect module is loaded.
196154
auto split = operationName.split('.');
197155
llvm::StringRef dialectNamespace = split.first;
198-
loadDialectModule(dialectNamespace);
199-
200-
// Attempt to find from the canonical map and cache.
201-
{
202-
auto foundIt = operationClassMap.find(operationName);
203-
if (foundIt != operationClassMap.end()) {
204-
if (foundIt->second.is_none())
205-
return std::nullopt;
206-
assert(foundIt->second && "py::object is defined");
207-
// Positive cache.
208-
operationClassMapCache[operationName] = foundIt->second;
209-
return foundIt->second;
210-
}
211-
// Negative cache.
212-
operationClassMap[operationName] = py::none();
156+
if (!loadDialectModule(dialectNamespace))
213157
return std::nullopt;
214-
}
215-
}
216158

217-
void PyGlobals::clearImportCache() {
218-
loadedDialectModulesCache.clear();
219-
operationClassMapCache.clear();
220-
typeCasterMapCache.clear();
159+
auto foundIt = operationClassMap.find(operationName);
160+
if (foundIt != operationClassMap.end()) {
161+
assert(foundIt->second && "OpView is defined");
162+
return foundIt->second;
163+
}
164+
// Not found and loading did not yield a registration.
165+
return std::nullopt;
221166
}

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include <tuple>
10-
119
#include "PybindUtils.h"
1210

1311
#include "Globals.h"
1412
#include "IRModule.h"
1513
#include "Pass.h"
1614

15+
#include <tuple>
16+
1717
namespace py = pybind11;
1818
using namespace mlir;
1919
using namespace py::literals;
@@ -34,9 +34,14 @@ PYBIND11_MODULE(_mlir, m) {
3434
"append_dialect_search_prefix",
3535
[](PyGlobals &self, std::string moduleName) {
3636
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
37-
self.clearImportCache();
3837
},
3938
"module_name"_a)
39+
.def(
40+
"_check_dialect_module_loaded",
41+
[](PyGlobals &self, const std::string &dialectNamespace) {
42+
return self.loadDialectModule(dialectNamespace);
43+
},
44+
"dialect_namespace"_a)
4045
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
4146
"dialect_namespace"_a, "dialect_class"_a,
4247
"Testing hook for directly registering a dialect")

mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class _Globals:
77
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
88
def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
99
def append_dialect_search_prefix(self, module_name: str) -> None: ...
10+
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
1011

1112
def register_dialect(dialect_class: type) -> object: ...
1213
def register_operation(dialect_class: type) -> object: ...

mlir/test/python/ir/custom_dialect/__init__.py

Whitespace-only changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# The purpose of this empty dialect module is to enable successfully loading the "custom" dialect.
2+
# Without this file here (and a corresponding _cext.globals.append_dialect_search_prefix("custom_dialect")),
3+
# PyGlobals::loadDialectModule would search and fail to find the "custom" dialect for each Operation.create("custom.op")
4+
# (amongst other things).
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
config.excludes.add("__init__.py")
2+
config.excludes.add("custom.py")

mlir/test/python/ir/dialects.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4+
import sys
45
from mlir.ir import *
6+
from mlir.dialects._ods_common import _cext
57

68

79
def run(f):
@@ -104,3 +106,18 @@ def testIsRegisteredOperation():
104106
print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
105107
# CHECK: func.not_existing: False
106108
print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")
109+
110+
111+
# CHECK-LABEL: TEST: testAppendPrefixSearchPath
112+
@run
113+
def testAppendPrefixSearchPath():
114+
ctx = Context()
115+
ctx.allow_unregistered_dialects = True
116+
with Location.unknown(ctx):
117+
assert not _cext.globals._check_dialect_module_loaded("custom")
118+
Operation.create("custom.op")
119+
assert not _cext.globals._check_dialect_module_loaded("custom")
120+
121+
sys.path.append(".")
122+
_cext.globals.append_dialect_search_prefix("custom_dialect")
123+
assert _cext.globals._check_dialect_module_loaded("custom")

mlir/test/python/ir/insertion_point.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4-
import io
5-
import itertools
64
from mlir.ir import *
75

86

0 commit comments

Comments
 (0)