diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -172,6 +172,10 @@ /// if a pass manager has already been initialized. LogicalResult initialize(MLIRContext *context, unsigned newInitGeneration); + /// Compute a hash of the pipeline, so that we can detect changes (a pass is + /// added...). + llvm::hash_code hash(); + /// A pointer to an internal implementation instance. std::unique_ptr impl; @@ -439,9 +443,11 @@ /// generate reproducers. std::unique_ptr crashReproGenerator; - /// A hash key used to detect when reinitialization is necessary. + /// Hash keys used to detect when reinitialization is necessary. llvm::hash_code initializationKey = DenseMapInfo::getTombstoneKey(); + llvm::hash_code pipelineInitializationKey = + DenseMapInfo::getTombstoneKey(); /// Flag that specifies if pass timing is enabled. bool passTiming : 1; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Threading.h" #include "mlir/IR/Verifier.h" #include "mlir/Support/FileUtilities.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/CommandLine.h" @@ -424,6 +425,23 @@ return success(); } +llvm::hash_code OpPassManager::hash() { + llvm::hash_code hashCode; + for (Pass &pass : getPasses()) { + // If this pass isn't an adaptor, directly hash it. + auto *adaptor = dyn_cast(&pass); + if (!adaptor) { + hashCode = llvm::hash_combine(hashCode, &pass); + continue; + } + // Otherwise, hash recursively each of the adaptors pass managers. + for (OpPassManager &adaptorPM : adaptor->getPassManagers()) + llvm::hash_combine(hashCode, adaptorPM.hash()); + } + return hashCode; +} + + //===----------------------------------------------------------------------===// // OpToOpPassAdaptor //===----------------------------------------------------------------------===// @@ -825,10 +843,12 @@ // Initialize all of the passes within the pass manager with a new generation. llvm::hash_code newInitKey = context->getRegistryHash(); - if (newInitKey != initializationKey) { + llvm::hash_code pipelineKey = hash(); + if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) { if (failed(initialize(context, impl->initializationGeneration + 1))) return failure(); initializationKey = newInitKey; + pipelineKey = pipelineInitializationKey; } // Construct a top level analysis manager for the pipeline. diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" #include "gtest/gtest.h" @@ -144,4 +145,39 @@ "intend to nest?"); } +/// Simple pass to annotate a func::FuncOp with the results of analysis. +struct InitializeCheckingPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass) + LogicalResult initialize(MLIRContext *ctx) final { + initialized = true; + return success(); + } + bool initialized = false; + + void runOnOperation() override { + if (!initialized) { + getOperation()->emitError() << "Pass isn't initialized!"; + signalPassFailure(); + } + } +}; + +TEST(PassManagerTest, PassInitialization) { + MLIRContext context; + context.allowUnregisteredDialects(); + + // Create a module + OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); + + // Instantiate and run our pass. + auto pm = PassManager::on(&context); + pm.addPass(std::make_unique()); + EXPECT_TRUE(succeeded(pm.run(module.get()))); + + // Adding a second copy of the pass, we should also initialize it! + pm.addPass(std::make_unique()); + EXPECT_TRUE(succeeded(pm.run(module.get()))); +} + } // namespace