diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -178,6 +178,9 @@ /// Allow access to 'clone' and 'run'. friend class OpPassManager; + + /// Allow access to 'passOptions'. + friend class PassInfo; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -183,6 +183,10 @@ PassOptions() = default; + /// This constructor is purposely empty to avoid copying the internal options + /// map. + PassOptions(const PassOptions &) {} + /// Copy the option values from 'other' into 'this', where 'other' has the /// same options as 'this'. void copyOptionValuesFrom(const PassOptions &other); @@ -196,6 +200,13 @@ /// 'parseFromString'. void print(raw_ostream &os); + /// Print the help string for the options held by this struct. `descIndent` is + /// the indent that the descriptions should be aligned. + void printHelp(size_t indent, size_t descIndent) const; + + /// Return the maximum width required when printing the help string. + size_t getOptionWidth() const; + private: /// A list of all of the opaque options. std::vector options; diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -21,6 +21,10 @@ class OpPassManager; class Pass; +namespace detail { +class PassOptions; +} // end namespace detail + /// A registry function that adds passes to the given pass manager. This should /// also parse options and return success() if parsing succeeded. using PassRegistryFunction = @@ -55,28 +59,45 @@ /// Returns a description for the pass, this never returns null. StringRef getPassDescription() const { return description; } + /// Print the help information for this pass. This includes the argument, + /// description, and any pass options. `descIndent` is the indent that the + /// descriptions should be aligned. + void printHelpStr(size_t indent, size_t descIndent) const; + + /// Return the maximum width required when printing the options of this entry. + size_t getOptionWidth() const; + protected: - PassRegistryEntry(StringRef arg, StringRef description, - const PassRegistryFunction &builder) - : arg(arg), description(description), builder(builder) {} + PassRegistryEntry( + StringRef arg, StringRef description, const PassRegistryFunction &builder, + std::function)> + optHandler) + : arg(arg), description(description), builder(builder), + optHandler(optHandler) {} private: - // The argument with which to invoke the pass via mlir-opt. + /// The argument with which to invoke the pass via mlir-opt. StringRef arg; - // Description of the pass. + /// Description of the pass. StringRef description; - // Function to register this entry to a pass manager pipeline. + /// Function to register this entry to a pass manager pipeline. PassRegistryFunction builder; + + /// Function to invoke a handler for a pass options instance. + std::function)> + optHandler; }; /// A structure to represent the information of a registered pass pipeline. class PassPipelineInfo : public PassRegistryEntry { public: - PassPipelineInfo(StringRef arg, StringRef description, - const PassRegistryFunction &builder) - : PassRegistryEntry(arg, description, builder) {} + PassPipelineInfo( + StringRef arg, StringRef description, const PassRegistryFunction &builder, + std::function)> + optHandler) + : PassRegistryEntry(arg, description, builder, optHandler) {} }; /// A structure to represent the information for a derived pass class. @@ -94,8 +115,10 @@ /// Register a specific dialect pipeline registry function with the system, /// typically used through the PassPipelineRegistration template. -void registerPassPipeline(StringRef arg, StringRef description, - const PassRegistryFunction &function); +void registerPassPipeline( + StringRef arg, StringRef description, const PassRegistryFunction &function, + std::function)> + optHandler); /// Register a specific dialect pass allocator function with the system, /// typically used through the PassRegistration template. @@ -113,7 +136,6 @@ /// static PassRegistration reg("my-pass", "My Pass Description."); /// template struct PassRegistration { - PassRegistration(StringRef arg, StringRef description, const PassAllocatorFunction &constructor) { registerPass(arg, description, PassID::getID(), constructor); @@ -142,14 +164,18 @@ PassPipelineRegistration( StringRef arg, StringRef description, std::function builder) { - registerPassPipeline(arg, description, - [builder](OpPassManager &pm, StringRef optionsStr) { - Options options; - if (failed(options.parseFromString(optionsStr))) - return failure(); - builder(pm, options); - return success(); - }); + registerPassPipeline( + arg, description, + [builder](OpPassManager &pm, StringRef optionsStr) { + Options options; + if (failed(options.parseFromString(optionsStr))) + return failure(); + builder(pm, options); + return success(); + }, + [](function_ref optHandler) { + optHandler(Options()); + }); } }; @@ -158,13 +184,15 @@ template <> struct PassPipelineRegistration { PassPipelineRegistration(StringRef arg, StringRef description, std::function builder) { - registerPassPipeline(arg, description, - [builder](OpPassManager &pm, StringRef optionsStr) { - if (!optionsStr.empty()) - return failure(); - builder(pm); - return success(); - }); + registerPassPipeline( + arg, description, + [builder](OpPassManager &pm, StringRef optionsStr) { + if (!optionsStr.empty()) + return failure(); + builder(pm); + return success(); + }, + [](function_ref) {}); } }; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -35,13 +35,48 @@ }; } +/// Utility to print the help string for a specific option. +void printOptionHelp(StringRef arg, StringRef desc, size_t indent, + size_t descIndent, bool isTopLevel) { + size_t numSpaces = descIndent - indent - arg.size() - 5; + llvm::outs().indent(indent) << "--" << arg; + llvm::outs().indent(numSpaces) << " - " << desc << '\n'; +} + +//===----------------------------------------------------------------------===// +// PassRegistry +//===----------------------------------------------------------------------===// + +/// Print the help information for this pass. This includes the argument, +/// description, and any pass options. `descIndent` is the indent that the +/// descriptions should be aligned. +void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { + printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent, + /*isTopLevel=*/true); + // If this entry has options, print the help for those as well. + optHandler([=](const PassOptions &options) { + options.printHelp(indent, descIndent); + }); +} + +/// Return the maximum width required when printing the options of this +/// entry. +size_t PassRegistryEntry::getOptionWidth() const { + size_t maxLen = 0; + optHandler([&](const PassOptions &options) mutable { + maxLen = options.getOptionWidth() + 2; + }); + return maxLen; +} + //===----------------------------------------------------------------------===// // PassPipelineInfo //===----------------------------------------------------------------------===// -void mlir::registerPassPipeline(StringRef arg, StringRef description, - const PassRegistryFunction &function) { - PassPipelineInfo pipelineInfo(arg, description, function); +void mlir::registerPassPipeline( + StringRef arg, StringRef description, const PassRegistryFunction &function, + std::function)> optHandler) { + PassPipelineInfo pipelineInfo(arg, description, function, optHandler); bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second; assert(inserted && "Pass pipeline registered multiple times"); (void)inserted; @@ -53,7 +88,12 @@ PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID, const PassAllocatorFunction &allocator) - : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {} + : PassRegistryEntry( + arg, description, buildDefaultRegistryFn(allocator), + // Use a temporary pass to provide an options instance. + [=](function_ref optHandler) { + optHandler(allocator()->passOptions); + }) {} void mlir::registerPass(StringRef arg, StringRef description, const PassID *passID, @@ -151,6 +191,34 @@ os << '}'; } +/// Print the help string for the options held by this struct. `descIndent` is +/// the indent within the stream that the descriptions should be aligned. +void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const { + // Sort the options to make the ordering deterministic. + SmallVector orderedOptions(options.begin(), options.end()); + llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(), + [](OptionBase *const *lhs, OptionBase *const *rhs) { + return (*lhs)->getArgStr().compare( + (*rhs)->getArgStr()); + }); + for (OptionBase *option : orderedOptions) { + // TODO(riverriddle) printOptionInfo assumes a specific indent and will + // print options with values with incorrect indentation. We should add + // support to llvm::cl::Option for passing in a base indent to use when + // printing. + llvm::outs().indent(indent); + option->getOption()->printOptionInfo(descIndent - indent); + } +} + +/// Return the maximum width required when printing the help string. +size_t detail::PassOptions::getOptionWidth() const { + size_t max = 0; + for (auto *option : options) + max = std::max(max, option->getOption()->getOptionWidth()); + return max; +} + //===----------------------------------------------------------------------===// // TextualPassPipeline Parser //===----------------------------------------------------------------------===// @@ -443,6 +511,7 @@ void initialize(); void printOptionInfo(const llvm::cl::Option &opt, size_t globalWidth) const override; + size_t getOptionWidth(const llvm::cl::Option &opt) const override; bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, PassArgData &value); }; @@ -467,15 +536,54 @@ } } -void PassNameParser::printOptionInfo(const llvm::cl::Option &O, - size_t GlobalWidth) const { - PassNameParser *TP = const_cast(this); - llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(), - [](const PassNameParser::OptionInfo *VT1, - const PassNameParser::OptionInfo *VT2) { - return VT1->Name.compare(VT2->Name); - }); - llvm::cl::parser::printOptionInfo(O, GlobalWidth); +void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, + size_t globalWidth) const { + // Print the information for the top-level option. + if (opt.hasArgStr()) { + llvm::outs() << " --" << opt.ArgStr; + opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7); + } else { + llvm::outs() << " " << opt.HelpStr << '\n'; + } + + // Print the top-level pipeline argument. + printOptionHelp(passPipelineArg, + "A textual description of a pass pipeline to run", + /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr()); + + // Functor used to print the ordered entries of a registration map. + auto printOrderedEntries = [&](StringRef header, auto &map) { + llvm::SmallVector orderedEntries; + for (auto &kv : map) + orderedEntries.push_back(&kv.second); + llvm::array_pod_sort( + orderedEntries.begin(), orderedEntries.end(), + [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { + return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument()); + }); + + llvm::outs().indent(4) << header << ":\n"; + for (PassRegistryEntry *entry : orderedEntries) + entry->printHelpStr(/*indent=*/6, globalWidth); + }; + + // Print the available passes. + printOrderedEntries("Passes", *passRegistry); + + // Print the available pass pipelines. + if (!passPipelineRegistry->empty()) + printOrderedEntries("Pass Pipelines", *passPipelineRegistry); +} + +size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { + size_t maxWidth = llvm::cl::parser::getOptionWidth(opt) + 2; + + // Check for any wider pass or pipeline options. + for (auto &entry : *passRegistry) + maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); + for (auto &entry : *passPipelineRegistry) + maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); + return maxWidth; } bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,