Skip to content

Commit 77eee57

Browse files
committed
[mlir] Refactor DialectRegistry delayed interface support into a general DialectExtension mechanism
The current dialect registry allows for attaching delayed interfaces, that are added to attrs/dialects/ops/etc. when the owning dialect gets loaded. This is clunky for quite a few reasons, e.g. each interface type has a separate tracking structure, and is also quite limiting. This commit refactors this delayed mutation of dialect constructs into a more general DialectExtension mechanism. This mechanism is essentially a registration callback that is invoked when a set of dialects have been loaded. This allows for attaching interfaces directly on the loaded constructs, and also allows for loading new dependent dialects. The latter of which is extremely useful as it will now enable dependent dialects to only apply in the contexts in which they are necessary. For example, a dialect dependency can now be conditional on if a user actually needs the interface that relies on it. Differential Revision: https://reviews.llvm.org/D120367
1 parent 8212b41 commit 77eee57

25 files changed

+481
-388
lines changed

mlir/include/mlir/IR/Dialect.h

Lines changed: 14 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_IR_DIALECT_H
1414
#define MLIR_IR_DIALECT_H
1515

16+
#include "mlir/IR/DialectRegistry.h"
1617
#include "mlir/IR/OperationSupport.h"
1718
#include "mlir/Support/TypeID.h"
1819

@@ -26,11 +27,9 @@ class DialectInterface;
2627
class OpBuilder;
2728
class Type;
2829

29-
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
30-
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
31-
using DialectInterfaceAllocatorFunction =
32-
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
33-
using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
30+
//===----------------------------------------------------------------------===//
31+
// Dialect
32+
//===----------------------------------------------------------------------===//
3433

3534
/// Dialects are groups of MLIR operations, types and attributes, as well as
3635
/// behavior associated with the entire group. For example, hooks into other
@@ -180,6 +179,16 @@ class Dialect {
180179
getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
181180
}
182181

182+
/// Register a dialect interface with this dialect instance.
183+
void addInterface(std::unique_ptr<DialectInterface> interface);
184+
185+
/// Register a set of dialect interfaces with this dialect instance.
186+
template <typename... Args>
187+
void addInterfaces() {
188+
(void)std::initializer_list<int>{
189+
0, (addInterface(std::make_unique<Args>(this)), 0)...};
190+
}
191+
183192
protected:
184193
/// The constructor takes a unique namespace for this dialect as well as the
185194
/// context to bind to.
@@ -218,15 +227,6 @@ class Dialect {
218227
/// Enable support for unregistered types.
219228
void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
220229

221-
/// Register a dialect interface with this dialect instance.
222-
void addInterface(std::unique_ptr<DialectInterface> interface);
223-
224-
/// Register a set of dialect interfaces with this dialect instance.
225-
template <typename... Args> void addInterfaces() {
226-
(void)std::initializer_list<int>{
227-
0, (addInterface(std::make_unique<Args>(this)), 0)...};
228-
}
229-
230230
private:
231231
Dialect(const Dialect &) = delete;
232232
void operator=(Dialect &) = delete;
@@ -274,168 +274,6 @@ class Dialect {
274274
friend class MLIRContext;
275275
};
276276

277-
/// The DialectRegistry maps a dialect namespace to a constructor for the
278-
/// matching dialect.
279-
/// This allows for decoupling the list of dialects "available" from the
280-
/// dialects loaded in the Context. The parser in particular will lazily load
281-
/// dialects in the Context as operations are encountered.
282-
class DialectRegistry {
283-
/// Lists of interfaces that need to be registered when the dialect is loaded.
284-
struct DelayedInterfaces {
285-
/// Dialect interfaces.
286-
SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
287-
dialectInterfaces;
288-
/// Attribute/Operation/Type interfaces.
289-
SmallVector<std::tuple<TypeID, TypeID, ObjectInterfaceAllocatorFunction>, 2>
290-
objectInterfaces;
291-
};
292-
293-
using MapTy =
294-
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
295-
using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
296-
297-
public:
298-
explicit DialectRegistry();
299-
300-
template <typename ConcreteDialect> void insert() {
301-
insert(TypeID::get<ConcreteDialect>(),
302-
ConcreteDialect::getDialectNamespace(),
303-
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
304-
// Just allocate the dialect, the context
305-
// takes ownership of it.
306-
return ctx->getOrLoadDialect<ConcreteDialect>();
307-
})));
308-
}
309-
310-
template <typename ConcreteDialect, typename OtherDialect,
311-
typename... MoreDialects>
312-
void insert() {
313-
insert<ConcreteDialect>();
314-
insert<OtherDialect, MoreDialects...>();
315-
}
316-
317-
/// Add a new dialect constructor to the registry. The constructor must be
318-
/// calling MLIRContext::getOrLoadDialect in order for the context to take
319-
/// ownership of the dialect and for delayed interface registration to happen.
320-
void insert(TypeID typeID, StringRef name,
321-
const DialectAllocatorFunction &ctor);
322-
323-
/// Return an allocation function for constructing the dialect identified by
324-
/// its namespace, or nullptr if the namespace is not in this registry.
325-
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
326-
327-
// Register all dialects available in the current registry with the registry
328-
// in the provided context.
329-
void appendTo(DialectRegistry &destination) const {
330-
for (const auto &nameAndRegistrationIt : registry)
331-
destination.insert(nameAndRegistrationIt.second.first,
332-
nameAndRegistrationIt.first,
333-
nameAndRegistrationIt.second.second);
334-
// Merge interfaces.
335-
for (auto it : interfaces) {
336-
TypeID dialect = it.first;
337-
auto destInterfaces = destination.interfaces.find(dialect);
338-
if (destInterfaces == destination.interfaces.end()) {
339-
destination.interfaces[dialect] = it.second;
340-
continue;
341-
}
342-
// The destination already has delayed interface registrations for this
343-
// dialect. Merge registrations into the destination registry.
344-
destInterfaces->second.dialectInterfaces.append(
345-
it.second.dialectInterfaces.begin(),
346-
it.second.dialectInterfaces.end());
347-
destInterfaces->second.objectInterfaces.append(
348-
it.second.objectInterfaces.begin(), it.second.objectInterfaces.end());
349-
}
350-
}
351-
352-
/// Return the names of dialects known to this registry.
353-
auto getDialectNames() const {
354-
return llvm::map_range(
355-
registry,
356-
[](const MapTy::value_type &item) -> StringRef { return item.first; });
357-
}
358-
359-
/// Add an interface constructed with the given allocation function to the
360-
/// dialect provided as template parameter. The dialect must be present in
361-
/// the registry.
362-
template <typename DialectTy>
363-
void addDialectInterface(TypeID interfaceTypeID,
364-
DialectInterfaceAllocatorFunction allocator) {
365-
addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
366-
allocator);
367-
}
368-
369-
/// Add an interface to the dialect, both provided as template parameter. The
370-
/// dialect must be present in the registry.
371-
template <typename DialectTy, typename InterfaceTy>
372-
void addDialectInterface() {
373-
addDialectInterface<DialectTy>(
374-
InterfaceTy::getInterfaceID(), [](Dialect *dialect) {
375-
return std::make_unique<InterfaceTy>(dialect);
376-
});
377-
}
378-
379-
/// Add an external op interface model for an op that belongs to a dialect,
380-
/// both provided as template parameters. The dialect must be present in the
381-
/// registry.
382-
template <typename OpTy, typename ModelTy> void addOpInterface() {
383-
StringRef opName = OpTy::getOperationName();
384-
StringRef dialectName = opName.split('.').first;
385-
addObjectInterface(dialectName, TypeID::get<OpTy>(),
386-
ModelTy::Interface::getInterfaceID(),
387-
[](MLIRContext *context) {
388-
OpTy::template attachInterface<ModelTy>(*context);
389-
});
390-
}
391-
392-
/// Add an external attribute interface model for an attribute type `AttrTy`
393-
/// that is going to belong to `DialectTy`. The dialect must be present in the
394-
/// registry.
395-
template <typename DialectTy, typename AttrTy, typename ModelTy>
396-
void addAttrInterface() {
397-
addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
398-
}
399-
400-
/// Add an external type interface model for an type class `TypeTy` that is
401-
/// going to belong to `DialectTy`. The dialect must be present in the
402-
/// registry.
403-
template <typename DialectTy, typename TypeTy, typename ModelTy>
404-
void addTypeInterface() {
405-
addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
406-
}
407-
408-
/// Register any interfaces required for the given dialect (based on its
409-
/// TypeID). Users are not expected to call this directly.
410-
void registerDelayedInterfaces(Dialect *dialect) const;
411-
412-
private:
413-
/// Add an interface constructed with the given allocation function to the
414-
/// dialect identified by its namespace.
415-
void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
416-
const DialectInterfaceAllocatorFunction &allocator);
417-
418-
/// Add an attribute/operation/type interface constructible with the given
419-
/// allocation function to the dialect identified by its namespace.
420-
void addObjectInterface(StringRef dialectName, TypeID objectID,
421-
TypeID interfaceTypeID,
422-
const ObjectInterfaceAllocatorFunction &allocator);
423-
424-
/// Add an external model for an attribute/type interface to the dialect
425-
/// identified by its namespace.
426-
template <typename ObjectTy, typename ModelTy>
427-
void addStorageUserInterface(StringRef dialectName) {
428-
addObjectInterface(dialectName, TypeID::get<ObjectTy>(),
429-
ModelTy::Interface::getInterfaceID(),
430-
[](MLIRContext *context) {
431-
ObjectTy::template attachInterface<ModelTy>(*context);
432-
});
433-
}
434-
435-
MapTy registry;
436-
InterfaceMapTy interfaces;
437-
};
438-
439277
} // namespace mlir
440278

441279
namespace llvm {

0 commit comments

Comments
 (0)