10
10
#include " Globals.h"
11
11
#include " PybindUtils.h"
12
12
13
- #include < optional>
14
- #include < vector>
15
-
16
13
#include " mlir-c/Bindings/Python/Interop.h"
17
14
#include " mlir-c/Support.h"
18
15
16
+ #include < optional>
17
+ #include < vector>
18
+
19
19
namespace py = pybind11;
20
20
using namespace mlir ;
21
21
using namespace mlir ::python;
@@ -36,12 +36,12 @@ PyGlobals::PyGlobals() {
36
36
37
37
PyGlobals::~PyGlobals () { instance = nullptr ; }
38
38
39
- void PyGlobals::loadDialectModule (llvm::StringRef dialectNamespace) {
40
- if (loadedDialectModulesCache .contains (dialectNamespace))
41
- return ;
39
+ bool PyGlobals::loadDialectModule (llvm::StringRef dialectNamespace) {
40
+ if (loadedDialectModules .contains (dialectNamespace))
41
+ return true ;
42
42
// Since re-entrancy is possible, make a copy of the search prefixes.
43
43
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
44
- py::object loaded;
44
+ py::object loaded = py::none () ;
45
45
for (std::string moduleName : localSearchPrefixes) {
46
46
moduleName.push_back (' .' );
47
47
moduleName.append (dialectNamespace.data (), dialectNamespace.size ());
@@ -57,15 +57,18 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
57
57
break ;
58
58
}
59
59
60
+ if (loaded.is_none ())
61
+ return false ;
60
62
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
61
63
// may have occurred, which may do anything.
62
- loadedDialectModulesCache.insert (dialectNamespace);
64
+ loadedDialectModules.insert (dialectNamespace);
65
+ return true ;
63
66
}
64
67
65
68
void PyGlobals::registerAttributeBuilder (const std::string &attributeKind,
66
69
py::function pyFunc, bool replace) {
67
70
py::object &found = attributeBuilderMap[attributeKind];
68
- if (found && !found. is_none () && ! replace) {
71
+ if (found && !replace) {
69
72
throw std::runtime_error ((llvm::Twine (" Attribute builder for '" ) +
70
73
attributeKind +
71
74
" ' is already registered with func: " +
@@ -79,13 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
79
82
pybind11::function typeCaster,
80
83
bool replace) {
81
84
pybind11::object &found = typeCasterMap[mlirTypeID];
82
- if (found && !found.is_none () && !replace)
83
- throw std::runtime_error (" Type caster is already registered" );
85
+ if (found && !replace)
86
+ throw std::runtime_error (" Type caster is already registered with caster: " +
87
+ py::str (found).operator std::string ());
84
88
found = std::move (typeCaster);
85
- const auto foundIt = typeCasterMapCache.find (mlirTypeID);
86
- if (foundIt != typeCasterMapCache.end () && !foundIt->second .is_none ()) {
87
- typeCasterMapCache[mlirTypeID] = found;
88
- }
89
89
}
90
90
91
91
void PyGlobals::registerDialectImpl (const std::string &dialectNamespace,
@@ -108,114 +108,59 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
108
108
.str ());
109
109
}
110
110
found = std::move (pyClass);
111
- auto foundIt = operationClassMapCache.find (operationName);
112
- if (foundIt != operationClassMapCache.end () && !foundIt->second .is_none ()) {
113
- operationClassMapCache[operationName] = found;
114
- }
115
111
}
116
112
117
113
std::optional<py::function>
118
114
PyGlobals::lookupAttributeBuilder (const std::string &attributeKind) {
119
- // Fast match against the class map first (common case).
120
115
const auto foundIt = attributeBuilderMap.find (attributeKind);
121
116
if (foundIt != attributeBuilderMap.end ()) {
122
- if (foundIt->second .is_none ())
123
- return std::nullopt;
124
- assert (foundIt->second && " py::function is defined" );
117
+ assert (foundIt->second && " attribute builder is defined" );
125
118
return foundIt->second ;
126
119
}
127
-
128
- // Not found and loading did not yield a registration. Negative cache.
129
- attributeBuilderMap[attributeKind] = py::none ();
130
120
return std::nullopt;
131
121
}
132
122
133
123
std::optional<py::function> PyGlobals::lookupTypeCaster (MlirTypeID mlirTypeID,
134
124
MlirDialect dialect) {
135
- {
136
- // Fast match against the class map first (common case).
137
- const auto foundIt = typeCasterMapCache.find (mlirTypeID);
138
- if (foundIt != typeCasterMapCache.end ()) {
139
- if (foundIt->second .is_none ())
140
- return std::nullopt;
141
- assert (foundIt->second && " py::function is defined" );
142
- return foundIt->second ;
143
- }
144
- }
145
-
146
- // Not found. Load the dialect namespace.
147
- loadDialectModule (unwrap (mlirDialectGetNamespace (dialect)));
148
-
149
- // Attempt to find from the canonical map and cache.
150
- {
151
- const auto foundIt = typeCasterMap.find (mlirTypeID);
152
- if (foundIt != typeCasterMap.end ()) {
153
- if (foundIt->second .is_none ())
154
- return std::nullopt;
155
- assert (foundIt->second && " py::object is defined" );
156
- // Positive cache.
157
- typeCasterMapCache[mlirTypeID] = foundIt->second ;
158
- return foundIt->second ;
159
- }
160
- // Negative cache.
161
- typeCasterMap[mlirTypeID] = py::none ();
125
+ // Make sure dialect module is loaded.
126
+ if (!loadDialectModule (unwrap (mlirDialectGetNamespace (dialect))))
162
127
return std::nullopt;
128
+
129
+ const auto foundIt = typeCasterMap.find (mlirTypeID);
130
+ if (foundIt != typeCasterMap.end ()) {
131
+ assert (foundIt->second && " type caster is defined" );
132
+ return foundIt->second ;
163
133
}
134
+ return std::nullopt;
164
135
}
165
136
166
137
std::optional<py::object>
167
138
PyGlobals::lookupDialectClass (const std::string &dialectNamespace) {
168
- loadDialectModule (dialectNamespace);
169
- // Fast match against the class map first (common case).
139
+ // Make sure dialect module is loaded.
140
+ if (!loadDialectModule (dialectNamespace))
141
+ return std::nullopt;
170
142
const auto foundIt = dialectClassMap.find (dialectNamespace);
171
143
if (foundIt != dialectClassMap.end ()) {
172
- if (foundIt->second .is_none ())
173
- return std::nullopt;
174
- assert (foundIt->second && " py::object is defined" );
144
+ assert (foundIt->second && " dialect class is defined" );
175
145
return foundIt->second ;
176
146
}
177
-
178
- // Not found and loading did not yield a registration. Negative cache.
179
- dialectClassMap[dialectNamespace] = py::none ();
147
+ // Not found and loading did not yield a registration.
180
148
return std::nullopt;
181
149
}
182
150
183
151
std::optional<pybind11::object>
184
152
PyGlobals::lookupOperationClass (llvm::StringRef operationName) {
185
- {
186
- auto foundIt = operationClassMapCache.find (operationName);
187
- if (foundIt != operationClassMapCache.end ()) {
188
- if (foundIt->second .is_none ())
189
- return std::nullopt;
190
- assert (foundIt->second && " py::object is defined" );
191
- return foundIt->second ;
192
- }
193
- }
194
-
195
- // Not found. Load the dialect namespace.
153
+ // Make sure dialect module is loaded.
196
154
auto split = operationName.split (' .' );
197
155
llvm::StringRef dialectNamespace = split.first ;
198
- loadDialectModule (dialectNamespace);
199
-
200
- // Attempt to find from the canonical map and cache.
201
- {
202
- auto foundIt = operationClassMap.find (operationName);
203
- if (foundIt != operationClassMap.end ()) {
204
- if (foundIt->second .is_none ())
205
- return std::nullopt;
206
- assert (foundIt->second && " py::object is defined" );
207
- // Positive cache.
208
- operationClassMapCache[operationName] = foundIt->second ;
209
- return foundIt->second ;
210
- }
211
- // Negative cache.
212
- operationClassMap[operationName] = py::none ();
156
+ if (!loadDialectModule (dialectNamespace))
213
157
return std::nullopt;
214
- }
215
- }
216
158
217
- void PyGlobals::clearImportCache () {
218
- loadedDialectModulesCache.clear ();
219
- operationClassMapCache.clear ();
220
- typeCasterMapCache.clear ();
159
+ auto foundIt = operationClassMap.find (operationName);
160
+ if (foundIt != operationClassMap.end ()) {
161
+ assert (foundIt->second && " OpView is defined" );
162
+ return foundIt->second ;
163
+ }
164
+ // Not found and loading did not yield a registration.
165
+ return std::nullopt;
221
166
}
0 commit comments