Skip to content

Commit d0e6fd9

Browse files
committed
[mlir] Extend the promise interface mechanism
This patch pairs a promised interface with the object (Op/Attr/Type/Dialect) requesting the promise, ie: ``` declarePromisedInterface<MyAttr, MyInterface>(); ``` Allowing to make fine grained promises. It also adds a mechanism to query if `Op/Attr/Type` has an specific promise returning true if the promise is there or if an implementation has been added. Finally it adds a couple of `Attr|TypeConstraints` that can be used in ODS to query if the promise or an implementation is there. This patch tries to solve 2 issues: 1. Different entities cannot use the same promise. ``` declarePromisedInterface<MyInterface>(); // Resolves a promise. MyAttr1::attachInterface<MyInterface>(ctx); // Doesn't resolves a promise, as the previous attachment removed the promise. MyAttr2::attachInterface<MyInterface>(ctx); ``` 2. Is not possible to query if a promise has been declared. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D158464
1 parent 5857fe0 commit d0e6fd9

File tree

15 files changed

+168
-26
lines changed

15 files changed

+168
-26
lines changed

mlir/include/mlir/IR/Attributes.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ class Attribute {
8787

8888
friend ::llvm::hash_code hash_value(Attribute arg);
8989

90+
/// Returns true if `InterfaceT` has been promised by the dialect or
91+
/// implemented.
92+
template <typename InterfaceT>
93+
bool hasPromiseOrImplementsInterface() {
94+
return dialect_extension_detail::hasPromisedInterface(
95+
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
96+
mlir::isa<InterfaceT>(*this);
97+
}
98+
9099
/// Returns true if the type was registered with a particular trait.
91100
template <template <typename T> class Trait>
92101
bool hasTrait() {
@@ -289,7 +298,7 @@ class AttributeInterface
289298
// Check that the current interface isn't an unresolved promise for the
290299
// given attribute.
291300
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
292-
attr.getDialect(), ConcreteType::getInterfaceID(),
301+
attr.getDialect(), attr.getTypeID(), ConcreteType::getInterfaceID(),
293302
llvm::getTypeName<ConcreteType>());
294303
#endif
295304

mlir/include/mlir/IR/Dialect.h

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class Dialect {
160160
/// nullptr.
161161
DialectInterface *getRegisteredInterface(TypeID interfaceID) {
162162
#ifndef NDEBUG
163-
handleUseOfUndefinedPromisedInterface(interfaceID);
163+
handleUseOfUndefinedPromisedInterface(getTypeID(), interfaceID);
164164
#endif
165165

166166
auto it = registeredInterfaces.find(interfaceID);
@@ -169,7 +169,8 @@ class Dialect {
169169
template <typename InterfaceT>
170170
InterfaceT *getRegisteredInterface() {
171171
#ifndef NDEBUG
172-
handleUseOfUndefinedPromisedInterface(InterfaceT::getInterfaceID(),
172+
handleUseOfUndefinedPromisedInterface(getTypeID(),
173+
InterfaceT::getInterfaceID(),
173174
llvm::getTypeName<InterfaceT>());
174175
#endif
175176

@@ -209,18 +210,21 @@ class Dialect {
209210
/// registration. The promised interface type can be an interface of any type
210211
/// not just a dialect interface, i.e. it may also be an
211212
/// AttributeInterface/OpInterface/TypeInterface/etc.
212-
template <typename InterfaceT>
213+
template <typename ConcreteT, typename InterfaceT>
213214
void declarePromisedInterface() {
214-
unresolvedPromisedInterfaces.insert(InterfaceT::getInterfaceID());
215+
unresolvedPromisedInterfaces.insert(
216+
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
215217
}
216218

217219
/// Checks if the given interface, which is attempting to be used, is a
218220
/// promised interface of this dialect that has yet to be implemented. If so,
219221
/// emits a fatal error. `interfaceName` is an optional string that contains a
220222
/// more user readable name for the interface (such as the class name).
221-
void handleUseOfUndefinedPromisedInterface(TypeID interfaceID,
223+
void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
224+
TypeID interfaceID,
222225
StringRef interfaceName = "") {
223-
if (unresolvedPromisedInterfaces.count(interfaceID)) {
226+
if (unresolvedPromisedInterfaces.count(
227+
{interfaceRequestorID, interfaceID})) {
224228
llvm::report_fatal_error(
225229
"checking for an interface (`" + interfaceName +
226230
"`) that was promised by dialect '" + getNamespace() +
@@ -229,11 +233,27 @@ class Dialect {
229233
"registered.");
230234
}
231235
}
236+
232237
/// Checks if the given interface, which is attempting to be attached to a
233238
/// construct owned by this dialect, is a promised interface of this dialect
234239
/// that has yet to be implemented. If so, it resolves the interface promise.
235-
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceID) {
236-
unresolvedPromisedInterfaces.erase(interfaceID);
240+
void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
241+
TypeID interfaceID) {
242+
unresolvedPromisedInterfaces.erase({interfaceRequestorID, interfaceID});
243+
}
244+
245+
/// Checks if a promise has been made for the interface/requestor pair.
246+
bool hasPromisedInterface(TypeID interfaceRequestorID,
247+
TypeID interfaceID) const {
248+
return unresolvedPromisedInterfaces.count(
249+
{interfaceRequestorID, interfaceID});
250+
}
251+
252+
/// Checks if a promise has been made for the interface/requestor pair.
253+
template <typename ConcreteT, typename InterfaceT>
254+
bool hasPromisedInterface() const {
255+
return hasPromisedInterface(TypeID::get<ConcreteT>(),
256+
InterfaceT::getInterfaceID());
237257
}
238258

239259
protected:
@@ -332,7 +352,7 @@ class Dialect {
332352
/// A set of interfaces that the dialect (or its constructs, i.e.
333353
/// Attributes/Operations/Types/etc.) has promised to implement, but has yet
334354
/// to provide an implementation for.
335-
DenseSet<TypeID> unresolvedPromisedInterfaces;
355+
DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
336356

337357
friend class DialectRegistry;
338358
friend void registerDialect();

mlir/include/mlir/IR/DialectRegistry.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,29 @@ namespace dialect_extension_detail {
102102
/// Checks if the given interface, which is attempting to be used, is a
103103
/// promised interface of this dialect that has yet to be implemented. If so,
104104
/// emits a fatal error.
105-
void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceID,
105+
void handleUseOfUndefinedPromisedInterface(Dialect &dialect,
106+
TypeID interfaceRequestorID,
107+
TypeID interfaceID,
106108
StringRef interfaceName);
107109

108110
/// Checks if the given interface, which is attempting to be attached, is a
109111
/// promised interface of this dialect that has yet to be implemented. If so,
110112
/// the promised interface is marked as resolved.
111113
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
114+
TypeID interfaceRequestorID,
112115
TypeID interfaceID);
113116

117+
/// Checks if a promise has been made for the interface/requestor pair.
118+
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119+
TypeID interfaceID);
120+
121+
/// Checks if a promise has been made for the interface/requestor pair.
122+
template <typename ConcreteT, typename InterfaceT>
123+
bool hasPromisedInterface(Dialect &dialect) {
124+
return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
125+
InterfaceT::getInterfaceID());
126+
}
127+
114128
} // namespace dialect_extension_detail
115129

116130
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpBase.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,30 @@ class Results<dag rets> {
450450
dag results = rets;
451451
}
452452

453+
//===----------------------------------------------------------------------===//
454+
// Common promised interface constraints
455+
//===----------------------------------------------------------------------===//
456+
457+
// This constrait represents a promise or an implementation of an attr interface.
458+
class PromisedAttrInterface<AttrInterface interface> : AttrConstraint<
459+
CPred<"$_self.hasPromiseOrImplementsInterface<" #
460+
!if(!empty(interface.cppNamespace),
461+
"",
462+
interface.cppNamespace # "::") # interface.cppInterfaceName #">()">,
463+
"promising or implementing the `" # interface.cppInterfaceName # "` attr interface">;
464+
465+
// This predicate checks if the type promises or implementats a type interface.
466+
class HasPromiseOrImplementsTypeInterface<TypeInterface interface> :
467+
CPred<"$_self.hasPromiseOrImplementsInterface<" #
468+
!if(!empty(interface.cppNamespace),
469+
"",
470+
interface.cppNamespace # "::") # interface.cppInterfaceName #">()">;
471+
472+
// This constrait represents a promise or an implementation of a type interface.
473+
class PromisedTypeInterface<TypeInterface interface> : TypeConstraint<
474+
HasPromiseOrImplementsTypeInterface<interface>,
475+
"promising or implementing the `" # interface.cppInterfaceName # "` type interface">;
476+
453477
//===----------------------------------------------------------------------===//
454478
// Common op type constraints
455479
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2075,7 +2075,7 @@ class OpInterface
20752075
// given operation.
20762076
if (Dialect *dialect = name.getDialect()) {
20772077
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
2078-
*dialect, ConcreteType::getInterfaceID(),
2078+
*dialect, name.getTypeID(), ConcreteType::getInterfaceID(),
20792079
llvm::getTypeName<ConcreteType>());
20802080
}
20812081
#endif

mlir/include/mlir/IR/Operation.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,13 @@ class alignas(8) Operation final
698698
/// If folding was unsuccessful, this function returns "failure".
699699
LogicalResult fold(SmallVectorImpl<OpFoldResult> &results);
700700

701+
/// Returns true if `InterfaceT` has been promised by the dialect or
702+
/// implemented.
703+
template <typename InterfaceT>
704+
bool hasPromiseOrImplementsInterface() const {
705+
return name.hasPromiseOrImplementsInterface<InterfaceT>();
706+
}
707+
701708
/// Returns true if the operation was registered with a particular trait, e.g.
702709
/// hasTrait<OperandsAreSignlessIntegerLike>().
703710
template <template <typename T> class Trait>

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,12 +351,21 @@ class OperationName {
351351
void attachInterface() {
352352
// Handle the case where the models resolve a promised interface.
353353
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
354-
*getDialect(), Models::Interface::getInterfaceID()),
354+
*getDialect(), getTypeID(), Models::Interface::getInterfaceID()),
355355
...);
356356

357357
getImpl()->getInterfaceMap().insertModels<Models...>();
358358
}
359359

360+
/// Returns true if `InterfaceT` has been promised by the dialect or
361+
/// implemented.
362+
template <typename InterfaceT>
363+
bool hasPromiseOrImplementsInterface() const {
364+
return dialect_extension_detail::hasPromisedInterface(
365+
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
366+
hasInterface<InterfaceT>();
367+
}
368+
360369
/// Returns true if this operation has the given interface registered to it.
361370
template <typename T>
362371
bool hasInterface() const {

mlir/include/mlir/IR/StorageUniquerSupport.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
163163

164164
// Handle the case where the models resolve a promised interface.
165165
(dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
166-
abstract->getDialect(), IfaceModels::Interface::getInterfaceID()),
166+
abstract->getDialect(), abstract->getTypeID(),
167+
IfaceModels::Interface::getInterfaceID()),
167168
...);
168169

169170
(checkInterfaceTarget<IfaceModels>(), ...);

mlir/include/mlir/IR/Types.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,15 @@ class Type {
180180
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
181181
}
182182

183+
/// Returns true if `InterfaceT` has been promised by the dialect or
184+
/// implemented.
185+
template <typename InterfaceT>
186+
bool hasPromiseOrImplementsInterface() {
187+
return dialect_extension_detail::hasPromisedInterface(
188+
getDialect(), getTypeID(), InterfaceT::getInterfaceID()) ||
189+
mlir::isa<InterfaceT>(*this);
190+
}
191+
183192
/// Returns true if the type was registered with a particular trait.
184193
template <template <typename T> class Trait>
185194
bool hasTrait() {
@@ -274,7 +283,7 @@ class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
274283
// Check that the current interface isn't an unresolved promise for the
275284
// given type.
276285
dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
277-
type.getDialect(), ConcreteType::getInterfaceID(),
286+
type.getDialect(), type.getTypeID(), ConcreteType::getInterfaceID(),
278287
llvm::getTypeName<ConcreteType>());
279288
#endif
280289

mlir/lib/Dialect/Func/IR/FuncOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void FuncDialect::initialize() {
4040
#define GET_OP_LIST
4141
#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
4242
>();
43-
declarePromisedInterface<DialectInlinerInterface>();
43+
declarePromisedInterface<FuncDialect, DialectInlinerInterface>();
4444
}
4545

4646
/// Materialize a single constant operation from a given attribute value with

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,8 @@ void NVVMDialect::initialize() {
994994
// Support unknown operations because not all NVVM operations are
995995
// registered.
996996
allowUnknownOperations();
997-
declarePromisedInterface<ConvertToLLVMPatternInterface>();
998-
declarePromisedInterface<gpu::TargetAttrInterface>();
997+
declarePromisedInterface<NVVMDialect, ConvertToLLVMPatternInterface>();
998+
declarePromisedInterface<NVVMTargetAttr, gpu::TargetAttrInterface>();
999999
}
10001000

10011001
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,

mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ void ROCDLDialect::initialize() {
247247

248248
// Support unknown operations because not all ROCDL operations are registered.
249249
allowUnknownOperations();
250-
declarePromisedInterface<gpu::TargetAttrInterface>();
250+
declarePromisedInterface<ROCDLTargetAttr, gpu::TargetAttrInterface>();
251251
}
252252

253253
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,

mlir/lib/IR/Dialect.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ bool Dialect::isValidNamespace(StringRef str) {
9797
/// Register a set of dialect interfaces with this dialect instance.
9898
void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
9999
// Handle the case where the models resolve a promised interface.
100-
handleAdditionOfUndefinedPromisedInterface(interface->getID());
100+
handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID());
101101

102102
auto it = registeredInterfaces.try_emplace(interface->getID(),
103103
std::move(interface));
@@ -125,8 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
125125
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
126126
for (auto *dialect : ctx->getLoadedDialects()) {
127127
#ifndef NDEBUG
128-
dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
129-
interfaceName);
128+
dialect->handleUseOfUndefinedPromisedInterface(
129+
dialect->getTypeID(), interfaceKind, interfaceName);
130130
#endif
131131
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
132132
interfaces.insert(interface);
@@ -151,13 +151,22 @@ DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
151151
DialectExtensionBase::~DialectExtensionBase() = default;
152152

153153
void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
154-
Dialect &dialect, TypeID interfaceID, StringRef interfaceName) {
155-
dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName);
154+
Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
155+
StringRef interfaceName) {
156+
dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
157+
interfaceID, interfaceName);
156158
}
157159

158160
void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
159-
Dialect &dialect, TypeID interfaceID) {
160-
dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID);
161+
Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
162+
dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
163+
interfaceID);
164+
}
165+
166+
bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
167+
TypeID interfaceRequestorID,
168+
TypeID interfaceID) {
169+
return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
161170
}
162171

163172
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,20 @@ def DenseArrayNonNegativeOp : TEST_Op<"confined_non_negative_attr"> {
368368
);
369369
}
370370

371+
//===----------------------------------------------------------------------===//
372+
// Test Promised Interfaces Constraints
373+
//===----------------------------------------------------------------------===//
374+
375+
def PromisedInterfacesOp : TEST_Op<"promised_interfaces"> {
376+
let arguments = (ins
377+
ConfinedAttr<AnyAttr,
378+
[PromisedAttrInterface<TestExternalAttrInterface>]>:$promisedAttr,
379+
ConfinedType<AnyType,
380+
[HasPromiseOrImplementsTypeInterface<TestExternalTypeInterface>]
381+
>:$promisedType
382+
);
383+
}
384+
371385
//===----------------------------------------------------------------------===//
372386
// Test Enum Attributes
373387
//===----------------------------------------------------------------------===//

mlir/unittests/IR/InterfaceAttachmentTest.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,4 +417,30 @@ TEST(InterfaceAttachment, OperationDelayedContextAppend) {
417417
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
418418
}
419419

420+
TEST(InterfaceAttachmentTest, PromisedInterfaces) {
421+
// Attribute interfaces use the exact same mechanism as types, so just check
422+
// that the promise mechanism works for attributes.
423+
MLIRContext context;
424+
auto testDialect = context.getOrLoadDialect<test::TestDialect>();
425+
auto attr = test::SimpleAAttr::get(&context);
426+
427+
// `SimpleAAttr` doesn't implement nor promises the
428+
// `TestExternalAttrInterface` interface.
429+
EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
430+
EXPECT_FALSE(
431+
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
432+
433+
// Add a promise `TestExternalAttrInterface`.
434+
testDialect->declarePromisedInterface<test::SimpleAAttr,
435+
TestExternalAttrInterface>();
436+
EXPECT_TRUE(
437+
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
438+
439+
// Attach the interface.
440+
test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
441+
EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
442+
EXPECT_TRUE(
443+
attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
444+
}
445+
420446
} // namespace

0 commit comments

Comments
 (0)