Skip to content

Commit 5fc28e7

Browse files
authored
Improve MLIR Attribute::get() method efficiency by reducing the amount of argument copies (#68067)
This ensures that the proper std::forward/std::move are involved, we go from 6 copy-constructions to 0 (!) on Attribute creation in release builds.
1 parent 6a621ed commit 5fc28e7

File tree

8 files changed

+109
-16
lines changed

8 files changed

+109
-16
lines changed

mlir/include/mlir/IR/StorageUniquerSupport.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,21 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
175175
/// function is guaranteed to return a non null object and will assert if
176176
/// the arguments provided are invalid.
177177
template <typename... Args>
178-
static ConcreteT get(MLIRContext *ctx, Args... args) {
178+
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
179179
// Ensure that the invariants are correct for construction.
180180
assert(
181181
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
182-
return UniquerT::template get<ConcreteT>(ctx, args...);
182+
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
183183
}
184184

185185
/// Get or create a new ConcreteT instance within the ctx, defined at
186186
/// the given, potentially unknown, location. If the arguments provided are
187187
/// invalid, errors are emitted using the provided location and a null object
188188
/// is returned.
189189
template <typename... Args>
190-
static ConcreteT getChecked(const Location &loc, Args... args) {
191-
return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
190+
static ConcreteT getChecked(const Location &loc, Args &&...args) {
191+
return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
192+
std::forward<Args>(args)...);
192193
}
193194

194195
/// Get or create a new ConcreteT instance within the ctx. If the arguments

mlir/include/mlir/Support/StorageUniquer.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/ADT/DenseSet.h"
1717
#include "llvm/ADT/StringRef.h"
1818
#include "llvm/Support/Allocator.h"
19+
#include <utility>
1920

2021
namespace mlir {
2122
namespace detail {
@@ -207,7 +208,7 @@ class StorageUniquer {
207208

208209
// Generate a constructor function for the derived storage.
209210
auto ctorFn = [&](StorageAllocator &allocator) {
210-
auto *storage = Storage::construct(allocator, derivedKey);
211+
auto *storage = Storage::construct(allocator, std::move(derivedKey));
211212
if (initFn)
212213
initFn(storage);
213214
return storage;
@@ -300,9 +301,9 @@ class StorageUniquer {
300301
static typename ImplTy::KeyTy getKey(Args &&...args) {
301302
if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy,
302303
Args...>::value)
303-
return ImplTy::getKey(args...);
304+
return ImplTy::getKey(std::forward<Args>(args)...);
304305
else
305-
return typename ImplTy::KeyTy(args...);
306+
return typename ImplTy::KeyTy(std::forward<Args>(args)...);
306307
}
307308

308309
//===--------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,5 +323,16 @@ def Test_IteratorTypeArrayAttr
323323
: TypedArrayAttrBase<Test_IteratorTypeEnum,
324324
"Iterator type should be an enum.">;
325325

326+
def TestParamCopyCount : AttrParameter<"CopyCount", "", "const CopyCount &"> {}
327+
328+
// Test overridding attribute builders with a custom builder.
329+
def TestCopyCount : Test_Attr<"TestCopyCount"> {
330+
let mnemonic = "copy_count";
331+
let parameters = (ins TestParamCopyCount:$copy_count);
332+
let assemblyFormat = "`<` $copy_count `>`";
333+
}
334+
335+
336+
326337

327338
#endif // TEST_ATTRDEFS

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/TypeSwitch.h"
2323
#include "llvm/ADT/bit.h"
2424
#include "llvm/Support/ErrorHandling.h"
25+
#include "llvm/Support/raw_ostream.h"
2526

2627
using namespace mlir;
2728
using namespace test;
@@ -175,6 +176,45 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
175176
p << (*result ? "true" : "false");
176177
}
177178

179+
//===----------------------------------------------------------------------===//
180+
// CopyCountAttr Implementation
181+
//===----------------------------------------------------------------------===//
182+
183+
CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
184+
CopyCount::counter++;
185+
}
186+
187+
CopyCount &CopyCount::operator=(const CopyCount &rhs) {
188+
CopyCount::counter++;
189+
value = rhs.value;
190+
return *this;
191+
}
192+
193+
int CopyCount::counter;
194+
195+
static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
196+
return lhs.value == rhs.value;
197+
}
198+
199+
llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
200+
const test::CopyCount &value) {
201+
return os << value.value;
202+
}
203+
204+
template <>
205+
struct mlir::FieldParser<test::CopyCount> {
206+
static FailureOr<test::CopyCount> parse(AsmParser &parser) {
207+
std::string value;
208+
if (parser.parseKeyword(value))
209+
return failure();
210+
return test::CopyCount(value);
211+
}
212+
};
213+
namespace test {
214+
llvm::hash_code hash_value(const test::CopyCount &copyCount) {
215+
return llvm::hash_value(copyCount.value);
216+
}
217+
} // namespace test
178218
//===----------------------------------------------------------------------===//
179219
// Tablegen Generated Definitions
180220
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestAttributes.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@
2929

3030
namespace test {
3131
class TestDialect;
32+
// Payload class for the CopyCountAttr.
33+
class CopyCount {
34+
public:
35+
CopyCount(std::string value) : value(value) {}
36+
CopyCount(const CopyCount &rhs);
37+
CopyCount &operator=(const CopyCount &rhs);
38+
CopyCount(CopyCount &&rhs) = default;
39+
CopyCount &operator=(CopyCount &&rhs) = default;
40+
static int counter;
41+
std::string value;
42+
};
43+
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
44+
const test::CopyCount &value);
3245

3346
/// A handle used to reference external elements instances.
3447
using TestDialectResourceBlobHandle =

mlir/test/mlir-tblgen/attrdefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
8282
// Check that AttributeSelfTypeParameter is handled properly.
8383
// DEF-LABEL: struct CompoundAAttrStorage
8484
// DEF: CompoundAAttrStorage(
85-
// DEF-SAME: inner(inner)
85+
// DEF-SAME: inner(std::move(inner))
8686

8787
// DEF: bool operator==(const KeyTy &tblgenKey) const {
8888
// DEF-NEXT: return
@@ -94,7 +94,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
9494

9595
// DEF: static CompoundAAttrStorage *construct
9696
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
97-
// DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
97+
// DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
9898

9999
// DEF: ::mlir::Type CompoundAAttr::getInner() const {
100100
// DEF-NEXT: return getImpl()->inner;

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ void DefGen::emitDefaultBuilder() {
353353
MethodBody &body = m->body().indent();
354354
auto scope = body.scope("return Base::get(context", ");");
355355
for (const auto &param : params)
356-
body << ", " << param.getName();
356+
body << ", std::move(" << param.getName() << ")";
357357
}
358358

359359
void DefGen::emitCheckedBuilder() {
@@ -474,8 +474,10 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
474474
void DefGen::emitStorageConstructor() {
475475
Constructor *ctor =
476476
storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
477-
for (auto &param : params)
478-
ctor->addMemberInitializer(param.getName(), param.getName());
477+
for (auto &param : params) {
478+
std::string movedValue = ("std::move(" + param.getName() + ")").str();
479+
ctor->addMemberInitializer(param.getName(), movedValue);
480+
}
479481
}
480482

481483
void DefGen::emitKeyType() {
@@ -525,11 +527,11 @@ void DefGen::emitConstruct() {
525527
: Method::Static,
526528
MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
527529
"allocator"),
528-
MethodParameter("const KeyTy &", "tblgenKey"));
530+
MethodParameter("KeyTy &&", "tblgenKey"));
529531
if (!def.hasStorageCustomConstructor()) {
530532
auto &body = construct->body().indent();
531533
for (const auto &it : llvm::enumerate(params)) {
532-
body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
534+
body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
533535
it.value().getName(), it.index());
534536
}
535537
// Use the parameters' custom allocator code, if provided.
@@ -544,8 +546,9 @@ void DefGen::emitConstruct() {
544546
body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
545547
def.getStorageClassName()),
546548
");");
547-
llvm::interleaveComma(params, body,
548-
[&](auto &param) { body << param.getName(); });
549+
llvm::interleaveComma(params, body, [&](auto &param) {
550+
body << "std::move(" << param.getName() << ")";
551+
});
549552
}
550553
}
551554

mlir/unittests/IR/AttributeTest.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "gtest/gtest.h"
1414
#include <optional>
1515

16+
#include "../../test/lib/Dialect/Test/TestDialect.h"
17+
1618
using namespace mlir;
1719
using namespace mlir::detail;
1820

@@ -459,4 +461,26 @@ TEST(SubElementTest, Nested) {
459461
ArrayRef<Attribute>(
460462
{strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
461463
}
464+
465+
// Test how many times we call copy-ctor when building an attribute.
466+
TEST(CopyCountAttr, CopyCount) {
467+
MLIRContext context;
468+
context.loadDialect<test::TestDialect>();
469+
470+
test::CopyCount::counter = 0;
471+
test::CopyCount copyCount("hello");
472+
test::TestCopyCountAttr::get(&context, std::move(copyCount));
473+
int counter1 = test::CopyCount::counter;
474+
test::CopyCount::counter = 0;
475+
test::TestCopyCountAttr::get(&context, std::move(copyCount));
476+
#ifndef NDEBUG
477+
// One verification enabled only in assert-mode requires a copy.
478+
EXPECT_EQ(counter1, 1);
479+
EXPECT_EQ(test::CopyCount::counter, 1);
480+
#else
481+
EXPECT_EQ(counter1, 0);
482+
EXPECT_EQ(test::CopyCount::counter, 0);
483+
#endif
484+
}
485+
462486
} // namespace

0 commit comments

Comments
 (0)