diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -246,6 +246,26 @@ std::unique_ptr impl; }; +/// This class implements a command-line parser spefically for MLIR pass names. +/// It registers a cl option with a given argument and description that accepts +/// a comma delimited list of pass names. +class PassNameCLParser { +public: + /// Construct a parser with the given command line description. + PassNameCLParser(StringRef arg, StringRef description); + ~PassNameCLParser(); + + /// Returns true if this parser contains any valid options to add. + bool hasAnyOccurrences() const; + + /// Returns true if the given pass registry entry was registered at the + /// top-level of the parser, i.e. not within an explicit textual pipeline. + bool contains(const PassRegistryEntry *entry) const; + +private: + std::unique_ptr impl; +}; + } // end namespace mlir #endif // MLIR_PASS_PASSREGISTRY_H_ diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -32,10 +32,10 @@ //===--------------------------------------------------------------------===// // IR Printing //===--------------------------------------------------------------------===// - PassPipelineCLParser printBefore{"print-ir-before", - "Print IR before specified passes"}; - PassPipelineCLParser printAfter{"print-ir-after", - "Print IR after specified passes"}; + PassNameCLParser printBefore{"print-ir-before", + "Print IR before specified passes"}; + PassNameCLParser printAfter{"print-ir-after", + "Print IR after specified passes"}; llvm::cl::opt printBeforeAll{ "print-ir-before-all", llvm::cl::desc("Print IR before each pass"), llvm::cl::init(false)}; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -543,6 +543,12 @@ size_t getOptionWidth(const llvm::cl::Option &opt) const override; bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, PassArgData &value); + + /// If true, this parser only parses entries that correspond to a concrete + /// pass registry entry, and does not add a `pass-pipeline` argument, does not + /// include the options for pass entries, and does not include pass pipelines + /// entries. + bool passNamesOnly = false; }; } // namespace @@ -550,8 +556,10 @@ llvm::cl::parser::initialize(); /// Add an entry for the textual pass pipeline option. - addLiteralOption(passPipelineArg, PassArgData(), - "A textual description of a pass pipeline to run"); + if (!passNamesOnly) { + addLiteralOption(passPipelineArg, PassArgData(), + "A textual description of a pass pipeline to run"); + } /// Add the pass entries. for (const auto &kv : *passRegistry) { @@ -559,14 +567,24 @@ kv.second.getPassDescription()); } /// Add the pass pipeline entries. - for (const auto &kv : *passPipelineRegistry) { - addLiteralOption(kv.second.getPassArgument(), &kv.second, - kv.second.getPassDescription()); + if (!passNamesOnly) { + for (const auto &kv : *passPipelineRegistry) { + addLiteralOption(kv.second.getPassArgument(), &kv.second, + kv.second.getPassDescription()); + } } } void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, size_t globalWidth) const { + // If this parser is just parsing pass names, print a simplified option + // string. + if (passNamesOnly) { + llvm::outs() << " --" << opt.ArgStr << "="; + opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18); + return; + } + // Print the information for the top-level option. if (opt.hasArgStr()) { llvm::outs() << " --" << opt.ArgStr; @@ -635,11 +653,21 @@ namespace mlir { namespace detail { struct PassPipelineCLParserImpl { - PassPipelineCLParserImpl(StringRef arg, StringRef description) + PassPipelineCLParserImpl(StringRef arg, StringRef description, + bool passNamesOnly) : passList(arg, llvm::cl::desc(description)) { + passList.getParser().passNamesOnly = passNamesOnly; passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); } + /// Returns true if the given pass registry entry was registered at the + /// top-level of the parser, i.e. not within an explicit textual pipeline. + bool contains(const PassRegistryEntry *entry) const { + return llvm::any_of(passList, [&](const PassArgData &data) { + return data.registryEntry == entry; + }); + } + /// The set of passes and pass pipelines to run. llvm::cl::list passList; }; @@ -648,8 +676,8 @@ /// Construct a pass pipeline parser with the given command line description. PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) - : impl(std::make_unique(arg, - description)) {} + : impl(std::make_unique( + arg, description, /*passNamesOnly=*/false)) {} PassPipelineCLParser::~PassPipelineCLParser() {} /// Returns true if this parser contains any valid options to add. @@ -660,9 +688,7 @@ /// Returns true if the given pass registry entry was registered at the /// top-level of the parser, i.e. not within an explicit textual pipeline. bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { - return llvm::any_of(impl->passList, [&](const PassArgData &data) { - return data.registryEntry == entry; - }); + return impl->contains(entry); } /// Adds the passes defined by this parser entry to the given pass manager. @@ -685,3 +711,25 @@ } return success(); } + +//===----------------------------------------------------------------------===// +// PassNameCLParser + +/// Construct a pass pipeline parser with the given command line description. +PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) + : impl(std::make_unique( + arg, description, /*passNamesOnly=*/true)) { + impl->passList.setMiscFlag(llvm::cl::CommaSeparated); +} +PassNameCLParser::~PassNameCLParser() {} + +/// Returns true if this parser contains any valid options to add. +bool PassNameCLParser::hasAnyOccurrences() const { + return impl->passList.getNumOccurrences() != 0; +} + +/// Returns true if the given pass registry entry was registered at the +/// top-level of the parser, i.e. not within an explicit textual pipeline. +bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { + return impl->contains(entry); +}