Skip to content

Commit c5f0c32

Browse files
joker-ephtru
authored andcommitted
Fix MLIR pass manager initialization: hash the pass pipeline to detect when initialization is needed
The current logic hashes the context to detect registration changes and re-run the pass initialization. However it wasn't checking for changes to the pipeline, so a pass that would get added after a first run would not be initialized during subsequent runs. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D158377
1 parent 0d8fd07 commit c5f0c32

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

mlir/include/mlir/Pass/PassManager.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ class OpPassManager {
172172
/// if a pass manager has already been initialized.
173173
LogicalResult initialize(MLIRContext *context, unsigned newInitGeneration);
174174

175+
/// Compute a hash of the pipeline, so that we can detect changes (a pass is
176+
/// added...).
177+
llvm::hash_code hash();
178+
175179
/// A pointer to an internal implementation instance.
176180
std::unique_ptr<detail::OpPassManagerImpl> impl;
177181

@@ -439,9 +443,11 @@ class PassManager : public OpPassManager {
439443
/// generate reproducers.
440444
std::unique_ptr<detail::PassCrashReproducerGenerator> crashReproGenerator;
441445

442-
/// A hash key used to detect when reinitialization is necessary.
446+
/// Hash keys used to detect when reinitialization is necessary.
443447
llvm::hash_code initializationKey =
444448
DenseMapInfo<llvm::hash_code>::getTombstoneKey();
449+
llvm::hash_code pipelineInitializationKey =
450+
DenseMapInfo<llvm::hash_code>::getTombstoneKey();
445451

446452
/// Flag that specifies if pass timing is enabled.
447453
bool passTiming : 1;

mlir/lib/Pass/Pass.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/Threading.h"
1919
#include "mlir/IR/Verifier.h"
2020
#include "mlir/Support/FileUtilities.h"
21+
#include "llvm/ADT/Hashing.h"
2122
#include "llvm/ADT/STLExtras.h"
2223
#include "llvm/ADT/ScopeExit.h"
2324
#include "llvm/Support/CommandLine.h"
@@ -424,6 +425,23 @@ LogicalResult OpPassManager::initialize(MLIRContext *context,
424425
return success();
425426
}
426427

428+
llvm::hash_code OpPassManager::hash() {
429+
llvm::hash_code hashCode;
430+
for (Pass &pass : getPasses()) {
431+
// If this pass isn't an adaptor, directly hash it.
432+
auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
433+
if (!adaptor) {
434+
hashCode = llvm::hash_combine(hashCode, &pass);
435+
continue;
436+
}
437+
// Otherwise, hash recursively each of the adaptors pass managers.
438+
for (OpPassManager &adaptorPM : adaptor->getPassManagers())
439+
llvm::hash_combine(hashCode, adaptorPM.hash());
440+
}
441+
return hashCode;
442+
}
443+
444+
427445
//===----------------------------------------------------------------------===//
428446
// OpToOpPassAdaptor
429447
//===----------------------------------------------------------------------===//
@@ -825,10 +843,12 @@ LogicalResult PassManager::run(Operation *op) {
825843

826844
// Initialize all of the passes within the pass manager with a new generation.
827845
llvm::hash_code newInitKey = context->getRegistryHash();
828-
if (newInitKey != initializationKey) {
846+
llvm::hash_code pipelineKey = hash();
847+
if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
829848
if (failed(initialize(context, impl->initializationGeneration + 1)))
830849
return failure();
831850
initializationKey = newInitKey;
851+
pipelineKey = pipelineInitializationKey;
832852
}
833853

834854
// Construct a top level analysis manager for the pipeline.

mlir/unittests/Pass/PassManagerTest.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/Func/IR/FuncOps.h"
1111
#include "mlir/IR/Builders.h"
1212
#include "mlir/IR/BuiltinOps.h"
13+
#include "mlir/IR/Diagnostics.h"
1314
#include "mlir/Pass/Pass.h"
1415
#include "gtest/gtest.h"
1516

@@ -144,4 +145,39 @@ TEST(PassManagerTest, InvalidPass) {
144145
"intend to nest?");
145146
}
146147

148+
/// Simple pass to annotate a func::FuncOp with the results of analysis.
149+
struct InitializeCheckingPass
150+
: public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
151+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
152+
LogicalResult initialize(MLIRContext *ctx) final {
153+
initialized = true;
154+
return success();
155+
}
156+
bool initialized = false;
157+
158+
void runOnOperation() override {
159+
if (!initialized) {
160+
getOperation()->emitError() << "Pass isn't initialized!";
161+
signalPassFailure();
162+
}
163+
}
164+
};
165+
166+
TEST(PassManagerTest, PassInitialization) {
167+
MLIRContext context;
168+
context.allowUnregisteredDialects();
169+
170+
// Create a module
171+
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
172+
173+
// Instantiate and run our pass.
174+
auto pm = PassManager::on<ModuleOp>(&context);
175+
pm.addPass(std::make_unique<InitializeCheckingPass>());
176+
EXPECT_TRUE(succeeded(pm.run(module.get())));
177+
178+
// Adding a second copy of the pass, we should also initialize it!
179+
pm.addPass(std::make_unique<InitializeCheckingPass>());
180+
EXPECT_TRUE(succeeded(pm.run(module.get())));
181+
}
182+
147183
} // namespace

0 commit comments

Comments
 (0)