Skip to content

Commit 11067d7

Browse files
committed
[mlir] Optimize OperationName construction and usage
When constructing an OperationName, the overwhelming majority of cases are from registered operations. This revision adds a non-locked lookup into the currently registered operations, which prevents locking in the common case. This revision also optimizes several uses of RegisteredOperationName that expect the operation to be registered, e.g. such as in OpBuilder. These changes provides a reasonable speedup (5-10%) in some compilations, especially on platforms where locking is expensive. Differential Revision: https://reviews.llvm.org/D117187
1 parent a97e20a commit 11067d7

File tree

6 files changed

+55
-23
lines changed

6 files changed

+55
-23
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,22 +406,27 @@ class OpBuilder : public Builder {
406406

407407
private:
408408
/// Helper for sanity checking preconditions for create* methods below.
409-
void checkHasRegisteredInfo(const OperationName &name) {
410-
if (LLVM_UNLIKELY(!name.isRegistered()))
409+
template <typename OpT>
410+
RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
411+
Optional<RegisteredOperationName> opName =
412+
RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
413+
if (LLVM_UNLIKELY(!opName)) {
411414
llvm::report_fatal_error(
412-
"Building op `" + name.getStringRef() +
415+
"Building op `" + OpT::getOperationName() +
413416
"` but it isn't registered in this MLIRContext: the dialect may not "
414417
"be loaded or this operation isn't registered by the dialect. See "
415418
"also https://mlir.llvm.org/getting_started/Faq/"
416419
"#registered-loaded-dependent-whats-up-with-dialects-management");
420+
}
421+
return *opName;
417422
}
418423

419424
public:
420425
/// Create an operation of specific op type at the current insertion point.
421426
template <typename OpTy, typename... Args>
422427
OpTy create(Location location, Args &&...args) {
423-
OperationState state(location, OpTy::getOperationName());
424-
checkHasRegisteredInfo(state.name);
428+
OperationState state(location,
429+
getCheckRegisteredInfo<OpTy>(location.getContext()));
425430
OpTy::build(*this, state, std::forward<Args>(args)...);
426431
auto *op = createOperation(state);
427432
auto result = dyn_cast<OpTy>(op);
@@ -437,8 +442,8 @@ class OpBuilder : public Builder {
437442
Args &&...args) {
438443
// Create the operation without using 'createOperation' as we don't want to
439444
// insert it yet.
440-
OperationState state(location, OpTy::getOperationName());
441-
checkHasRegisteredInfo(state.name);
445+
OperationState state(location,
446+
getCheckRegisteredInfo<OpTy>(location.getContext()));
442447
OpTy::build(*this, state, std::forward<Args>(args)...);
443448
Operation *op = Operation::create(state);
444449

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,7 @@ class RegisteredOperationName : public OperationName {
231231
/// Lookup the registered operation information for the given operation.
232232
/// Returns None if the operation isn't registered.
233233
static Optional<RegisteredOperationName> lookup(StringRef name,
234-
MLIRContext *ctx) {
235-
return OperationName(name, ctx).getRegisteredInfo();
236-
}
234+
MLIRContext *ctx);
237235

238236
/// Register a new operation in a Dialect object.
239237
/// This constructor is used by Dialect objects when they register the list of
@@ -582,9 +580,12 @@ struct OperationState {
582580

583581
public:
584582
OperationState(Location location, StringRef name);
585-
586583
OperationState(Location location, OperationName name);
587584

585+
OperationState(Location location, OperationName name, ValueRange operands,
586+
TypeRange types, ArrayRef<NamedAttribute> attributes,
587+
BlockRange successors = {},
588+
MutableArrayRef<std::unique_ptr<Region>> regions = {});
588589
OperationState(Location location, StringRef name, ValueRange operands,
589590
TypeRange types, ArrayRef<NamedAttribute> attributes,
590591
BlockRange successors = {},

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,9 +1406,9 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
14061406
// name that works both in scalar mode and vector mode.
14071407
// TODO: Is it worth considering an Operation.clone operation which
14081408
// changes the type so we can promote an Operation with less boilerplate?
1409-
OperationState vecOpState(op->getLoc(), op->getName().getStringRef(),
1410-
vectorOperands, vectorTypes, op->getAttrs(),
1411-
/*successors=*/{}, /*regions=*/{});
1409+
OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands,
1410+
vectorTypes, op->getAttrs(), /*successors=*/{},
1411+
/*regions=*/{});
14121412
Operation *vecOp = state.builder.createOperation(vecOpState);
14131413
state.registerOpVectorReplacement(op, vecOp);
14141414
return vecOp;

mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
7070
Operation *op,
7171
ArrayRef<Value> operands,
7272
ArrayRef<Type> resultTypes) {
73-
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
74-
op->getAttrs());
73+
OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
7574
return builder.createOperation(res);
7675
}
7776

mlir/lib/IR/MLIRContext.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class MLIRContextImpl {
182182
llvm::StringMap<OperationName::Impl> operations;
183183

184184
/// A vector of operation info specifically for registered operations.
185-
SmallVector<RegisteredOperationName> registeredOperations;
185+
llvm::StringMap<RegisteredOperationName> registeredOperations;
186186

187187
/// A mutex used when accessing operation information.
188188
llvm::sys::SmartRWMutex<true> operationInfoMutex;
@@ -576,8 +576,9 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
576576
// We just have the operations in a non-deterministic hash table order. Dump
577577
// into a temporary array, then sort it by operation name to get a stable
578578
// ordering.
579-
std::vector<RegisteredOperationName> result(
580-
impl->registeredOperations.begin(), impl->registeredOperations.end());
579+
auto unwrappedNames = llvm::make_second_range(impl->registeredOperations);
580+
std::vector<RegisteredOperationName> result(unwrappedNames.begin(),
581+
unwrappedNames.end());
581582
llvm::array_pod_sort(result.begin(), result.end(),
582583
[](const RegisteredOperationName *lhs,
583584
const RegisteredOperationName *rhs) {
@@ -589,7 +590,7 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
589590
}
590591

591592
bool MLIRContext::isOperationRegistered(StringRef name) {
592-
return OperationName(name, this).isRegistered();
593+
return RegisteredOperationName::lookup(name, this).hasValue();
593594
}
594595

595596
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
@@ -649,6 +650,15 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
649650
// Check for an existing name in read-only mode.
650651
bool isMultithreadingEnabled = context->isMultithreadingEnabled();
651652
if (isMultithreadingEnabled) {
653+
// Check the registered info map first. In the overwhelmingly common case,
654+
// the entry will be in here and it also removes the need to acquire any
655+
// locks.
656+
auto registeredIt = ctxImpl.registeredOperations.find(name);
657+
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
658+
impl = registeredIt->second.impl;
659+
return;
660+
}
661+
652662
llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
653663
auto it = ctxImpl.operations.find(name);
654664
if (it != ctxImpl.operations.end()) {
@@ -676,6 +686,15 @@ StringRef OperationName::getDialectNamespace() const {
676686
// RegisteredOperationName
677687
//===----------------------------------------------------------------------===//
678688

689+
Optional<RegisteredOperationName>
690+
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
691+
auto &impl = ctx->getImpl();
692+
auto it = impl.registeredOperations.find(name);
693+
if (it != impl.registeredOperations.end())
694+
return it->getValue();
695+
return llvm::None;
696+
}
697+
679698
ParseResult
680699
RegisteredOperationName::parseAssembly(OpAsmParser &parser,
681700
OperationState &result) const {
@@ -717,7 +736,8 @@ void RegisteredOperationName::insert(
717736
<< "' is already registered.\n";
718737
abort();
719738
}
720-
ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl));
739+
ctxImpl.registeredOperations.try_emplace(name,
740+
RegisteredOperationName(&impl));
721741

722742
// Update the registered info for this operation.
723743
impl.dialect = &dialect;

mlir/lib/IR/OperationSupport.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,19 +170,26 @@ OperationState::OperationState(Location location, StringRef name)
170170
OperationState::OperationState(Location location, OperationName name)
171171
: location(location), name(name) {}
172172

173-
OperationState::OperationState(Location location, StringRef name,
173+
OperationState::OperationState(Location location, OperationName name,
174174
ValueRange operands, TypeRange types,
175175
ArrayRef<NamedAttribute> attributes,
176176
BlockRange successors,
177177
MutableArrayRef<std::unique_ptr<Region>> regions)
178-
: location(location), name(name, location->getContext()),
178+
: location(location), name(name),
179179
operands(operands.begin(), operands.end()),
180180
types(types.begin(), types.end()),
181181
attributes(attributes.begin(), attributes.end()),
182182
successors(successors.begin(), successors.end()) {
183183
for (std::unique_ptr<Region> &r : regions)
184184
this->regions.push_back(std::move(r));
185185
}
186+
OperationState::OperationState(Location location, StringRef name,
187+
ValueRange operands, TypeRange types,
188+
ArrayRef<NamedAttribute> attributes,
189+
BlockRange successors,
190+
MutableArrayRef<std::unique_ptr<Region>> regions)
191+
: OperationState(location, OperationName(name, location.getContext()),
192+
operands, types, attributes, successors, regions) {}
186193

187194
void OperationState::addOperands(ValueRange newOperands) {
188195
operands.append(newOperands.begin(), newOperands.end());

0 commit comments

Comments
 (0)