diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -64,8 +64,10 @@ /// Print the help information for this pass. This includes the argument, /// description, and any pass options. `descIndent` is the indent that the - /// descriptions should be aligned. - void printHelpStr(size_t indent, size_t descIndent) const; + /// descriptions should be aligned. If `printOptions` is set to true, the + /// help descriptions for the options of this entry are also printed. + void printHelpStr(size_t indent, size_t descIndent, + bool printOptions = true) const; /// Return the maximum width required when printing the options of this entry. size_t getOptionWidth() const; @@ -246,6 +248,26 @@ std::unique_ptr impl; }; +/// This class implements a command-line parser spefically for MLIR pass names. +/// It registers a cl option with a given argument and description that accepts +/// a comma delimited list of pass names. +class PassNameCLParser { +public: + /// Construct a parser with the given command line description. + PassNameCLParser(StringRef arg, StringRef description); + ~PassNameCLParser(); + + /// Returns true if this parser contains any valid options to add. + bool hasAnyOccurrences() const; + + /// Returns true if the given pass registry entry was registered at the + /// top-level of the parser, i.e. not within an explicit textual pipeline. + bool contains(const PassRegistryEntry *entry) const; + +private: + std::unique_ptr impl; +}; + } // end namespace mlir #endif // MLIR_PASS_PASSREGISTRY_H_ diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -32,10 +32,10 @@ //===--------------------------------------------------------------------===// // IR Printing //===--------------------------------------------------------------------===// - PassPipelineCLParser printBefore{"print-ir-before", - "Print IR before specified passes"}; - PassPipelineCLParser printAfter{"print-ir-after", - "Print IR after specified passes"}; + PassNameCLParser printBefore{"print-ir-before", + "Print IR before specified passes"}; + PassNameCLParser printAfter{"print-ir-after", + "Print IR after specified passes"}; llvm::cl::opt printBeforeAll{ "print-ir-before-all", llvm::cl::desc("Print IR before each pass"), llvm::cl::init(false)}; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -56,14 +56,18 @@ /// Print the help information for this pass. This includes the argument, /// description, and any pass options. `descIndent` is the indent that the -/// descriptions should be aligned. -void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { +/// descriptions should be aligned. If `printOptions` is set to true, the help +/// descriptions for the options of this entry are also printed. +void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent, + bool printOptions) const { printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent, /*isTopLevel=*/true); // If this entry has options, print the help for those as well. - optHandler([=](const PassOptions &options) { - options.printHelp(indent, descIndent); - }); + if (printOptions) { + optHandler([=](const PassOptions &options) { + options.printHelp(indent, descIndent); + }); + } } /// Return the maximum width required when printing the options of this @@ -529,6 +533,12 @@ size_t getOptionWidth(const llvm::cl::Option &opt) const override; bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, PassArgData &value); + + /// If true, this parser only parses entries that correspond to a concrete + /// pass registry entry, and does not add a `pass-pipeline` argument, does not + /// include the options for pass entries, and does not include pass pipelines + /// entries. + bool passNamesOnly = false; }; } // namespace @@ -536,8 +546,10 @@ llvm::cl::parser::initialize(); /// Add an entry for the textual pass pipeline option. - addLiteralOption(passPipelineArg, PassArgData(), - "A textual description of a pass pipeline to run"); + if (!passNamesOnly) { + addLiteralOption(passPipelineArg, PassArgData(), + "A textual description of a pass pipeline to run"); + } /// Add the pass entries. for (const auto &kv : *passRegistry) { @@ -545,9 +557,11 @@ kv.second.getPassDescription()); } /// Add the pass pipeline entries. - for (const auto &kv : *passPipelineRegistry) { - addLiteralOption(kv.second.getPassArgument(), &kv.second, - kv.second.getPassDescription()); + if (!passNamesOnly) { + for (const auto &kv : *passPipelineRegistry) { + addLiteralOption(kv.second.getPassArgument(), &kv.second, + kv.second.getPassDescription()); + } } } @@ -561,13 +575,9 @@ llvm::outs() << " " << opt.HelpStr << '\n'; } - // Print the top-level pipeline argument. - printOptionHelp(passPipelineArg, - "A textual description of a pass pipeline to run", - /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr()); - // Functor used to print the ordered entries of a registration map. - auto printOrderedEntries = [&](StringRef header, auto &map) { + auto printOrderedEntries = [&](auto &map, size_t indent, bool includeOptions, + StringRef header = {}) { llvm::SmallVector orderedEntries; for (auto &kv : map) orderedEntries.push_back(&kv.second); @@ -577,17 +587,30 @@ return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument()); }); - llvm::outs().indent(4) << header << ":\n"; + if (!header.empty()) + llvm::outs().indent(indent) << header << ":\n"; for (PassRegistryEntry *entry : orderedEntries) - entry->printHelpStr(/*indent=*/6, globalWidth); + entry->printHelpStr(/*indent=*/2 + indent, globalWidth, includeOptions); }; + if (passNamesOnly) { + printOrderedEntries(*passRegistry, /*indent=*/2, /*includeOptions=*/false); + return; + } + + // Print the top-level pipeline argument. + printOptionHelp(passPipelineArg, + "A textual description of a pass pipeline to run", + /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr()); + // Print the available passes. - printOrderedEntries("Passes", *passRegistry); + printOrderedEntries(*passRegistry, /*indent=*/4, /*includeOptions=*/true, + "Passes"); // Print the available pass pipelines. if (!passPipelineRegistry->empty()) - printOrderedEntries("Pass Pipelines", *passPipelineRegistry); + printOrderedEntries(*passPipelineRegistry, /*indent=*/4, + /*includeOptions=*/true, "Pass Pipelines"); } size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { @@ -621,11 +644,21 @@ namespace mlir { namespace detail { struct PassPipelineCLParserImpl { - PassPipelineCLParserImpl(StringRef arg, StringRef description) + PassPipelineCLParserImpl(StringRef arg, StringRef description, + bool passNamesOnly) : passList(arg, llvm::cl::desc(description)) { + passList.getParser().passNamesOnly = passNamesOnly; passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); } + /// Returns true if the given pass registry entry was registered at the + /// top-level of the parser, i.e. not within an explicit textual pipeline. + bool contains(const PassRegistryEntry *entry) const { + return llvm::any_of(passList, [&](const PassArgData &data) { + return data.registryEntry == entry; + }); + } + /// The set of passes and pass pipelines to run. llvm::cl::list passList; }; @@ -634,8 +667,8 @@ /// Construct a pass pipeline parser with the given command line description. PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) - : impl(std::make_unique(arg, - description)) {} + : impl(std::make_unique( + arg, description, /*passNamesOnly=*/false)) {} PassPipelineCLParser::~PassPipelineCLParser() {} /// Returns true if this parser contains any valid options to add. @@ -646,9 +679,7 @@ /// Returns true if the given pass registry entry was registered at the /// top-level of the parser, i.e. not within an explicit textual pipeline. bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { - return llvm::any_of(impl->passList, [&](const PassArgData &data) { - return data.registryEntry == entry; - }); + return impl->contains(entry); } /// Adds the passes defined by this parser entry to the given pass manager. @@ -671,3 +702,25 @@ } return success(); } + +//===----------------------------------------------------------------------===// +// PassNameCLParser + +/// Construct a pass pipeline parser with the given command line description. +PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) + : impl(std::make_unique( + arg, description, /*passNamesOnly=*/true)) { + impl->passList.setMiscFlag(llvm::cl::CommaSeparated); +} +PassNameCLParser::~PassNameCLParser() {} + +/// Returns true if this parser contains any valid options to add. +bool PassNameCLParser::hasAnyOccurrences() const { + return impl->passList.getNumOccurrences() != 0; +} + +/// Returns true if the given pass registry entry was registered at the +/// top-level of the parser, i.e. not within an explicit textual pipeline. +bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { + return impl->contains(entry); +}