Skip to content

Commit c653279

Browse files
committed
DialectRegistry: add extension ID
1 parent cfb92be commit c653279

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

mlir/include/mlir/IR/DialectRegistry.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ class DialectRegistry {
187187
nameAndRegistrationIt.second.second);
188188
// Merge the extensions.
189189
for (const auto &extension : extensions)
190-
destination.extensions.push_back(extension->clone());
190+
destination.extensions.emplace_back(extension.first,
191+
extension.second->clone());
191192
}
192193

193194
/// Return the names of dialects known to this registry.
@@ -206,47 +207,56 @@ class DialectRegistry {
206207
void applyExtensions(MLIRContext *ctx) const;
207208

208209
/// Add the given extension to the registry.
209-
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
210-
extensions.push_back(std::move(extension));
210+
void addExtension(StringRef extensionID,
211+
std::unique_ptr<DialectExtensionBase> extension) {
212+
extensions.emplace_back(extensionID, std::move(extension));
211213
}
212214

213215
/// Add the given extensions to the registry.
214216
template <typename... ExtensionsT>
215217
void addExtensions() {
216-
(addExtension(std::make_unique<ExtensionsT>()), ...);
218+
(addExtension(ExtensionsT::extensionID, std::make_unique<ExtensionsT>()),
219+
...);
217220
}
218221

219222
/// Add an extension function that requires the given dialects.
220223
/// Note: This bare functor overload is provided in addition to the
221224
/// std::function variant to enable dialect type deduction, e.g.:
222-
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
225+
/// registry.addExtension("ID", +[](MLIRContext *ctx, MyDialect *dialect) {
226+
/// ... })
223227
///
224228
/// is equivalent to:
225229
/// registry.addExtension<MyDialect>(
230+
/// "ID",
226231
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
227232
/// )
228233
template <typename... DialectsT>
229-
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
234+
void addExtension(StringRef extensionID,
235+
void (*extensionFn)(MLIRContext *, DialectsT *...)) {
230236
addExtension<DialectsT...>(
237+
extensionID,
231238
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
232239
}
233240
template <typename... DialectsT>
234241
void
235-
addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
242+
addExtension(StringRef extensionID,
243+
std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
236244
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
237245

238246
struct Extension : public DialectExtension<Extension, DialectsT...> {
239247
Extension(const Extension &) = default;
240248
Extension(ExtensionFnT extensionFn)
241-
: extensionFn(std::move(extensionFn)) {}
249+
: DialectExtension<Extension, DialectsT...>(),
250+
extensionFn(std::move(extensionFn)) {}
242251
~Extension() override = default;
243252

244253
void apply(MLIRContext *context, DialectsT *...dialects) const final {
245254
extensionFn(context, dialects...);
246255
}
247256
ExtensionFnT extensionFn;
248257
};
249-
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
258+
addExtension(extensionID,
259+
std::make_unique<Extension>(std::move(extensionFn)));
250260
}
251261

252262
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
@@ -255,7 +265,9 @@ class DialectRegistry {
255265

256266
private:
257267
MapTy registry;
258-
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
268+
using KeyExtensionPair =
269+
std::pair<llvm::StringRef, std::unique_ptr<DialectExtensionBase>>;
270+
llvm::SmallVector<KeyExtensionPair> extensions;
259271
};
260272

261273
} // namespace mlir

mlir/lib/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
260260

261261
// Note: Additional extensions may be added while applying an extension.
262262
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
263-
applyExtension(*extensions[i]);
263+
applyExtension(*extensions[i].second);
264264
}
265265

266266
void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@@ -287,7 +287,7 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
287287

288288
// Note: Additional extensions may be added while applying an extension.
289289
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
290-
applyExtension(*extensions[i]);
290+
applyExtension(*extensions[i].second);
291291
}
292292

293293
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {

0 commit comments

Comments
 (0)