diff --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h --- a/llvm/include/llvm/ADT/Sequence.h +++ b/llvm/include/llvm/ADT/Sequence.h @@ -42,6 +42,10 @@ value_sequence_iterator(const value_sequence_iterator &) = default; value_sequence_iterator(value_sequence_iterator &&Arg) : Value(std::move(Arg.Value)) {} + value_sequence_iterator &operator=(const value_sequence_iterator &Arg) { + Value = Arg.Value; + return *this; + } template ()))> value_sequence_iterator(U &&Value) : Value(std::forward(Value)) {} diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -98,7 +98,7 @@ /// A derived analysis model used to hold a specific analysis object. template struct AnalysisModel : public AnalysisConcept { template - explicit AnalysisModel(Args &&... args) + explicit AnalysisModel(Args &&...args) : analysis(std::forward(args)...) {} /// A hook used to query analyses for invalidation. @@ -198,7 +198,10 @@ /// An analysis map that contains a map for the current operation, and a set of /// maps for any child operations. struct NestedAnalysisMap { - NestedAnalysisMap(Operation *op) : analyses(op) {} + NestedAnalysisMap(Operation *op, PassInstrumentor *instrumentor) + : analyses(op), parentOrInstrumentor(instrumentor) {} + NestedAnalysisMap(Operation *op, NestedAnalysisMap *parent) + : analyses(op), parentOrInstrumentor(parent) {} /// Get the operation for this analysis map. Operation *getOperation() const { return analyses.getOperation(); } @@ -206,11 +209,34 @@ /// Invalidate any non preserved analyses. void invalidate(const PreservedAnalyses &pa); + /// Returns the parent analysis map for this analysis map, or null if this is + /// the top-level map. + const NestedAnalysisMap *getParent() const { + return parentOrInstrumentor.dyn_cast(); + } + + /// Returns a pass instrumentation object for the current operation. This + /// value may be null. + PassInstrumentor *getPassInstrumentor() const { + if (auto *parent = getParent()) + return parent->getPassInstrumentor(); + return parentOrInstrumentor.get(); + } + /// The cached analyses for nested operations. DenseMap> childAnalyses; - /// The analyses for the owning module. + /// The analyses for the owning operation. detail::AnalysisMap analyses; + + /// This value has three possible states: + /// NestedAnalysisMap*: A pointer to the parent analysis map. + /// PassInstrumentor*: This analysis map is the top-level map, and this + /// pointer is the optional pass instrumentor for the + /// current compilation. + /// nullptr: This analysis map is the top-level map, and there is nop pass + /// instrumentor. + PointerUnion parentOrInstrumentor; }; } // namespace detail @@ -236,11 +262,11 @@ template Optional> getCachedParentAnalysis(Operation *parentOp) const { - ParentPointerT curParent = parent; - while (auto *parentAM = curParent.dyn_cast()) { - if (parentAM->impl->getOperation() == parentOp) - return parentAM->getCachedAnalysis(); - curParent = parentAM->parent; + const detail::NestedAnalysisMap *curParent = impl; + while (auto *parentAM = curParent->getParent()) { + if (parentAM->getOperation() == parentOp) + return parentAM->analyses.getCachedAnalysis(); + curParent = parentAM; } return None; } @@ -286,7 +312,8 @@ return it->second->analyses.getCachedAnalysis(); } - /// Get an analysis manager for the given child operation. + /// Get an analysis manager for the given operation, which must be a proper + /// descendant of the current operation represented by this analysis manager. AnalysisManager nest(Operation *op); /// Invalidate any non preserved analyses, @@ -300,19 +327,15 @@ /// Returns a pass instrumentation object for the current operation. This /// value may be null. - PassInstrumentor *getPassInstrumentor() const; + PassInstrumentor *getPassInstrumentor() const { + return impl->getPassInstrumentor(); + } private: - AnalysisManager(const AnalysisManager *parent, - detail::NestedAnalysisMap *impl) - : parent(parent), impl(impl) {} - AnalysisManager(const ModuleAnalysisManager *parent, - detail::NestedAnalysisMap *impl) - : parent(parent), impl(impl) {} + AnalysisManager(detail::NestedAnalysisMap *impl) : impl(impl) {} - /// A reference to the parent analysis manager, or the top-level module - /// analysis manager. - ParentPointerT parent; + /// Get an analysis manager for the given immediately nested child operation. + AnalysisManager nestImmediate(Operation *op); /// A reference to the impl analysis map within the parent analysis manager. detail::NestedAnalysisMap *impl; @@ -328,23 +351,16 @@ class ModuleAnalysisManager { public: ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor) - : analyses(op), passInstrumentor(passInstrumentor) {} + : analyses(op, passInstrumentor) {} ModuleAnalysisManager(const ModuleAnalysisManager &) = delete; ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete; - /// Returns a pass instrumentation object for the current module. This value - /// may be null. - PassInstrumentor *getPassInstrumentor() const { return passInstrumentor; } - /// Returns an analysis manager for the current top-level module. - operator AnalysisManager() { return AnalysisManager(this, &analyses); } + operator AnalysisManager() { return AnalysisManager(&analyses); } private: /// The analyses for the owning module. detail::NestedAnalysisMap analyses; - - /// An optional instrumentation object. - PassInstrumentor *passInstrumentor; }; } // end namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -95,7 +95,7 @@ typename OptionParser = detail::PassOptions::OptionParser> struct Option : public detail::PassOptions::Option { template - Option(Pass &parent, StringRef arg, Args &&... args) + Option(Pass &parent, StringRef arg, Args &&...args) : detail::PassOptions::Option( parent.passOptions, arg, std::forward(args)...) {} using detail::PassOptions::Option::operator=; @@ -107,14 +107,17 @@ struct ListOption : public detail::PassOptions::ListOption { template - ListOption(Pass &parent, StringRef arg, Args &&... args) + ListOption(Pass &parent, StringRef arg, Args &&...args) : detail::PassOptions::ListOption( parent.passOptions, arg, std::forward(args)...) {} using detail::PassOptions::ListOption::operator=; }; /// Attempt to initialize the options of this pass from the given string. - LogicalResult initializeOptions(StringRef options); + /// Derived classes may override this method to hook into the point at which + /// options are initialized, but should generally always invoke this base + /// class variant. + virtual LogicalResult initializeOptions(StringRef options); /// Prints out the pass in the textual representation of pipelines. If this is /// an adaptor pass, print with the op_name(sub_pass,...) format. @@ -265,7 +268,6 @@ void copyOptionValuesFrom(const Pass *other); private: - /// Out of line virtual method to ensure vtables and metadata are emitted to a /// single .o file. virtual void anchor(); diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -48,8 +48,8 @@ class OpPassManager { public: enum class Nesting { Implicit, Explicit }; - OpPassManager(Identifier name, Nesting nesting); - OpPassManager(StringRef name, Nesting nesting); + OpPassManager(Identifier name, Nesting nesting = Nesting::Explicit); + OpPassManager(StringRef name, Nesting nesting = Nesting::Explicit); OpPassManager(OpPassManager &&rhs); OpPassManager(const OpPassManager &rhs); ~OpPassManager(); diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -107,6 +107,19 @@ /// Creates a pass which inlines calls and callable operations as defined by /// the CallGraph. std::unique_ptr createInlinerPass(); +/// Creates an instance of the inliner pass, and use the provided pass managers +/// when optimizing callable operations with names matching the key type. +/// Callable operations with a name not within the provided map will use the +/// default inliner pipeline during optimization. +std::unique_ptr +createInlinerPass(llvm::StringMap opPipelines); +/// Creates an instance of the inliner pass, and use the provided pass managers +/// when optimizing callable operations with names matching the key type. +/// Callable operations with a name not within the provided map will use the +/// provided default pipeline builder. +std::unique_ptr +createInlinerPass(llvm::StringMap opPipelines, + std::function defaultPipelineBuilder); /// Creates a pass which performs sparse conditional constant propagation over /// nested operations. diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -285,9 +285,12 @@ let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; let options = [ - Option<"disableCanonicalization", "disable-simplify", "bool", - /*default=*/"false", - "Disable running simplifications during inlining">, + Option<"defaultPipelineStr", "default-pipeline", "std::string", + /*default=*/"", "The default optimizer pipeline used for callables">, + ListOption<"opPipelineStrs", "op-pipelines", "std::string", + "Callable operation specific optimizer pipelines (in the form " + "of `dialect.op(pipeline)`)", + "llvm::cl::MiscFlags::CommaSeparated">, Option<"maxInliningIterations", "max-iterations", "unsigned", /*default=*/"4", "Maximum number of iterations when inlining within an SCC">, diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -340,22 +340,25 @@ // Initialize the pass state with a callback for the pass to dynamically // execute a pipeline on the currently visited operation. - auto dynamic_pipeline_callback = - [op, &am, verifyPasses](OpPassManager &pipeline, - Operation *root) -> LogicalResult { + PassInstrumentor *pi = am.getPassInstrumentor(); + PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), + pass}; + auto dynamic_pipeline_callback = [&](OpPassManager &pipeline, + Operation *root) -> LogicalResult { if (!op->isAncestor(root)) return root->emitOpError() << "Trying to schedule a dynamic pipeline on an " "operation that isn't " "nested under the current operation the pass is processing"; + assert(pipeline.getOpName() == root->getName().getStringRef()); - AnalysisManager nestedAm = am.nest(root); + AnalysisManager nestedAm = root == op ? am : am.nest(root); return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm, - verifyPasses); + verifyPasses, pi, &parentInfo); }; pass->passState.emplace(op, am, dynamic_pipeline_callback); + // Instrument before the pass has run. - PassInstrumentor *pi = am.getPassInstrumentor(); if (pi) pi->runBeforePass(pass, op); @@ -388,7 +391,10 @@ /// Run the given operation and analysis manager on a provided op pass manager. LogicalResult OpToOpPassAdaptor::runPipeline( iterator_range passes, Operation *op, - AnalysisManager am, bool verifyPasses) { + AnalysisManager am, bool verifyPasses, PassInstrumentor *instrumentor, + const PassInstrumentation::PipelineParentInfo *parentInfo) { + assert((!instrumentor || parentInfo) && + "expected parent info if instrumentor is provided"); auto scope_exit = llvm::make_scope_exit([&] { // Clear out any computed operation analyses. These analyses won't be used // any more in this pipeline, and this helps reduce the current working set @@ -398,10 +404,13 @@ }); // Run the pipeline over the provided operation. + if (instrumentor) + instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo); for (Pass &pass : passes) if (failed(run(&pass, op, am, verifyPasses))) return failure(); - + if (instrumentor) + instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo); return success(); } @@ -491,17 +500,10 @@ *op.getContext()); if (!mgr) continue; - Identifier opName = mgr->getOpName(*getOperation()->getContext()); // Run the held pipeline over the current operation. - if (instrumentor) - instrumentor->runBeforePipeline(opName, parentInfo); - LogicalResult result = - runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses); - if (instrumentor) - instrumentor->runAfterPipeline(opName, parentInfo); - - if (failed(result)) + if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op), + verifyPasses, instrumentor, &parentInfo))) return signalPassFailure(); } } @@ -576,13 +578,9 @@ pms, it.first->getName().getIdentifier(), getContext()); assert(pm && "expected valid pass manager for operation"); - Identifier opName = pm->getOpName(*getOperation()->getContext()); - if (instrumentor) - instrumentor->runBeforePipeline(opName, parentInfo); - auto pipelineResult = - runPipeline(pm->getPasses(), it.first, it.second, verifyPasses); - if (instrumentor) - instrumentor->runAfterPipeline(opName, parentInfo); + LogicalResult pipelineResult = + runPipeline(pm->getPasses(), it.first, it.second, verifyPasses, + instrumentor, &parentInfo); // Drop this thread from being tracked by the diagnostic handler. // After this task has finished, the thread may be used outside of @@ -848,22 +846,41 @@ // AnalysisManager //===----------------------------------------------------------------------===// -/// Returns a pass instrumentation object for the current operation. -PassInstrumentor *AnalysisManager::getPassInstrumentor() const { - ParentPointerT curParent = parent; - while (auto *parentAM = curParent.dyn_cast()) - curParent = parentAM->parent; - return curParent.get()->getPassInstrumentor(); +/// Get an analysis manager for the given operation, which must be a proper +/// descendant of the current operation represented by this analysis manager. +AnalysisManager AnalysisManager::nest(Operation *op) { + Operation *currentOp = impl->getOperation(); + assert(currentOp->isProperAncestor(op) && + "expected valid descendant operation"); + + // Check for the base case where the provided operation is immediately nested. + if (currentOp == op->getParentOp()) + return nestImmediate(op); + + // Otherwise, we need to collect all ancestors up to the current operation. + SmallVector opAncestors; + do { + opAncestors.push_back(op); + op = op->getParentOp(); + } while (op != currentOp); + + AnalysisManager result = *this; + for (Operation *op : llvm::reverse(opAncestors)) + result = result.nestImmediate(op); + return result; } -/// Get an analysis manager for the given child operation. -AnalysisManager AnalysisManager::nest(Operation *op) { +/// Get an analysis manager for the given immediately nested child operation. +AnalysisManager AnalysisManager::nestImmediate(Operation *op) { + assert(impl->getOperation() == op->getParentOp() && + "expected immediate child operation"); + auto it = impl->childAnalyses.find(op); if (it == impl->childAnalyses.end()) it = impl->childAnalyses - .try_emplace(op, std::make_unique(op)) + .try_emplace(op, std::make_unique(op, impl)) .first; - return {this, it->second.get()}; + return {it->second.get()}; } /// Invalidate any non preserved analyses. diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -60,9 +60,11 @@ /// Run the given operation and analysis manager on a provided op pass /// manager. - static LogicalResult - runPipeline(iterator_range passes, - Operation *op, AnalysisManager am, bool verifyPasses); + static LogicalResult runPipeline( + iterator_range passes, Operation *op, + AnalysisManager am, bool verifyPasses, + PassInstrumentor *instrumentor = nullptr, + const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr); /// A set of adaptors to run. SmallVector mgrs; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -291,11 +291,15 @@ /// given to enable accurate error reporting. LogicalResult TextualPipeline::initialize(StringRef text, raw_ostream &errorStream) { + if (text.empty()) + return success(); + // Build a source manager to use for error reporting. llvm::SourceMgr pipelineMgr; - pipelineMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer( - text, "MLIR Textual PassPipeline Parser"), - llvm::SMLoc()); + pipelineMgr.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser", + /*RequiresNullTerminator=*/false), + llvm::SMLoc()); auto errorHandler = [&](const char *rawLoc, Twine msg) { pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc), llvm::SourceMgr::DK_Error, msg); @@ -327,7 +331,7 @@ pipeline.emplace_back(/*name=*/text.substr(0, pos).trim()); // If we have a single terminating name, we're done. - if (pos == text.npos) + if (pos == StringRef::npos) break; text = text.substr(pos); @@ -338,9 +342,19 @@ text = text.substr(1); // Skip over everything until the closing '}' and store as options. - size_t close = text.find('}'); + size_t close = StringRef::npos; + for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) { + if (text[i] == '{') { + ++braceCount; + continue; + } + if (text[i] == '}' && --braceCount == 0) { + close = i; + break; + } + } - // TODO: Handle skipping over quoted sub-strings. + // Check to see if a closing options brace was found. if (close == StringRef::npos) { return errorHandler( /*rawLoc=*/text.data() - 1, diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -302,16 +302,13 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) { Timer *timer = popLastActiveTimer(); - // If this is a pass adaptor, then we need to merge in the timing data for the - // pipelines running on other threads. - if (isa(pass)) { - auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass}); - if (toMerge != pipelinesToMerge.end()) { - for (auto &it : toMerge->second) - timer->mergeChild(std::move(it)); - pipelinesToMerge.erase(toMerge); - } - return; + // Check to see if we need to merge in the timing data for the pipelines + // running on other threads. + auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass}); + if (toMerge != pipelinesToMerge.end()) { + for (auto &it : toMerge->second) + timer->mergeChild(std::move(it)); + pipelinesToMerge.erase(toMerge); } timer->stop(); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -15,9 +15,8 @@ #include "PassDetail.h" #include "mlir/Analysis/CallGraph.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SCCIterator.h" @@ -28,6 +27,11 @@ using namespace mlir; +/// This function implements the default inliner optimization pipeline. +static void defaultInlinerOptPipeline(OpPassManager &pm) { + pm.addPass(createCanonicalizerPass()); +} + //===----------------------------------------------------------------------===// // Symbol Use Tracking //===----------------------------------------------------------------------===// @@ -279,9 +283,9 @@ /// Run a given transformation over the SCCs of the callgraph in a bottom up /// traversal. -static void -runTransformOnCGSCCs(const CallGraph &cg, - function_ref sccTransformer) { +static LogicalResult runTransformOnCGSCCs( + const CallGraph &cg, + function_ref sccTransformer) { llvm::scc_iterator cgi = llvm::scc_begin(&cg); CallGraphSCC currentSCC(cgi); while (!cgi.isAtEnd()) { @@ -289,8 +293,10 @@ // SCC without invalidating our iterator. currentSCC.reset(*cgi); ++cgi; - sccTransformer(currentSCC); + if (failed(sccTransformer(currentSCC))) + return failure(); } + return success(); } namespace { @@ -499,85 +505,94 @@ return success(inlinedAnyCalls); } -/// Canonicalize the nodes within the given SCC with the given set of -/// canonicalization patterns. -static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, - CallGraphSCC ¤tSCC, MLIRContext *context, - const FrozenRewritePatternList &canonPatterns) { - // Collect the sets of nodes to canonicalize. - SmallVector nodesToCanonicalize; - for (auto *node : currentSCC) { - // Don't canonicalize the external node, it has no valid callable region. - if (node->isExternal()) - continue; - - // Don't canonicalize nodes with children. Nodes with children - // require special handling as we may remove the node during - // canonicalization. In the future, we should be able to handle this - // case with proper node deletion tracking. - if (node->hasChildren()) - continue; - - // We also won't apply canonicalizations for nodes that are not - // isolated. This avoids potentially mutating the regions of nodes defined - // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily' - // driver. - auto *region = node->getCallableRegion(); - if (!region->getParentOp()->isKnownIsolatedFromAbove()) - continue; - nodesToCanonicalize.push_back(node); - } - if (nodesToCanonicalize.empty()) - return; - - // Canonicalize each of the nodes within the SCC in parallel. - // NOTE: This is simple now, because we don't enable canonicalizing nodes - // within children. When we remove this restriction, this logic will need to - // be reworked. - if (context->isMultithreadingEnabled()) { - ParallelDiagnosticHandler canonicalizationHandler(context); - llvm::parallelForEachN( - /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) { - // Set the order for this thread so that diagnostics will be properly - // ordered. - canonicalizationHandler.setOrderIDForThread(index); - - // Apply the canonicalization patterns to this region. - auto *node = nodesToCanonicalize[index]; - applyPatternsAndFoldGreedily(*node->getCallableRegion(), - canonPatterns); - - // Make sure to reset the order ID for the diagnostic handler, as this - // thread may be used in a different context. - canonicalizationHandler.eraseOrderIDForThread(); - }); - } else { - for (CallGraphNode *node : nodesToCanonicalize) - applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); - } - - // Recompute the uses held by each of the nodes. - for (CallGraphNode *node : nodesToCanonicalize) - useList.recomputeUses(node, cg); -} - //===----------------------------------------------------------------------===// // InlinerPass //===----------------------------------------------------------------------===// namespace { -struct InlinerPass : public InlinerBase { +class InlinerPass : public InlinerBase { +public: + InlinerPass(); + InlinerPass(const InlinerPass &) = default; + InlinerPass(std::function defaultPipeline); + InlinerPass(std::function defaultPipeline, + llvm::StringMap opPipelines); void runOnOperation() override; - /// Attempt to inline calls within the given scc, and run canonicalizations - /// with the given patterns, until a fixed point is reached. This allows for - /// the inlining of newly devirtualized calls. - void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, - MLIRContext *context, - const FrozenRewritePatternList &canonPatterns); +private: + /// Attempt to inline calls within the given scc, and run simplifications, + /// until a fixed point is reached. This allows for the inlining of newly + /// devirtualized calls. Returns failure if there was a fatal error during + /// inlining. + LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList, + CallGraphSCC ¤tSCC, MLIRContext *context); + + /// Optimize the nodes within the given SCC with one of the held optimization + /// pass pipelines. Returns failure if an error occurred during the + /// optimization of the SCC, success otherwise. + LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList, + CallGraphSCC ¤tSCC, MLIRContext *context); + + /// Optimize the nodes within the given SCC in parallel. Returns failure if an + /// error occurred during the optimization of the SCC, success otherwise. + LogicalResult optimizeSCCAsync(MutableArrayRef nodesToVisit, + MLIRContext *context); + + /// Optimize the given callable node with one of the pass managers provided + /// with `pipelines`, or the default pipeline. Returns failure if an error + /// occurred during the optimization of the callable, success otherwise. + LogicalResult optimizeCallable(CallGraphNode *node, + llvm::StringMap &pipelines); + + /// Attempt to initialize the options of this pass from the given string. + /// Derived classes may override this method to hook into the point at which + /// options are initialized, but should generally always invoke this base + /// class variant. + LogicalResult initializeOptions(StringRef options) override; + + /// An optional function that constructs a default optimization pipeline for + /// a given operation. + std::function defaultPipeline; + /// A map of operation names to pass pipelines to use when optimizing + /// callable operations of these types. This provides a specialized pipeline + /// instead of the default. The vector size is the number of threads used + /// during optimization. + SmallVector, 8> opPipelines; }; } // end anonymous namespace +InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {} +InlinerPass::InlinerPass(std::function defaultPipeline) + : defaultPipeline(defaultPipeline) { + opPipelines.push_back({}); + + // Initialize the pass options with the provided arguments. + if (defaultPipeline) { + OpPassManager fakePM("__mlir_fake_pm_op"); + defaultPipeline(fakePM); + llvm::raw_string_ostream strStream(defaultPipelineStr); + fakePM.printAsTextualPipeline(strStream); + } +} + +InlinerPass::InlinerPass(std::function defaultPipeline, + llvm::StringMap opPipelines) + : InlinerPass(std::move(defaultPipeline)) { + if (opPipelines.empty()) + return; + + // Update the option for the op specific optimization pipelines. + for (auto &it : opPipelines) { + std::string pipeline; + llvm::raw_string_ostream pipelineOS(pipeline); + pipelineOS << it.getKey() << "("; + it.second.printAsTextualPipeline(pipelineOS); + pipelineOS << ")"; + opPipelineStrs.addValue(pipeline); + } + this->opPipelines.emplace_back(std::move(opPipelines)); +} + void InlinerPass::runOnOperation() { CallGraph &cg = getAnalysis(); auto *context = &getContext(); @@ -591,42 +606,190 @@ return signalPassFailure(); } - // Collect a set of canonicalization patterns to use when simplifying - // callable regions within an SCC. - OwningRewritePatternList canonPatterns; - for (auto *op : context->getRegisteredOperations()) - op->getCanonicalizationPatterns(canonPatterns, context); - FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns)); - // Run the inline transform in post-order over the SCCs in the callgraph. SymbolTableCollection symbolTable; Inliner inliner(context, cg, symbolTable); CGUseList useList(getOperation(), cg, symbolTable); - runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { - inlineSCC(inliner, useList, scc, context, frozenCanonPatterns); + LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { + return inlineSCC(inliner, useList, scc, context); }); + if (failed(result)) + return signalPassFailure(); // After inlining, make sure to erase any callables proven to be dead. inliner.eraseDeadCallables(); } -void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, - CallGraphSCC ¤tSCC, MLIRContext *context, - const FrozenRewritePatternList &canonPatterns) { - // If we successfully inlined any calls, run some simplifications on the - // nodes of the scc. Continue attempting to inline until we reach a fixed - // point, or a maximum iteration count. We canonicalize here as it may - // devirtualize new calls, as well as give us a better cost model. +LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, + CallGraphSCC ¤tSCC, + MLIRContext *context) { + // Continuously simplify and inline until we either reach a fixed point, or + // hit the maximum iteration count. Simplifying early helps to refine the cost + // model, and in future iterations may devirtualize new calls. unsigned iterationCount = 0; - while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) { - // If we aren't allowing simplifications or the max iteration count was - // reached, then bail out early. - if (disableCanonicalization || ++iterationCount >= maxInliningIterations) + do { + if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context))) + return failure(); + if (failed(inlineCallsInSCC(inliner, useList, currentSCC))) break; - canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns); + } while (++iterationCount < maxInliningIterations); + return success(); +} + +LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList, + CallGraphSCC ¤tSCC, + MLIRContext *context) { + // Collect the sets of nodes to simplify. + SmallVector nodesToVisit; + for (auto *node : currentSCC) { + if (node->isExternal()) + continue; + + // Don't simplify nodes with children. Nodes with children require special + // handling as we may remove the node during simplification. In the future, + // we should be able to handle this case with proper node deletion tracking. + if (node->hasChildren()) + continue; + + // We also won't apply simplifications to nodes that can't have passes + // scheduled on them. + auto *region = node->getCallableRegion(); + if (!region->getParentOp()->isKnownIsolatedFromAbove()) + continue; + nodesToVisit.push_back(node); + } + if (nodesToVisit.empty()) + return success(); + + // Optimize each of the nodes within the SCC in parallel. + // NOTE: This is simple now, because we don't enable optimizing nodes within + // children. When we remove this restriction, this logic will need to be + // reworked. + if (context->isMultithreadingEnabled()) { + if (failed(optimizeSCCAsync(nodesToVisit, context))) + return failure(); + + // Otherwise, we are optimizing within a single thread. + } else { + for (CallGraphNode *node : nodesToVisit) { + if (failed(optimizeCallable(node, opPipelines[0]))) + return failure(); + } + } + + // Recompute the uses held by each of the nodes. + for (CallGraphNode *node : nodesToVisit) + useList.recomputeUses(node, cg); + return success(); +} + +LogicalResult +InlinerPass::optimizeSCCAsync(MutableArrayRef nodesToVisit, + MLIRContext *context) { + // Ensure that there are enough pipeline maps for the optimizer to run in + // parallel. + size_t numThreads = llvm::hardware_concurrency().compute_thread_count(); + if (opPipelines.size() != numThreads) { + // Reserve before resizing so that we can use a reference to the first + // element. + opPipelines.reserve(numThreads); + opPipelines.resize(numThreads, opPipelines.front()); + } + + // Ensure an analysis manager has been constructed for each of the nodes. + // This prevents thread races when running the nested pipelines. + for (CallGraphNode *node : nodesToVisit) + getAnalysisManager().nest(node->getCallableRegion()->getParentOp()); + + // An index for the current node to optimize. + std::atomic nodeIt(0); + + // Optimize the nodes of the SCC in parallel. + ParallelDiagnosticHandler optimizerHandler(context); + return llvm::parallelTransformReduce( + llvm::seq(0, numThreads), success(), + [](LogicalResult lhs, LogicalResult rhs) { + return success(succeeded(lhs) && succeeded(rhs)); + }, + [&](size_t index) { + LogicalResult result = success(); + for (auto e = nodesToVisit.size(); nodeIt < e && succeeded(result);) { + // Get the next available operation index. + unsigned nextID = nodeIt++; + if (nextID >= e) + break; + + // Set the order for this thread so that diagnostics will be + // properly ordered, and reset after optimization has finished. + optimizerHandler.setOrderIDForThread(nextID); + result = optimizeCallable(nodesToVisit[nextID], opPipelines[index]); + optimizerHandler.eraseOrderIDForThread(); + } + return result; + }); +} + +LogicalResult +InlinerPass::optimizeCallable(CallGraphNode *node, + llvm::StringMap &pipelines) { + Operation *callable = node->getCallableRegion()->getParentOp(); + StringRef opName = callable->getName().getStringRef(); + auto pipelineIt = pipelines.find(opName); + if (pipelineIt == pipelines.end()) { + // If a pipeline didn't exist, use the default if possible. + if (!defaultPipeline) + return success(); + + OpPassManager defaultPM(opName); + defaultPipeline(defaultPM); + pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first; } + return runPipeline(pipelineIt->second, callable); +} + +LogicalResult InlinerPass::initializeOptions(StringRef options) { + if (failed(Pass::initializeOptions(options))) + return failure(); + + // Initialize the default pipeline builder to use the option string. + if (!defaultPipelineStr.empty()) { + std::string defaultPipelineCopy = defaultPipelineStr; + defaultPipeline = [=](OpPassManager &pm) { + parsePassPipeline(defaultPipelineCopy, pm); + }; + } else if (defaultPipelineStr.getNumOccurrences()) { + defaultPipeline = nullptr; + } + + // Initialize the op specific pass pipelines. + llvm::StringMap pipelines; + for (StringRef pipeline : opPipelineStrs) { + // Pipelines are expected to be of the form `()`. + size_t pipelineStart = pipeline.find_first_of('('); + if (pipelineStart == StringRef::npos || !pipeline.consume_back(")")) + return failure(); + StringRef opName = pipeline.take_front(pipelineStart); + OpPassManager pm(opName); + if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm))) + return failure(); + pipelines.try_emplace(opName, std::move(pm)); + } + opPipelines.assign({std::move(pipelines)}); + + return success(); } std::unique_ptr mlir::createInlinerPass() { return std::make_unique(); } +std::unique_ptr +mlir::createInlinerPass(llvm::StringMap opPipelines) { + return std::make_unique(defaultInlinerOptPipeline, + std::move(opPipelines)); +} +std::unique_ptr +createInlinerPass(llvm::StringMap opPipelines, + std::function defaultPipelineBuilder) { + return std::make_unique(std::move(defaultPipelineBuilder), + std::move(opPipelines)); +} diff --git a/mlir/test/Dialect/Affine/inlining.mlir b/mlir/test/Dialect/Affine/inlining.mlir --- a/mlir/test/Dialect/Affine/inlining.mlir +++ b/mlir/test/Dialect/Affine/inlining.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -inline="disable-simplify" | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -inline="default-pipeline=''" | FileCheck %s // Basic test that functions within affine operations are inlined. func @func_with_affine_ops(%N: index) { diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline{disable-simplify})' | FileCheck %s +// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline{default-pipeline=''})' | FileCheck %s spv.module Logical GLSL450 { spv.func @callee() "None" { diff --git a/mlir/test/Pass/dynamic-pipeline-nested.mlir b/mlir/test/Pass/dynamic-pipeline-nested.mlir --- a/mlir/test/Pass/dynamic-pipeline-nested.mlir +++ b/mlir/test/Pass/dynamic-pipeline-nested.mlir @@ -20,9 +20,9 @@ // CHECK: Dump Before CSE // NOTNESTED-NEXT: @inner_mod1 // NESTED-NEXT: @foo - func private @foo() + module @foo {} // Only in the nested case we have a second run of the pass here. // NESTED: Dump Before CSE // NESTED-NEXT: @baz - func private @baz() + module @baz {} } diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -inline="disable-simplify" | FileCheck %s -// RUN: mlir-opt %s -inline="disable-simplify" -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC +// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s +// RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC // RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY // Inline a function that takes an argument. diff --git a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp --- a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp +++ b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp @@ -35,15 +35,17 @@ TestDynamicPipelinePass(const TestDynamicPipelinePass &) {} void runOnOperation() override { + Operation *currentOp = getOperation(); + llvm::errs() << "Dynamic execute '" << pipeline << "' on " - << getOperation()->getName() << "\n"; + << currentOp->getName() << "\n"; if (pipeline.empty()) { llvm::errs() << "Empty pipeline\n"; return; } - auto symbolOp = dyn_cast(getOperation()); + auto symbolOp = dyn_cast(currentOp); if (!symbolOp) { - getOperation()->emitWarning() + currentOp->emitWarning() << "Ignoring because not implementing SymbolOpInterface\n"; return; } @@ -54,24 +56,24 @@ return; } if (!pm) { - pm = std::make_unique( - getOperation()->getName().getIdentifier(), - OpPassManager::Nesting::Implicit); + pm = std::make_unique(currentOp->getName().getIdentifier(), + OpPassManager::Nesting::Implicit); parsePassPipeline(pipeline, *pm, llvm::errs()); } // Check that running on the parent operation always immediately fails. if (runOnParent) { - if (getOperation()->getParentOp()) - if (!failed(runPipeline(*pm, getOperation()->getParentOp()))) + if (currentOp->getParentOp()) + if (!failed(runPipeline(*pm, currentOp->getParentOp()))) signalPassFailure(); return; } if (runOnNestedOp) { llvm::errs() << "Run on nested op\n"; - getOperation()->walk([&](Operation *op) { - if (op == getOperation() || !op->isKnownIsolatedFromAbove()) + currentOp->walk([&](Operation *op) { + if (op == currentOp || !op->isKnownIsolatedFromAbove() || + op->getName() != currentOp->getName()) return; llvm::errs() << "Run on " << *op << "\n"; // Run on the current operation @@ -80,7 +82,7 @@ }); } else { // Run on the current operation - if (failed(runPipeline(*pm, getOperation()))) + if (failed(runPipeline(*pm, currentOp))) signalPassFailure(); } }