diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -53,7 +53,7 @@ virtual ~Pass() = default; /// Returns the unique identifier that corresponds to this pass. - TypeID getTypeID() const { return passID; } + TypeID getTypeID() const { return passIDAndIsInitialized.getPointer(); } /// Returns the pass info for the specified pass class or null if unknown. static const PassInfo *lookupPassInfo(TypeID passID); @@ -148,8 +148,10 @@ protected: explicit Pass(TypeID passID, Optional opName = llvm::None) - : passID(passID), opName(opName) {} - Pass(const Pass &other) : Pass(other.passID, other.opName) {} + : passIDAndIsInitialized(passID), opName(opName) {} + Pass(const Pass &other) + : passIDAndIsInitialized(other.passIDAndIsInitialized), + opName(other.opName) {} /// Returns the current pass state. detail::PassExecutionState &getPassState() { @@ -163,6 +165,12 @@ /// The polymorphic API that runs the pass over the currently held operation. virtual void runOnOperation() = 0; + /// Initialize any complex state necessary for running this pass. This hook + /// should not rely on any state accessible during the execution of a pass. + /// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be + /// invoked within this hook. + virtual void initialize(MLIRContext *context) {} + /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. /// The provided operation must be the current one or one nested below. @@ -265,13 +273,21 @@ void copyOptionValuesFrom(const Pass *other); private: + /// Initialize the state of this pass if it has not already been initialized. + void initializeIfNecessary(MLIRContext *context) { + if (passIDAndIsInitialized.getInt()) + return; + initialize(context); + passIDAndIsInitialized.setInt(true); + } /// Out of line virtual method to ensure vtables and metadata are emitted to a /// single .o file. virtual void anchor(); - /// Represents a unique identifier for the pass. - TypeID passID; + /// A unique identifier for the type of this pass. + /// A boolean indicating if this pass has been initialized yet. + llvm::PointerIntPair passIDAndIsInitialized; /// The name of the operation that this pass operates on, or None if this is a /// generic OperationPass. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -125,6 +125,10 @@ /// Return the current nesting mode. Nesting getNesting(); + /// Initialize all of the passes within this pass manager. This should not be + /// called directly. + void initialize(MLIRContext *context); + private: /// A pointer to an internal implementation instance. std::unique_ptr impl; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -320,9 +320,14 @@ registerDialectsForPipeline(*this, dialects); } +void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } + OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; } -void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } +void OpPassManager::initialize(MLIRContext *context) { + for (Pass &pass : getPasses()) + pass.initializeIfNecessary(context); +} //===----------------------------------------------------------------------===// // OpToOpPassAdaptor @@ -349,6 +354,8 @@ "operation that isn't " "nested under the current operation the pass is processing"; + // Initialize the user provided pipeline and execute the pipeline. + pipeline.initialize(root->getContext()); AnalysisManager nestedAm = am.nest(root); return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm, verifyPasses); @@ -433,6 +440,15 @@ pm.getDependentDialects(dialects); } +/// Initialize the nested passes within this adaptor. +void OpToOpPassAdaptor::initialize(MLIRContext *context) { + for (OpPassManager &pm : mgrs) + pm.initialize(context); + for (SmallVectorImpl &asyncPM : asyncExecutors) + for (OpPassManager &pm : asyncPM) + pm.initialize(context); +} + /// Merge the current pass adaptor into given 'rhs'. void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) { for (auto &pm : mgrs) { @@ -799,6 +815,9 @@ getDependentDialects(dependentDialects); dependentDialects.loadAll(context); + // Initialize all of the passes within the pass manager. + initialize(context); + // Construct a top level analysis manager for the pipeline. ModuleAnalysisManager am(op, instrumentor.get()); diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -39,6 +39,9 @@ /// adaptor. void getDependentDialects(DialectRegistry &dialects) const override; + /// Initialize the nested passes within this adaptor. + void initialize(MLIRContext *context) override; + /// Return the async pass managers held by this parallel adaptor. MutableArrayRef> getParallelPassManagers() { return asyncExecutors; diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -21,19 +21,19 @@ namespace { /// Canonicalize operations in nested regions. struct Canonicalizer : public CanonicalizerBase { - void runOnOperation() override { - OwningRewritePatternList patterns; - - // TODO: Instead of adding all known patterns from the whole system lazily - // add and cache the canonicalization patterns for ops we see in practice - // when building the worklist. For now, we just grab everything. - auto *context = &getContext(); + /// Initialize the canonicalizer by building the set of patterns used during + /// execution. + void initialize(MLIRContext *context) override { + OwningRewritePatternList owningPatterns; for (auto *op : context->getRegisteredOperations()) - op->getCanonicalizationPatterns(patterns, context); - - Operation *op = getOperation(); - applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); + op->getCanonicalizationPatterns(owningPatterns, context); + patterns = std::move(owningPatterns); } + void runOnOperation() override { + applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); + } + + FrozenRewritePatternList patterns; }; } // end anonymous namespace