Skip to content

Commit b091701

Browse files
authored
[mlir] Add a method on MLIRContext to retrieve the operations for a given dialect (#112344)
Currently we have `MLIRContext::getRegisteredOperations` which returns all operations for the given context, with the addition of `MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the same for a given dialect class. Closes #111591
1 parent 4091bc6 commit b091701

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ class MLIRContext {
197197
/// operations.
198198
ArrayRef<RegisteredOperationName> getRegisteredOperations();
199199

200+
/// Return a sorted array containing the information for registered operations
201+
/// filtered by dialect name.
202+
ArrayRef<RegisteredOperationName>
203+
getRegisteredOperationsByDialect(StringRef dialectName);
204+
200205
/// Return true if this operation name is registered in this context.
201206
bool isOperationRegistered(StringRef name);
202207

mlir/lib/IR/MLIRContext.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,30 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
711711
return impl->sortedRegisteredOperations;
712712
}
713713

714+
/// Return information for registered operations by dialect.
715+
ArrayRef<RegisteredOperationName>
716+
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
717+
auto lowerBound =
718+
std::lower_bound(impl->sortedRegisteredOperations.begin(),
719+
impl->sortedRegisteredOperations.end(), dialectName,
720+
[](auto &lhs, auto &rhs) {
721+
return lhs.getDialect().getNamespace().compare(rhs);
722+
});
723+
724+
if (lowerBound == impl->sortedRegisteredOperations.end() ||
725+
lowerBound->getDialect().getNamespace() != dialectName)
726+
return ArrayRef<RegisteredOperationName>();
727+
728+
auto upperBound =
729+
std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(),
730+
dialectName, [](auto &lhs, auto &rhs) {
731+
return lhs.compare(rhs.getDialect().getNamespace());
732+
});
733+
734+
size_t count = std::distance(lowerBound, upperBound);
735+
return ArrayRef(&*lowerBound, count);
736+
}
737+
714738
bool MLIRContext::isOperationRegistered(StringRef name) {
715739
return RegisteredOperationName::lookup(name, this).has_value();
716740
}

0 commit comments

Comments
 (0)