diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -166,6 +166,12 @@ Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, function_ref()> ctor); + /// Returns a hash of the registry of the context that may be used to give + /// a rough indicator of if the state of the context registry has changed. The + /// context registry correlates to loaded dialects and their entities + /// (attributes, operations, types, etc.). + llvm::hash_code getRegistryHash(); + private: const std::unique_ptr impl; diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -375,6 +375,9 @@ /// An optional factory to use when generating a crash reproducer if valid. ReproducerStreamFactory crashReproducerStreamFactory; + /// A hash key used to detect when reinitialization is necessary. + llvm::hash_code initializationKey; + /// Flag that specifies if pass timing is enabled. bool passTiming : 1; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -492,6 +492,16 @@ return dialect.get(); } +llvm::hash_code MLIRContext::getRegistryHash() { + llvm::hash_code hash(0); + // Factor in number of loaded dialects, attributes, operations, types. + hash = llvm::hash_combine(hash, impl->loadedDialects.size()); + hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); + hash = llvm::hash_combine(hash, impl->registeredOperations.size()); + hash = llvm::hash_combine(hash, impl->registeredTypes.size()); + return hash; +} + bool MLIRContext::allowsUnregisteredDialects() { return impl->allowUnregisteredDialects; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -846,6 +846,7 @@ PassManager::PassManager(MLIRContext *ctx, Nesting nesting, StringRef operationName) : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx), + initializationKey(DenseMapInfo::getTombstoneKey()), passTiming(false), localReproducer(false), verifyPasses(true) {} PassManager::~PassManager() {} @@ -868,7 +869,11 @@ dependentDialects.loadAll(context); // Initialize all of the passes within the pass manager with a new generation. - initialize(context, impl->initializationGeneration + 1); + llvm::hash_code newInitKey = context->getRegistryHash(); + if (newInitKey != initializationKey) { + initialize(context, impl->initializationGeneration + 1); + initializationKey = newInitKey; + } // Construct a top level analysis manager for the pipeline. ModuleAnalysisManager am(op, instrumentor.get());