diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -19,6 +19,8 @@ namespace mlir { namespace detail { +class OpToOpPassAdaptor; + /// The state for a single execution of a pass. This provides a unified /// interface for accessing and initializing necessary state for pass execution. struct PassExecutionState { @@ -249,9 +251,6 @@ void copyOptionValuesFrom(const Pass *other); private: - /// Forwarding function to execute this pass on the given operation. - LLVM_NODISCARD - LogicalResult run(Operation *op, AnalysisManager am); /// Out of line virtual method to ensure vtables and metadata are emitted to a /// single .o file. @@ -273,11 +272,11 @@ /// The pass options registered to this pass instance. detail::PassOptions passOptions; - /// Allow access to 'clone' and 'run'. + /// Allow access to 'clone'. friend class OpPassManager; - /// Allow access to 'run'. - friend class PassManager; + /// Allow access to 'passState'. + friend detail::OpToOpPassAdaptor; /// Allow access to 'passOptions'. friend class PassInfo; diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -47,6 +47,7 @@ /// other OpPassManagers or the top-level PassManager. class OpPassManager { public: + OpPassManager(OperationName name, bool verifyPasses); OpPassManager(OpPassManager &&rhs); OpPassManager(const OpPassManager &rhs); ~OpPassManager(); @@ -54,22 +55,19 @@ /// Iterator over the passes in this pass manager. using pass_iterator = - llvm::pointee_iterator>::iterator>; + llvm::pointee_iterator>::iterator>; pass_iterator begin(); pass_iterator end(); iterator_range getPasses() { return {begin(), end()}; } - using const_pass_iterator = llvm::pointee_iterator< - std::vector>::const_iterator>; + using const_pass_iterator = + llvm::pointee_iterator>::const_iterator>; const_pass_iterator begin() const; const_pass_iterator end() const; iterator_range getPasses() const { return {begin(), end()}; } - /// Run the held passes over the given operation. - LogicalResult run(Operation *op, AnalysisManager am); - /// Nest a new operation pass manager for the given operation kind under this /// pass manager. OpPassManager &nest(const OperationName &nestedName); @@ -115,8 +113,6 @@ void getDependentDialects(DialectRegistry &dialects) const; private: - OpPassManager(OperationName name, bool verifyPasses); - /// A pointer to an internal implementation instance. std::unique_ptr impl; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/CrashRecoveryContext.h" @@ -73,33 +74,6 @@ passOptions.print(os); } -/// Forwarding function to execute this pass. -LogicalResult Pass::run(Operation *op, AnalysisManager am) { - passState.emplace(op, am); - - // Instrument before the pass has run. - auto pi = am.getPassInstrumentor(); - if (pi) - pi->runBeforePass(this, op); - - // Invoke the virtual runOnOperation method. - runOnOperation(); - - // Invalidate any non preserved analyses. - am.invalidate(passState->preservedAnalyses); - - // Instrument after the pass has run. - bool passFailed = passState->irAndPassFailed.getInt(); - if (pi) { - if (passFailed) - pi->runAfterPassFailed(this, op); - else - pi->runAfterPass(this, op); - } - - // Return if the pass signaled a failure. - return failure(passFailed); -} //===----------------------------------------------------------------------===// // Verifier Passes @@ -286,24 +260,17 @@ OpPassManager::~OpPassManager() {} OpPassManager::pass_iterator OpPassManager::begin() { - return impl->passes.begin(); + return MutableArrayRef>{impl->passes}.begin(); +} +OpPassManager::pass_iterator OpPassManager::end() { + return MutableArrayRef>{impl->passes}.end(); } -OpPassManager::pass_iterator OpPassManager::end() { return impl->passes.end(); } OpPassManager::const_pass_iterator OpPassManager::begin() const { - return impl->passes.begin(); + return ArrayRef>{impl->passes}.begin(); } OpPassManager::const_pass_iterator OpPassManager::end() const { - return impl->passes.end(); -} - -/// Run all of the passes in this manager over the current operation. -LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) { - // Run each of the held passes. - for (auto &pass : impl->passes) - if (failed(pass->run(op, am))) - return failure(); - return success(); + return ArrayRef>{impl->passes}.end(); } /// Nest a new operation pass manager for the given operation kind under this @@ -367,19 +334,52 @@ // OpToOpPassAdaptor //===----------------------------------------------------------------------===// -/// Utility to run the given operation and analysis manager on a provided op -/// pass manager. -static LogicalResult runPipeline(OpPassManager &pm, Operation *op, - AnalysisManager am) { +LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, + AnalysisManager am) { + pass->passState.emplace(op, am); + + // Instrument before the pass has run. + PassInstrumentor *pi = am.getPassInstrumentor(); + if (pi) + pi->runBeforePass(pass, op); + + // Invoke the virtual runOnOperation method. + pass->runOnOperation(); + + // Invalidate any non preserved analyses. + am.invalidate(pass->passState->preservedAnalyses); + + // Instrument after the pass has run. + bool passFailed = pass->passState->irAndPassFailed.getInt(); + if (pi) { + if (passFailed) + pi->runAfterPassFailed(pass, op); + else + pi->runAfterPass(pass, op); + } + + // Return if the pass signaled a failure. + return failure(passFailed); +} + +/// Run the given operation and analysis manager on a provided op pass manager. +LogicalResult OpToOpPassAdaptor::runPipeline( + iterator_range passes, Operation *op, + AnalysisManager am) { + auto scope_exit = llvm::make_scope_exit([&] { + // Clear out any computed operation analyses. These analyses won't be used + // any more in this pipeline, and this helps reduce the current working set + // of memory. If preserving these analyses becomes important in the future + // we can re-evaluate this. + am.clear(); + }); + // Run the pipeline over the provided operation. - auto result = pm.run(op, am); + for (Pass &pass : passes) + if (failed(run(&pass, op, am))) + return failure(); - // Clear out any computed operation analyses. These analyses won't be used - // any more in this pipeline, and this helps reduce the current working set - // of memory. If preserving these analyses becomes important in the future - // we can re-evaluate this. - am.clear(); - return result; + return success(); } /// Find an operation pass manager that can operate on an operation of the given @@ -457,7 +457,7 @@ // Run the held pipeline over the current operation. if (instrumentor) instrumentor->runBeforePipeline(mgr->getOpName(), parentInfo); - auto result = runPipeline(*mgr, &op, am.slice(&op)); + auto result = runPipeline(mgr->getPasses(), &op, am.slice(&op)); if (instrumentor) instrumentor->runAfterPipeline(mgr->getOpName(), parentInfo); @@ -536,7 +536,8 @@ if (instrumentor) instrumentor->runBeforePipeline(pm->getOpName(), parentInfo); - auto pipelineResult = runPipeline(*pm, it.first, it.second); + auto pipelineResult = + runPipeline(pm->getPasses(), it.first, it.second); if (instrumentor) instrumentor->runAfterPipeline(pm->getOpName(), parentInfo); @@ -709,7 +710,7 @@ llvm::CrashRecoveryContext recoveryContext; recoveryContext.RunSafelyOnThread([&] { for (std::unique_ptr &pass : passes) - if (failed(pass->run(module, am))) + if (failed(OpToOpPassAdaptor::run(pass.get(), module, am))) return; passManagerResult = success(); }); @@ -753,9 +754,10 @@ // If reproducer generation is enabled, run the pass manager with crash // handling enabled. - LogicalResult result = crashReproducerFileName - ? runWithCrashRecovery(module, am) - : OpPassManager::run(module, am); + LogicalResult result = + crashReproducerFileName + ? runWithCrashRecovery(module, am) + : OpToOpPassAdaptor::runPipeline(getPasses(), module, am); // Dump all of the pass statistics if necessary. if (passStatisticsMode) diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -62,12 +62,24 @@ /// Run this pass adaptor asynchronously. void runOnOperationAsyncImpl(); + /// Run the given operation and analysis manager on a single pass. + static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am); + + /// Run the given operation and analysis manager on a provided op pass + /// manager. + static LogicalResult + runPipeline(iterator_range passes, + Operation *op, AnalysisManager am); + /// A set of adaptors to run. SmallVector mgrs; /// A set of executors, cloned from the main executor, that run asynchronously /// on different threads. This is used when threading is enabled. SmallVector, 8> asyncExecutors; + + // For accessing "runPipeline". + friend class mlir::PassManager; }; } // end namespace detail