@@ -187,7 +187,8 @@ class DialectRegistry {
187
187
nameAndRegistrationIt.second .second );
188
188
// Merge the extensions.
189
189
for (const auto &extension : extensions)
190
- destination.extensions .push_back (extension->clone ());
190
+ destination.extensions .emplace_back (extension.first ,
191
+ extension.second ->clone ());
191
192
}
192
193
193
194
// / Return the names of dialects known to this registry.
@@ -206,47 +207,56 @@ class DialectRegistry {
206
207
void applyExtensions (MLIRContext *ctx) const ;
207
208
208
209
// / Add the given extension to the registry.
209
- void addExtension (std::unique_ptr<DialectExtensionBase> extension) {
210
- extensions.push_back (std::move (extension));
210
+ void addExtension (StringRef extensionID,
211
+ std::unique_ptr<DialectExtensionBase> extension) {
212
+ extensions.emplace_back (extensionID, std::move (extension));
211
213
}
212
214
213
215
// / Add the given extensions to the registry.
214
216
template <typename ... ExtensionsT>
215
217
void addExtensions () {
216
- (addExtension (std::make_unique<ExtensionsT>()), ...);
218
+ (addExtension (ExtensionsT::extensionID, std::make_unique<ExtensionsT>()),
219
+ ...);
217
220
}
218
221
219
222
// / Add an extension function that requires the given dialects.
220
223
// / Note: This bare functor overload is provided in addition to the
221
224
// / std::function variant to enable dialect type deduction, e.g.:
222
- // / registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
225
+ // / registry.addExtension("ID", +[](MLIRContext *ctx, MyDialect *dialect) {
226
+ // / ... })
223
227
// /
224
228
// / is equivalent to:
225
229
// / registry.addExtension<MyDialect>(
230
+ // / "ID",
226
231
// / [](MLIRContext *ctx, MyDialect *dialect){ ... }
227
232
// / )
228
233
template <typename ... DialectsT>
229
- void addExtension (void (*extensionFn)(MLIRContext *, DialectsT *...)) {
234
+ void addExtension (StringRef extensionID,
235
+ void (*extensionFn)(MLIRContext *, DialectsT *...)) {
230
236
addExtension<DialectsT...>(
237
+ extensionID,
231
238
std::function<void (MLIRContext *, DialectsT * ...)>(extensionFn));
232
239
}
233
240
template <typename ... DialectsT>
234
241
void
235
- addExtension (std::function<void (MLIRContext *, DialectsT *...)> extensionFn) {
242
+ addExtension (StringRef extensionID,
243
+ std::function<void (MLIRContext *, DialectsT *...)> extensionFn) {
236
244
using ExtensionFnT = std::function<void (MLIRContext *, DialectsT * ...)>;
237
245
238
246
struct Extension : public DialectExtension <Extension, DialectsT...> {
239
247
Extension (const Extension &) = default ;
240
248
Extension (ExtensionFnT extensionFn)
241
- : extensionFn(std::move(extensionFn)) {}
249
+ : DialectExtension<Extension, DialectsT...>(),
250
+ extensionFn (std::move(extensionFn)) {}
242
251
~Extension () override = default ;
243
252
244
253
void apply (MLIRContext *context, DialectsT *...dialects) const final {
245
254
extensionFn (context, dialects...);
246
255
}
247
256
ExtensionFnT extensionFn;
248
257
};
249
- addExtension (std::make_unique<Extension>(std::move (extensionFn)));
258
+ addExtension (extensionID,
259
+ std::make_unique<Extension>(std::move (extensionFn)));
250
260
}
251
261
252
262
// / Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
@@ -255,7 +265,9 @@ class DialectRegistry {
255
265
256
266
private:
257
267
MapTy registry;
258
- std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
268
+ using KeyExtensionPair =
269
+ std::pair<llvm::StringRef, std::unique_ptr<DialectExtensionBase>>;
270
+ llvm::SmallVector<KeyExtensionPair> extensions;
259
271
};
260
272
261
273
} // namespace mlir
0 commit comments