Skip to content

Improve MLIR attribute get() method efficiency when complex members are involved #68067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 4, 2023

Conversation

joker-eph
Copy link
Collaborator

@joker-eph joker-eph commented Oct 3, 2023

This ensures that the proper forward/move are involved, we go from 6 copy-constructions to 0 (!) when building without assertions.

@joker-eph joker-eph requested review from jpienaar and Mogball October 3, 2023 05:33
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 3, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2023

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Changes

This ensures that the proper forward/move are involved, we go from 6 copy-construction to 0 (!) when building without assertions.


Full diff: https://github.com/llvm/llvm-project/pull/68067.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/StorageUniquerSupport.h (+5-4)
  • (modified) mlir/include/mlir/Support/StorageUniquer.h (+3-2)
  • (modified) mlir/test/lib/Dialect/Test/TestAttrDefs.td (+11)
  • (modified) mlir/test/lib/Dialect/Test/TestAttributes.cpp (+40)
  • (modified) mlir/test/lib/Dialect/Test/TestAttributes.h (+13)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+10-7)
  • (modified) mlir/unittests/IR/AttributeTest.cpp (+23)
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index c466e230d341d3e..982d5220ab52ce9 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -175,11 +175,11 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// function is guaranteed to return a non null object and will assert if
   /// the arguments provided are invalid.
   template <typename... Args>
-  static ConcreteT get(MLIRContext *ctx, Args... args) {
+  static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
     assert(
         succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
-    return UniquerT::template get<ConcreteT>(ctx, args...);
+    return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx, defined at
@@ -187,8 +187,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// invalid, errors are emitted using the provided location and a null object
   /// is returned.
   template <typename... Args>
-  static ConcreteT getChecked(const Location &loc, Args... args) {
-    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
+  static ConcreteT getChecked(const Location &loc, Args &&...args) {
+    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
+                                 std::forward<Args>(args)...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx. If the arguments
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 13359bf91f40d17..baaedc47dcb2cd5 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -16,6 +16,7 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Allocator.h"
+#include <utility>
 
 namespace mlir {
 namespace detail {
@@ -300,9 +301,9 @@ class StorageUniquer {
   static typename ImplTy::KeyTy getKey(Args &&...args) {
     if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy,
                                     Args...>::value)
-      return ImplTy::getKey(args...);
+      return ImplTy::getKey(std::forward<Args>(args)...);
     else
-      return typename ImplTy::KeyTy(args...);
+      return typename ImplTy::KeyTy(std::forward<Args>(args)...);
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ec0a5548a160338..945c54c04d47ce8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -323,5 +323,16 @@ def Test_IteratorTypeArrayAttr
     : TypedArrayAttrBase<Test_IteratorTypeEnum,
   "Iterator type should be an enum.">;
 
+def TestParamCopyCount : AttrParameter<"CopyCount", "", "const CopyCount &"> {}
+
+// Test overridding attribute builders with a custom builder.
+def TestCopyCount : Test_Attr<"TestCopyCount"> {
+  let mnemonic = "copy_count";
+  let parameters = (ins TestParamCopyCount:$copy_count);
+  let assemblyFormat = "`<` $copy_count `>`";
+}
+
+
+
 
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 7fc2e6ab3ec0a0a..c240354e5d99044 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 using namespace test;
@@ -175,6 +176,45 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
   p << (*result ? "true" : "false");
 }
 
+//===----------------------------------------------------------------------===//
+// CopyCountAttr Implementation
+//===----------------------------------------------------------------------===//
+
+CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
+  CopyCount::counter++;
+}
+
+CopyCount &CopyCount::operator=(const CopyCount &rhs) {
+  CopyCount::counter++;
+  value = rhs.value;
+  return *this;
+}
+
+int CopyCount::counter;
+
+static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
+  return lhs.value == rhs.value;
+}
+
+llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
+                                    const test::CopyCount &value) {
+  return os << value.value;
+}
+
+template <>
+struct mlir::FieldParser<test::CopyCount> {
+  static FailureOr<test::CopyCount> parse(AsmParser &parser) {
+    std::string value;
+    if (parser.parseKeyword(value))
+      return failure();
+    return test::CopyCount(value);
+  }
+};
+namespace test {
+llvm::hash_code hash_value(const test::CopyCount &copyCount) {
+  return llvm::hash_value(copyCount.value);
+}
+} // namespace test
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index cc73e078bf7e20b..ef6eae51fdd628a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -29,6 +29,19 @@
 
 namespace test {
 class TestDialect;
+// Payload class for the CopyCountAttr.
+class CopyCount {
+public:
+  CopyCount(std::string value) : value(value) {}
+  CopyCount(const CopyCount &rhs);
+  CopyCount &operator=(const CopyCount &rhs);
+  CopyCount(CopyCount &&rhs) = default;
+  CopyCount &operator=(CopyCount &&rhs) = default;
+  static int counter;
+  std::string value;
+};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const test::CopyCount &value);
 
 /// A handle used to reference external elements instances.
 using TestDialectResourceBlobHandle =
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index f6e43d42d29f069..f14d33c7d13d310 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -353,7 +353,7 @@ void DefGen::emitDefaultBuilder() {
   MethodBody &body = m->body().indent();
   auto scope = body.scope("return Base::get(context", ");");
   for (const auto &param : params)
-    body << ", " << param.getName();
+    body << ", std::move(" << param.getName() << ")";
 }
 
 void DefGen::emitCheckedBuilder() {
@@ -474,8 +474,10 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
 void DefGen::emitStorageConstructor() {
   Constructor *ctor =
       storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
-  for (auto &param : params)
-    ctor->addMemberInitializer(param.getName(), param.getName());
+  for (auto &param : params) {
+    std::string movedValue = ("std::move(" + param.getName() + ")").str();
+    ctor->addMemberInitializer(param.getName(), movedValue);
+  }
 }
 
 void DefGen::emitKeyType() {
@@ -525,11 +527,11 @@ void DefGen::emitConstruct() {
                                         : Method::Static,
       MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
                       "allocator"),
-      MethodParameter("const KeyTy &", "tblgenKey"));
+      MethodParameter("KeyTy &", "tblgenKey"));
   if (!def.hasStorageCustomConstructor()) {
     auto &body = construct->body().indent();
     for (const auto &it : llvm::enumerate(params)) {
-      body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
+      body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
                       it.value().getName(), it.index());
     }
     // Use the parameters' custom allocator code, if provided.
@@ -544,8 +546,9 @@ void DefGen::emitConstruct() {
         body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
                           def.getStorageClassName()),
                    ");");
-    llvm::interleaveComma(params, body,
-                          [&](auto &param) { body << param.getName(); });
+    llvm::interleaveComma(params, body, [&](auto &param) {
+      body << "std::move(" << param.getName() << ")";
+    });
   }
 }
 
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 9afbce037b408c0..6307a10bad4cd93 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -13,6 +13,8 @@
 #include "gtest/gtest.h"
 #include <optional>
 
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+
 using namespace mlir;
 using namespace mlir::detail;
 
@@ -459,4 +461,25 @@ TEST(SubElementTest, Nested) {
             ArrayRef<Attribute>(
                 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
 }
+
+// Test how many times we call copy-ctor when building an attribute.
+TEST(CopyCountAttr, CopyCount) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  test::CopyCount::counter = 0;
+  test::CopyCount copyCount("hello");
+  test::TestCopyCountAttr::get(&context, std::move(copyCount));
+  int counter1 = test::CopyCount::counter;
+  test::CopyCount::counter = 0;
+  test::TestCopyCountAttr::get(&context, std::move(copyCount));
+#ifndef NDEBUG
+  EXPECT_EQ(counter1, 1);
+  EXPECT_EQ(test::CopyCount::counter, 1);
+#else
+  EXPECT_EQ(counter1, 0);
+  EXPECT_EQ(test::CopyCount::counter, 0);
+#endif
+}
+
 } // namespace

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing. LGTM with Alex's comments

…re involved

This ensures that the proper forward/move are involved, we go from 6
copy-construction to 0 (!) when building without assertions.
@joker-eph joker-eph merged commit 5fc28e7 into llvm:main Oct 4, 2023
@joker-eph joker-eph deleted the copy-ctor branch October 4, 2023 01:07
@lattner
Copy link
Collaborator

lattner commented Oct 16, 2023

Nice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants