diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -98,7 +98,7 @@ /// Returns true if the given pass is hidden from IR printing. static bool isHiddenPass(Pass *pass) { - return isAdaptorPass(pass) || isa(pass); + return isa(pass) || isa(pass); } static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, @@ -172,7 +172,7 @@ } void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { - if (isAdaptorPass(pass)) + if (isa(pass)) return; if (config->shouldPrintAfterOnlyOnChange()) beforePassFingerPrints.erase(pass); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -51,7 +51,7 @@ /// an adaptor pass, print with the op_name(sub_pass,...) format. void Pass::printAsTextualPipeline(raw_ostream &os) { // Special case for adaptors to use the 'op_name(sub_passes)' format. - if (auto *adaptor = getAdaptorPassBase(this)) { + if (auto *adaptor = dyn_cast(this)) { llvm::interleaveComma(adaptor->getPassManagers(), os, [&](OpPassManager &pm) { os << pm.getOpName() << "("; @@ -152,15 +152,15 @@ void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() { // Bail out early if there are no adaptor passes. if (llvm::none_of(passes, [](std::unique_ptr &pass) { - return isAdaptorPass(pass.get()); + return isa(pass.get()); })) return; // Walk the pass list and merge adjacent adaptors. - OpToOpPassAdaptorBase *lastAdaptor = nullptr; + OpToOpPassAdaptor *lastAdaptor = nullptr; for (auto it = passes.begin(), e = passes.end(); it != e; ++it) { // Check to see if this pass is an adaptor. - if (auto *currentAdaptor = getAdaptorPassBase(it->get())) { + if (auto *currentAdaptor = dyn_cast(it->get())) { // If it is the first adaptor in a possible chain, remember it and // continue. if (!lastAdaptor) { @@ -243,16 +243,7 @@ /// pass manager. OpPassManager &OpPassManager::nest(const OperationName &nestedName) { OpPassManager nested(nestedName, impl->disableThreads, impl->verifyPasses); - - /// Create an adaptor for this pass. If multi-threading is disabled, then - /// create a synchronous adaptor. - if (impl->disableThreads || !llvm::llvm_is_multithreaded()) { - auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); - addPass(std::unique_ptr(adaptor)); - return adaptor->getPassManagers().front(); - } - - auto *adaptor = new OpToOpPassAdaptorParallel(std::move(nested)); + auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); } @@ -330,12 +321,12 @@ return it == mgrs.end() ? nullptr : &*it; } -OpToOpPassAdaptorBase::OpToOpPassAdaptorBase(OpPassManager &&mgr) { +OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) { mgrs.emplace_back(std::move(mgr)); } /// Merge the current pass adaptor into given 'rhs'. -void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) { +void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) { for (auto &pm : mgrs) { // If an existing pass manager exists, then merge the given pass manager // into it. @@ -357,7 +348,7 @@ } /// Returns the adaptor pass name. -std::string OpToOpPassAdaptorBase::getName() { +std::string OpToOpPassAdaptor::getAdaptorName() { std::string name = "Pipeline Collection : ["; llvm::raw_string_ostream os(name); llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) { @@ -367,11 +358,16 @@ return os.str(); } -OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) - : OpToOpPassAdaptorBase(std::move(mgr)) {} - /// Run the held pipeline over all nested operations. void OpToOpPassAdaptor::runOnOperation() { + if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded()) + runOnOperationImpl(); + else + runOnOperationAsyncImpl(); +} + +/// Run this pass adaptor synchronously. +void OpToOpPassAdaptor::runOnOperationImpl() { auto am = getAnalysisManager(); PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), this}; @@ -397,9 +393,6 @@ } } -OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(OpPassManager &&mgr) - : OpToOpPassAdaptorBase(std::move(mgr)) {} - /// Utility functor that checks if the two ranges of pass managers have a size /// mismatch. static bool hasSizeMismatch(ArrayRef lhs, @@ -409,8 +402,8 @@ [&](size_t i) { return lhs[i].size() != rhs[i].size(); }); } -// Run the held pipeline asynchronously across the functions within the module. -void OpToOpPassAdaptorParallel::runOnOperation() { +/// Run this pass adaptor synchronously. +void OpToOpPassAdaptor::runOnOperationAsyncImpl() { AnalysisManager am = getAnalysisManager(); // Create the async executors if they haven't been created, or if the main @@ -491,16 +484,6 @@ signalPassFailure(); } -/// Utility function to convert the given class to the base adaptor it is an -/// adaptor pass, returns nullptr otherwise. -OpToOpPassAdaptorBase *mlir::detail::getAdaptorPassBase(Pass *pass) { - if (auto *adaptor = dyn_cast(pass)) - return adaptor; - if (auto *adaptor = dyn_cast(pass)) - return adaptor; - return nullptr; -} - //===----------------------------------------------------------------------===// // PassCrashReproducer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -27,69 +27,44 @@ // OpToOpPassAdaptor //===----------------------------------------------------------------------===// -/// A base class for Op-to-Op adaptor passes. -class OpToOpPassAdaptorBase { -public: - OpToOpPassAdaptorBase(OpPassManager &&mgr); - OpToOpPassAdaptorBase(const OpToOpPassAdaptorBase &rhs) = default; - - /// Merge the current pass adaptor into given 'rhs'. - void mergeInto(OpToOpPassAdaptorBase &rhs); - - /// Returns the pass managers held by this adaptor. - MutableArrayRef getPassManagers() { return mgrs; } - - /// Returns the adaptor pass name. - std::string getName(); - -protected: - // A set of adaptors to run. - SmallVector mgrs; -}; - -/// An adaptor pass used to run operation passes over nested operations -/// synchronously on a single thread. +/// An adaptor pass used to run operation passes over nested operations. class OpToOpPassAdaptor - : public PassWrapper>, - public OpToOpPassAdaptorBase { + : public PassWrapper> { public: OpToOpPassAdaptor(OpPassManager &&mgr); + OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default; /// Run the held pipeline over all operations. void runOnOperation() override; -}; -/// An adaptor pass used to run operation passes over nested operations -/// asynchronously across multiple threads. -class OpToOpPassAdaptorParallel - : public PassWrapper>, - public OpToOpPassAdaptorBase { -public: - OpToOpPassAdaptorParallel(OpPassManager &&mgr); + /// Merge the current pass adaptor into given 'rhs'. + void mergeInto(OpToOpPassAdaptor &rhs); - /// Run the held pipeline over all operations. - void runOnOperation() override; + /// Returns the pass managers held by this adaptor. + MutableArrayRef getPassManagers() { return mgrs; } /// Return the async pass managers held by this parallel adaptor. MutableArrayRef> getParallelPassManagers() { return asyncExecutors; } + /// Returns the adaptor pass name. + std::string getAdaptorName(); + private: - // A set of executors, cloned from the main executor, that run asynchronously - // on different threads. - SmallVector, 8> asyncExecutors; -}; + /// Run this pass adaptor synchronously. + void runOnOperationImpl(); + + /// Run this pass adaptor asynchronously. + void runOnOperationAsyncImpl(); -/// Utility function to convert the given class to the base adaptor it is an -/// adaptor pass, returns nullptr otherwise. -OpToOpPassAdaptorBase *getAdaptorPassBase(Pass *pass); + /// A set of adaptors to run. + SmallVector mgrs; -/// Utility function to return if a pass refers to an adaptor pass. Adaptor -/// passes are those that internally execute a pipeline. -inline bool isAdaptorPass(Pass *pass) { - return isa(pass) || isa(pass); -} + /// A set of executors, cloned from the main executor, that run asynchronously + /// on different threads. This is used when threading is enabled. + SmallVector, 8> asyncExecutors; +}; } // end namespace detail } // end namespace mlir diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -60,7 +60,7 @@ static void printResultsAsList(raw_ostream &os, OpPassManager &pm) { llvm::StringMap> mergedStats; std::function addStats = [&](Pass *pass) { - auto *adaptor = getAdaptorPassBase(pass); + auto *adaptor = dyn_cast(pass); // If this is not an adaptor, add the stats to the list if there are any. if (!adaptor) { @@ -105,13 +105,12 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) { std::function printPass = [&](unsigned indent, Pass *pass) { - // Handle the case of an adaptor pass. - if (auto *adaptor = getAdaptorPassBase(pass)) { + if (auto *adaptor = dyn_cast(pass)) { // If this adaptor has more than one internal pipeline, print an entry for // it. auto mgrs = adaptor->getPassManagers(); if (mgrs.size() > 1) { - printPassEntry(os, indent, adaptor->getName()); + printPassEntry(os, indent, adaptor->getAdaptorName()); indent += 2; } @@ -195,8 +194,8 @@ Pass &pass = std::get<0>(passPair), &otherPass = std::get<1>(passPair); // If this is an adaptor, then recursively merge the pass managers. - if (auto *adaptorPass = getAdaptorPassBase(&pass)) { - auto *otherAdaptorPass = getAdaptorPassBase(&otherPass); + if (auto *adaptorPass = dyn_cast(&pass)) { + auto *otherAdaptorPass = cast(&otherPass); for (auto mgrs : llvm::zip(adaptorPass->getPassManagers(), otherAdaptorPass->getPassManagers())) std::get<0>(mgrs).mergeStatisticsInto(std::get<1>(mgrs)); @@ -217,18 +216,16 @@ /// consumption(e.g. dumping). static void prepareStatistics(OpPassManager &pm) { for (Pass &pass : pm.getPasses()) { - OpToOpPassAdaptorBase *adaptor = getAdaptorPassBase(&pass); + OpToOpPassAdaptor *adaptor = dyn_cast(&pass); if (!adaptor) continue; MutableArrayRef nestedPms = adaptor->getPassManagers(); - // If this is a parallel adaptor, merge the statistics from the async - // pass managers into the main nested pass managers. - if (auto *parallelAdaptor = dyn_cast(&pass)) { - for (auto &asyncPM : parallelAdaptor->getParallelPassManagers()) { - for (unsigned i = 0, e = asyncPM.size(); i != e; ++i) - asyncPM[i].mergeStatisticsInto(nestedPms[i]); - } + // Merge the statistics from the async pass managers into the main nested + // pass managers. + for (auto &asyncPM : adaptor->getParallelPassManagers()) { + for (unsigned i = 0, e = asyncPM.size(); i != e; ++i) + asyncPM[i].mergeStatisticsInto(nestedPms[i]); } // Prepare the statistics of each of the nested passes. diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -276,17 +276,17 @@ /// Start a new timer for the given pass. void PassTiming::startPassTimer(Pass *pass) { - auto kind = isAdaptorPass(pass) ? TimerKind::PipelineCollection - : TimerKind::PassOrAnalysis; + auto kind = isa(pass) ? TimerKind::PipelineCollection + : TimerKind::PassOrAnalysis; Timer *timer = getTimer(pass, kind, [pass]() -> std::string { - if (auto *adaptor = getAdaptorPassBase(pass)) - return adaptor->getName(); + if (auto *adaptor = dyn_cast(pass)) + return adaptor->getAdaptorName(); return std::string(pass->getName()); }); // We don't actually want to time the adaptor passes, they gather their total // from their held passes. - if (!isAdaptorPass(pass)) + if (!isa(pass)) timer->start(); } @@ -301,9 +301,9 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) { Timer *timer = popLastActiveTimer(); - // If this is an OpToOpPassAdaptorParallel, then we need to merge in the - // timing data for the pipelines running on other threads. - if (isa(pass)) { + // If this is a pass adaptor, then we need to merge in the timing data for the + // pipelines running on other threads. + if (isa(pass)) { auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass}); if (toMerge != pipelinesToMerge.end()) { for (auto &it : toMerge->second) @@ -313,10 +313,7 @@ return; } - // Adaptor passes aren't timed directly, so we don't need to stop their - // timers. - if (!isAdaptorPass(pass)) - timer->stop(); + timer->stop(); } /// Stop a timer.