-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Improve MLIR attribute get() method efficiency when complex members are involved #68067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir ChangesThis ensures that the proper forward/move are involved, we go from 6 copy-construction to 0 (!) when building without assertions. Full diff: https://github.com/llvm/llvm-project/pull/68067.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index c466e230d341d3e..982d5220ab52ce9 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -175,11 +175,11 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// function is guaranteed to return a non null object and will assert if
/// the arguments provided are invalid.
template <typename... Args>
- static ConcreteT get(MLIRContext *ctx, Args... args) {
+ static ConcreteT get(MLIRContext *ctx, Args &&...args) {
// Ensure that the invariants are correct for construction.
assert(
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
- return UniquerT::template get<ConcreteT>(ctx, args...);
+ return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}
/// Get or create a new ConcreteT instance within the ctx, defined at
@@ -187,8 +187,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// invalid, errors are emitted using the provided location and a null object
/// is returned.
template <typename... Args>
- static ConcreteT getChecked(const Location &loc, Args... args) {
- return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
+ static ConcreteT getChecked(const Location &loc, Args &&...args) {
+ return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
+ std::forward<Args>(args)...);
}
/// Get or create a new ConcreteT instance within the ctx. If the arguments
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 13359bf91f40d17..baaedc47dcb2cd5 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -16,6 +16,7 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Allocator.h"
+#include <utility>
namespace mlir {
namespace detail {
@@ -300,9 +301,9 @@ class StorageUniquer {
static typename ImplTy::KeyTy getKey(Args &&...args) {
if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy,
Args...>::value)
- return ImplTy::getKey(args...);
+ return ImplTy::getKey(std::forward<Args>(args)...);
else
- return typename ImplTy::KeyTy(args...);
+ return typename ImplTy::KeyTy(std::forward<Args>(args)...);
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ec0a5548a160338..945c54c04d47ce8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -323,5 +323,16 @@ def Test_IteratorTypeArrayAttr
: TypedArrayAttrBase<Test_IteratorTypeEnum,
"Iterator type should be an enum.">;
+def TestParamCopyCount : AttrParameter<"CopyCount", "", "const CopyCount &"> {}
+
+// Test overridding attribute builders with a custom builder.
+def TestCopyCount : Test_Attr<"TestCopyCount"> {
+ let mnemonic = "copy_count";
+ let parameters = (ins TestParamCopyCount:$copy_count);
+ let assemblyFormat = "`<` $copy_count `>`";
+}
+
+
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 7fc2e6ab3ec0a0a..c240354e5d99044 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace test;
@@ -175,6 +176,45 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
p << (*result ? "true" : "false");
}
+//===----------------------------------------------------------------------===//
+// CopyCountAttr Implementation
+//===----------------------------------------------------------------------===//
+
+CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
+ CopyCount::counter++;
+}
+
+CopyCount &CopyCount::operator=(const CopyCount &rhs) {
+ CopyCount::counter++;
+ value = rhs.value;
+ return *this;
+}
+
+int CopyCount::counter;
+
+static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
+ return lhs.value == rhs.value;
+}
+
+llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
+ const test::CopyCount &value) {
+ return os << value.value;
+}
+
+template <>
+struct mlir::FieldParser<test::CopyCount> {
+ static FailureOr<test::CopyCount> parse(AsmParser &parser) {
+ std::string value;
+ if (parser.parseKeyword(value))
+ return failure();
+ return test::CopyCount(value);
+ }
+};
+namespace test {
+llvm::hash_code hash_value(const test::CopyCount ©Count) {
+ return llvm::hash_value(copyCount.value);
+}
+} // namespace test
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index cc73e078bf7e20b..ef6eae51fdd628a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -29,6 +29,19 @@
namespace test {
class TestDialect;
+// Payload class for the CopyCountAttr.
+class CopyCount {
+public:
+ CopyCount(std::string value) : value(value) {}
+ CopyCount(const CopyCount &rhs);
+ CopyCount &operator=(const CopyCount &rhs);
+ CopyCount(CopyCount &&rhs) = default;
+ CopyCount &operator=(CopyCount &&rhs) = default;
+ static int counter;
+ std::string value;
+};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const test::CopyCount &value);
/// A handle used to reference external elements instances.
using TestDialectResourceBlobHandle =
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index f6e43d42d29f069..f14d33c7d13d310 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -353,7 +353,7 @@ void DefGen::emitDefaultBuilder() {
MethodBody &body = m->body().indent();
auto scope = body.scope("return Base::get(context", ");");
for (const auto ¶m : params)
- body << ", " << param.getName();
+ body << ", std::move(" << param.getName() << ")";
}
void DefGen::emitCheckedBuilder() {
@@ -474,8 +474,10 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
void DefGen::emitStorageConstructor() {
Constructor *ctor =
storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
- for (auto ¶m : params)
- ctor->addMemberInitializer(param.getName(), param.getName());
+ for (auto ¶m : params) {
+ std::string movedValue = ("std::move(" + param.getName() + ")").str();
+ ctor->addMemberInitializer(param.getName(), movedValue);
+ }
}
void DefGen::emitKeyType() {
@@ -525,11 +527,11 @@ void DefGen::emitConstruct() {
: Method::Static,
MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
"allocator"),
- MethodParameter("const KeyTy &", "tblgenKey"));
+ MethodParameter("KeyTy &", "tblgenKey"));
if (!def.hasStorageCustomConstructor()) {
auto &body = construct->body().indent();
for (const auto &it : llvm::enumerate(params)) {
- body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
+ body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
it.value().getName(), it.index());
}
// Use the parameters' custom allocator code, if provided.
@@ -544,8 +546,9 @@ void DefGen::emitConstruct() {
body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
def.getStorageClassName()),
");");
- llvm::interleaveComma(params, body,
- [&](auto ¶m) { body << param.getName(); });
+ llvm::interleaveComma(params, body, [&](auto ¶m) {
+ body << "std::move(" << param.getName() << ")";
+ });
}
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 9afbce037b408c0..6307a10bad4cd93 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -13,6 +13,8 @@
#include "gtest/gtest.h"
#include <optional>
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+
using namespace mlir;
using namespace mlir::detail;
@@ -459,4 +461,25 @@ TEST(SubElementTest, Nested) {
ArrayRef<Attribute>(
{strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
}
+
+// Test how many times we call copy-ctor when building an attribute.
+TEST(CopyCountAttr, CopyCount) {
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+
+ test::CopyCount::counter = 0;
+ test::CopyCount copyCount("hello");
+ test::TestCopyCountAttr::get(&context, std::move(copyCount));
+ int counter1 = test::CopyCount::counter;
+ test::CopyCount::counter = 0;
+ test::TestCopyCountAttr::get(&context, std::move(copyCount));
+#ifndef NDEBUG
+ EXPECT_EQ(counter1, 1);
+ EXPECT_EQ(test::CopyCount::counter, 1);
+#else
+ EXPECT_EQ(counter1, 0);
+ EXPECT_EQ(test::CopyCount::counter, 0);
+#endif
+}
+
} // namespace
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing. LGTM with Alex's comments
…re involved This ensures that the proper forward/move are involved, we go from 6 copy-construction to 0 (!) when building without assertions.
Nice! |
This ensures that the proper forward/move are involved, we go from 6 copy-constructions to 0 (!) when building without assertions.