diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -131,6 +131,23 @@ already ne loaded must express this by overriding the `getDependentDialects()` method and declare this list of Dialects explicitly. +### Initialization + +In certain situations, a Pass may contain state that is constructed dynamically, +but is potentially expensive to recompute in successive runs of the Pass. One +such example is when using [`PDL`-based](Dialects/PDLOps.md) +[patterns](PatternRewriter.md), which are compiled into a bytecode during +runtime. In these situations, a pass may override the following hook to +initialize this heavy state: + +* `void initialize(MLIRContext *context)` + +This hook is executed once per run of a full pass pipeline, meaning that it does +not have access to the state available during a `runOnOperation` call. More +concretely, all necessary accesses to an `MLIRContext` should be driven via the +provided `context` parameter, and methods that utilize "per-run" state such as +`getContext`/`getOperation`/`getAnalysis`/etc. must not be used. + ## Analysis Management An important concept, along with transformation passes, are analyses. These are diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -166,6 +166,12 @@ /// The polymorphic API that runs the pass over the currently held operation. virtual void runOnOperation() = 0; + /// Initialize any complex state necessary for running this pass. This hook + /// should not rely on any state accessible during the execution of a pass. + /// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be + /// invoked within this hook. + virtual void initialize(MLIRContext *context) {} + /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. /// The provided operation must be the current one or one nested below. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -35,6 +35,7 @@ namespace detail { struct OpPassManagerImpl; +class OpToOpPassAdaptor; struct PassExecutionState; } // end namespace detail @@ -126,9 +127,17 @@ Nesting getNesting(); private: + /// Initialize all of the passes within this pass manager with the given + /// initialization generation. The initialization generation is used to detect + /// if a pass manager has already been initialized. + void initialize(MLIRContext *context, unsigned newInitGeneration); + /// A pointer to an internal implementation instance. std::unique_ptr impl; + /// Allow access to initialize. + friend detail::OpToOpPassAdaptor; + /// Allow access to the constructor. friend class PassManager; friend class Pass; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -81,9 +81,10 @@ namespace detail { struct OpPassManagerImpl { OpPassManagerImpl(Identifier identifier, OpPassManager::Nesting nesting) - : name(identifier.str()), identifier(identifier), nesting(nesting) {} + : name(identifier.str()), identifier(identifier), + initializationGeneration(0), nesting(nesting) {} OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting) - : name(name), nesting(nesting) {} + : name(name), initializationGeneration(0), nesting(nesting) {} /// Merge the passes of this pass manager into the one provided. void mergeInto(OpPassManagerImpl &rhs); @@ -105,6 +106,7 @@ /// pass. void splitAdaptorPasses(); + /// Return the operation name of this pass manager as an identifier. Identifier getOpName(MLIRContext &context) { if (!identifier) identifier = Identifier::get(name, &context); @@ -121,6 +123,10 @@ /// The set of passes to run as part of this pass manager. std::vector> passes; + /// The current initialization generation of this pass manager. This is used + /// to indicate when a pass manager should be reinitialized. + unsigned initializationGeneration; + /// Control the implicit nesting of passes that mismatch the name set for this /// OpPassManager. OpPassManager::Nesting nesting; @@ -320,16 +326,36 @@ registerDialectsForPipeline(*this, dialects); } +void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } + OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; } -void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } +void OpPassManager::initialize(MLIRContext *context, + unsigned newInitGeneration) { + if (impl->initializationGeneration == newInitGeneration) + return; + impl->initializationGeneration = newInitGeneration; + for (Pass &pass : getPasses()) { + // If this pass isn't an adaptor, directly initialize it. + auto *adaptor = dyn_cast(&pass); + if (!adaptor) { + pass.initialize(context); + continue; + } + + // Otherwise, initialize each of the adaptors pass managers. + for (OpPassManager &adaptorPM : adaptor->getPassManagers()) + adaptorPM.initialize(context, newInitGeneration); + } +} //===----------------------------------------------------------------------===// // OpToOpPassAdaptor //===----------------------------------------------------------------------===// LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, - AnalysisManager am, bool verifyPasses) { + AnalysisManager am, bool verifyPasses, + unsigned parentInitGeneration) { if (!op->getName().getAbstractOperation()) return op->emitOpError() << "trying to schedule a pass on an unregistered operation"; @@ -352,9 +378,12 @@ "nested under the current operation the pass is processing"; assert(pipeline.getOpName() == root->getName().getStringRef()); + // Initialize the user provided pipeline and execute the pipeline. + pipeline.initialize(root->getContext(), parentInitGeneration); AnalysisManager nestedAm = root == op ? am : am.nest(root); return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm, - verifyPasses, pi, &parentInfo); + verifyPasses, parentInitGeneration, + pi, &parentInfo); }; pass->passState.emplace(op, am, dynamic_pipeline_callback); @@ -391,7 +420,8 @@ /// Run the given operation and analysis manager on a provided op pass manager. LogicalResult OpToOpPassAdaptor::runPipeline( iterator_range passes, Operation *op, - AnalysisManager am, bool verifyPasses, PassInstrumentor *instrumentor, + AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration, + PassInstrumentor *instrumentor, const PassInstrumentation::PipelineParentInfo *parentInfo) { assert((!instrumentor || parentInfo) && "expected parent info if instrumentor is provided"); @@ -407,7 +437,7 @@ if (instrumentor) instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo); for (Pass &pass : passes) - if (failed(run(&pass, op, am, verifyPasses))) + if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration))) return failure(); if (instrumentor) instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo); @@ -502,8 +532,10 @@ continue; // Run the held pipeline over the current operation. + unsigned initGeneration = mgr->impl->initializationGeneration; if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op), - verifyPasses, instrumentor, &parentInfo))) + verifyPasses, initGeneration, instrumentor, + &parentInfo))) return signalPassFailure(); } } @@ -578,9 +610,10 @@ pms, it.first->getName().getIdentifier(), getContext()); assert(pm && "expected valid pass manager for operation"); + unsigned initGeneration = pm->impl->initializationGeneration; LogicalResult pipelineResult = runPipeline(pm->getPasses(), it.first, it.second, verifyPasses, - instrumentor, &parentInfo); + initGeneration, instrumentor, &parentInfo); // Drop this thread from being tracked by the diagnostic handler. // After this task has finished, the thread may be used outside of @@ -753,7 +786,8 @@ llvm::CrashRecoveryContext recoveryContext; recoveryContext.RunSafelyOnThread([&] { for (std::unique_ptr &pass : passes) - if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses))) + if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses, + impl->initializationGeneration))) return; passManagerResult = success(); }); @@ -801,6 +835,9 @@ getDependentDialects(dependentDialects); dependentDialects.loadAll(context); + // Initialize all of the passes within the pass manager with a new generation. + initialize(context, impl->initializationGeneration + 1); + // Construct a top level analysis manager for the pipeline. ModuleAnalysisManager am(op, instrumentor.get()); @@ -812,7 +849,8 @@ LogicalResult result = crashReproducerFileName ? runWithCrashRecovery(op, am) - : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses); + : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses, + impl->initializationGeneration); // Notify the context that the run is done. context->exitMultiThreadedExecution(); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -55,14 +55,19 @@ void runOnOperationAsyncImpl(bool verifyPasses); /// Run the given operation and analysis manager on a single pass. + /// `parentInitGeneration` is the initialization generation of the parent pass + /// manager, and is used to initialize any dynamic pass pipelines run by the + /// given pass. static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am, - bool verifyPasses); + bool verifyPasses, unsigned parentInitGeneration); /// Run the given operation and analysis manager on a provided op pass - /// manager. + /// manager. `parentInitGeneration` is the initialization generation of the + /// parent pass manager, and is used to initialize any dynamic pass pipelines + /// run by the given passes. static LogicalResult runPipeline( iterator_range passes, Operation *op, - AnalysisManager am, bool verifyPasses, + AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration, PassInstrumentor *instrumentor = nullptr, const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -21,19 +21,19 @@ namespace { /// Canonicalize operations in nested regions. struct Canonicalizer : public CanonicalizerBase { - void runOnOperation() override { - OwningRewritePatternList patterns; - - // TODO: Instead of adding all known patterns from the whole system lazily - // add and cache the canonicalization patterns for ops we see in practice - // when building the worklist. For now, we just grab everything. - auto *context = &getContext(); + /// Initialize the canonicalizer by building the set of patterns used during + /// execution. + void initialize(MLIRContext *context) override { + OwningRewritePatternList owningPatterns; for (auto *op : context->getRegisteredOperations()) - op->getCanonicalizationPatterns(patterns, context); - - Operation *op = getOperation(); - applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); + op->getCanonicalizationPatterns(owningPatterns, context); + patterns = std::move(owningPatterns); } + void runOnOperation() override { + applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); + } + + FrozenRewritePatternList patterns; }; } // end anonymous namespace