Skip to content

[MLIR] Add a second map for registered OperationName in MLIRContext (NFC) #87170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2024

Conversation

joker-eph
Copy link
Collaborator

@joker-eph joker-eph commented Mar 30, 2024

This speeds up registered op creation by 10-11% by allowing lookup by
TypeID instead of StringRef.

This can break your build/tests at runtime with an error that you're creating
an unregistered operation that you have registered. If so you are likely using
a class inheriting from the "real" operation. See for example in this patch the
case of:

class ConstantIndexOp : public arith::ConstantOp {

If one is using builder.create<ConstantIndexOp>() they actually create an
arith.constant operation, but the builder will fetch the TypeID for
the ConstantIndexOp class which does not correspond to any registered
operation. To fix it the ConstantIndexOp class got this addition:

static ::mlir::TypeID resolveTypeID() { return TypeID::get(); }

@joker-eph joker-eph requested a review from Mogball March 30, 2024 21:08
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 30, 2024

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This speeds up registered op creation by 10-11% by allowing lookup by TypeID instead of StringRef.


Full diff: https://github.com/llvm/llvm-project/pull/87170.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/Builders.h (+1-1)
  • (modified) mlir/include/mlir/IR/OperationSupport.h (+5)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+21-7)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 43b6d2b3841690..3beade017d1ab9 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -490,7 +490,7 @@ class OpBuilder : public Builder {
   template <typename OpT>
   RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
     std::optional<RegisteredOperationName> opName =
-        RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
+        RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx);
     if (LLVM_UNLIKELY(!opName)) {
       llvm::report_fatal_error(
           "Building op `" + OpT::getOperationName() +
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index f2aa6cee840308..90e63ff8fcb38f 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -676,6 +676,11 @@ class RegisteredOperationName : public OperationName {
   static std::optional<RegisteredOperationName> lookup(StringRef name,
                                                        MLIRContext *ctx);
 
+  /// Lookup the registered operation information for the given operation.
+  /// Returns std::nullopt if the operation isn't registered.
+  static std::optional<RegisteredOperationName> lookup(TypeID typeID,
+                                                       MLIRContext *ctx);
+
   /// Register a new operation in a Dialect object.
   /// This constructor is used by Dialect objects when they register the list
   /// of operations they contain.
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e1e6d14231d9f1..8a63f7598c90c5 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -183,7 +183,8 @@ class MLIRContextImpl {
   llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
 
   /// A vector of operation info specifically for registered operations.
-  llvm::StringMap<RegisteredOperationName> registeredOperations;
+  llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
+  llvm::StringMap<RegisteredOperationName> registeredOperationsByName;
 
   /// This is a sorted container of registered operations for a deterministic
   /// and efficient `getRegisteredOperations` implementation.
@@ -780,8 +781,8 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
     // Check the registered info map first. In the overwhelmingly common case,
     // the entry will be in here and it also removes the need to acquire any
     // locks.
-    auto registeredIt = ctxImpl.registeredOperations.find(name);
-    if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
+    auto registeredIt = ctxImpl.registeredOperationsByName.find(name);
+    if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
       impl = registeredIt->second.impl;
       return;
     }
@@ -909,10 +910,19 @@ OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
 //===----------------------------------------------------------------------===//
 
 std::optional<RegisteredOperationName>
-RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
+RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
   auto &impl = ctx->getImpl();
-  auto it = impl.registeredOperations.find(name);
+  auto it = impl.registeredOperations.find(typeID);
   if (it != impl.registeredOperations.end())
+    return it->second;
+  return std::nullopt;
+}
+
+std::optional<RegisteredOperationName>
+RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
+  auto &impl = ctx->getImpl();
+  auto it = impl.registeredOperationsByName.find(name);
+  if (it != impl.registeredOperationsByName.end())
     return it->getValue();
   return std::nullopt;
 }
@@ -945,11 +955,15 @@ void RegisteredOperationName::insert(
 
   // Update the registered info for this operation.
   auto emplaced = ctxImpl.registeredOperations.try_emplace(
-      name, RegisteredOperationName(impl));
+      impl->getTypeID(), RegisteredOperationName(impl));
   assert(emplaced.second && "operation name registration must be successful");
+  auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
+      name, RegisteredOperationName(impl));
+  assert(emplacedByName.second &&
+         "operation name registration must be successful");
 
   // Add emplaced operation name to the sorted operations container.
-  RegisteredOperationName &value = emplaced.first->getValue();
+  RegisteredOperationName &value = emplaced.first->second;
   ctxImpl.sortedRegisteredOperations.insert(
       llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
                         [](auto &lhs, auto &rhs) {

@joker-eph joker-eph force-pushed the typeid-opname branch 2 times, most recently from 7744629 to ec2c4c4 Compare March 30, 2024 21:28
…NFC)

This speeds up registered op creation by 10-11% by allowing lookup by
TypeID instead of StringRef.
@joker-eph joker-eph merged commit 82c6eee into llvm:main Mar 31, 2024
@joker-eph joker-eph deleted the typeid-opname branch March 31, 2024 19:28
@stellaraccident
Copy link
Contributor

Nice, thanks!

@jpienaar
Copy link
Member

jpienaar commented Apr 1, 2024

Larger speedup than expected. What is size cost?

OOC could the slower string lookup one be completedly avoided/localized? I'm guessing not without iteration for the case where you only have the string (and not the class, and potentially could be dynamically constructing the string), and in that case the lookup is linear in terms of all op types registered. But that also feels like it should be an exceptional case/one where one could derive a map only if and when needed rather than carrying both. Or am I underestimating the frequency of use?

@joker-eph
Copy link
Collaborator Author

Larger speedup than expected. What is size cost?

Less than the StringMap :)
I looked into removing the StringMap but it's annoyingly used by the C API lookup (queries for "is this op name registered").

The DenseMap storage is a std::pair, so basically 8 bytes for the typeID and the OperationName is also just a 8 bytes pointer to the actual Impl class.
So it's 16B per registered operation, knowing that the map double in size when it fills up to 3/4, and so it is bounded by 16*8N/3 byte per operation (or to be more exact it'll be somewhere between 21B and 42B per registered operation).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:arith mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants