diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -73,6 +73,9 @@ return {begin(), end()}; } + /// Returns true if the pass manager has no passes. + bool empty() const { return begin() == end(); } + /// Nest a new operation pass manager for the given operation kind under this /// pass manager. OpPassManager &nest(StringAttr nestedName); @@ -110,7 +113,7 @@ /// of pipelines. /// Note: The quality of the string representation depends entirely on the /// the correctness of per-pass overrides of Pass::printAsTextualPipeline. - void printAsTextualPipeline(raw_ostream &os); + void printAsTextualPipeline(raw_ostream &os) const; /// Raw dump of the pass manager to llvm::errs(). void dump(); diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -23,6 +23,8 @@ #include namespace mlir { +class OpPassManager; + namespace detail { namespace pass_options { /// Parse a string containing a list of comma-delimited elements, invoking the @@ -158,7 +160,7 @@ public OptionBase { public: template - Option(PassOptions &parent, StringRef arg, Args &&... args) + Option(PassOptions &parent, StringRef arg, Args &&...args) : llvm::cl::opt( arg, llvm::cl::sub(parent), std::forward(args)...) { assert(!this->isPositional() && !this->isSink() && @@ -319,7 +321,8 @@ /// struct MyPipelineOptions : PassPipelineOptions { /// ListOption someListFlag{*this, "flag-name", llvm::cl::desc("...")}; /// }; -template class PassPipelineOptions : public detail::PassOptions { +template +class PassPipelineOptions : public detail::PassOptions { public: /// Factory that parses the provided options and returns a unique_ptr to the /// struct. @@ -335,7 +338,6 @@ /// any options. struct EmptyPipelineOptions : public PassPipelineOptions { }; - } // namespace mlir //===----------------------------------------------------------------------===// @@ -407,8 +409,92 @@ public: parser(Option &opt) : detail::VectorParserBase, T>(opt) {} }; -} // end namespace cl -} // end namespace llvm -#endif // MLIR_PASS_PASSOPTIONS_H_ +//===----------------------------------------------------------------------===// +// OpPassManager: OptionValue +template <> +struct OptionValue final : GenericOptionValue { + using WrapperType = mlir::OpPassManager; + + OptionValue(); + OptionValue(const mlir::OpPassManager &value); + OptionValue &operator=(const mlir::OpPassManager &rhs); + ~OptionValue(); + + /// Returns if the current option has a value. + bool hasValue() const { return value.get(); } + + /// Returns the current value of the option. + mlir::OpPassManager &getValue() const { + assert(hasValue() && "invalid option value"); + return *value; + } + + /// Set the value of the option. + void setValue(const mlir::OpPassManager &newValue); + void setValue(StringRef pipelineStr); + + /// Compare the option with the provided value. + bool compare(const mlir::OpPassManager &rhs) const; + bool compare(const GenericOptionValue &rhs) const override { + const auto &rhsOV = + static_cast &>(rhs); + if (!rhsOV.hasValue()) + return false; + return compare(rhsOV.getValue()); + } + +private: + void anchor() override; + + /// The underlying pass manager. We use a unique_ptr to avoid the need for the + /// full type definition. + std::unique_ptr value; +}; + +//===----------------------------------------------------------------------===// +// OpPassManager: Parser + +extern template class basic_parser; + +template <> +class parser : public basic_parser { +public: + /// A utility struct used when parsing a pass manager that prevents the need + /// for a default constructor on OpPassManager. + struct ParsedPassManager { + ParsedPassManager(); + ParsedPassManager(ParsedPassManager &&); + ~ParsedPassManager(); + operator const mlir::OpPassManager &() const { + assert(value && "parsed value was invalid"); + return *value; + } + + std::unique_ptr value; + }; + using parser_data_type = ParsedPassManager; + using OptVal = OptionValue; + + parser(Option &opt) : basic_parser(opt) {} + + bool parse(Option &, StringRef, StringRef arg, ParsedPassManager &value); + + /// Print an instance of the underling option value to the given stream. + static void print(raw_ostream &os, const mlir::OpPassManager &value); + + // Overload in subclass to provide a better default value. + StringRef getValueName() const override { return "pass-manager"; } + + void printOptionDiff(const Option &opt, mlir::OpPassManager &pm, + const OptVal &defaultValue, size_t globalWidth) const; + + // An out-of-line virtual method to provide a 'home' for this class. + void anchor() override; +}; + +} // namespace cl +} // namespace llvm + +#endif // MLIR_PASS_PASSOPTIONS_H_ diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -83,7 +83,7 @@ let options = [ Option<"defaultPipelineStr", "default-pipeline", "std::string", /*default=*/"", "The default optimizer pipeline used for callables">, - ListOption<"opPipelineStrs", "op-pipelines", "std::string", + ListOption<"opPipelineList", "op-pipelines", "OpPassManager", "Callable operation specific optimizer pipelines (in the form " "of `dialect.op(pipeline)`)">, Option<"maxInliningIterations", "max-iterations", "unsigned", diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -54,12 +54,14 @@ void Pass::printAsTextualPipeline(raw_ostream &os) { // Special case for adaptors to use the 'op_name(sub_passes)' format. if (auto *adaptor = dyn_cast(this)) { - llvm::interleaveComma(adaptor->getPassManagers(), os, - [&](OpPassManager &pm) { - os << pm.getOpName() << "("; - pm.printAsTextualPipeline(os); - os << ")"; - }); + llvm::interleave( + adaptor->getPassManagers(), + [&](OpPassManager &pm) { + os << pm.getOpName() << "("; + pm.printAsTextualPipeline(os); + os << ")"; + }, + [&] { os << ","; }); return; } // Otherwise, print the pass argument followed by its options. If the pass @@ -295,14 +297,17 @@ /// Prints out the given passes as the textual representation of a pipeline. static void printAsTextualPipeline(ArrayRef> passes, raw_ostream &os) { - llvm::interleaveComma(passes, os, [&](const std::unique_ptr &pass) { - pass->printAsTextualPipeline(os); - }); + llvm::interleave( + passes, + [&](const std::unique_ptr &pass) { + pass->printAsTextualPipeline(os); + }, + [&] { os << ","; }); } /// Prints out the passes of the pass manager as the textual representation /// of pipelines. -void OpPassManager::printAsTextualPipeline(raw_ostream &os) { +void OpPassManager::printAsTextualPipeline(raw_ostream &os) const { ::printAsTextualPipeline(impl->passes, os); } diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -332,6 +332,104 @@ return max; } +//===----------------------------------------------------------------------===// +// MLIR Options +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// OpPassManager: OptionValue + +llvm::cl::OptionValue::OptionValue() = default; +llvm::cl::OptionValue::OptionValue( + const mlir::OpPassManager &value) { + setValue(value); +} +llvm::cl::OptionValue & +llvm::cl::OptionValue::operator=( + const mlir::OpPassManager &rhs) { + setValue(rhs); + return *this; +} + +llvm::cl::OptionValue::~OptionValue() = default; + +void llvm::cl::OptionValue::setValue( + const OpPassManager &newValue) { + if (hasValue()) + *value = newValue; + else + value = std::make_unique(newValue); +} +void llvm::cl::OptionValue::setValue(StringRef pipelineStr) { + FailureOr pipeline = parsePassPipeline(pipelineStr); + assert(succeeded(pipeline) && "invalid pass pipeline"); + setValue(*pipeline); +} + +bool llvm::cl::OptionValue::compare( + const mlir::OpPassManager &rhs) const { + std::string lhsStr, rhsStr; + { + raw_string_ostream lhsStream(lhsStr); + value->printAsTextualPipeline(lhsStream); + + raw_string_ostream rhsStream(rhsStr); + rhs.printAsTextualPipeline(rhsStream); + } + + // Use the textual format for pipeline comparisons. + return lhsStr == rhsStr; +} + +void llvm::cl::OptionValue::anchor() {} + +//===----------------------------------------------------------------------===// +// OpPassManager: Parser + +namespace llvm { +namespace cl { +template class basic_parser; +} // namespace cl +} // namespace llvm + +bool llvm::cl::parser::parse(Option &, StringRef, StringRef arg, + ParsedPassManager &value) { + FailureOr pipeline = parsePassPipeline(arg); + if (failed(pipeline)) + return true; + value.value = std::make_unique(std::move(*pipeline)); + return false; +} + +void llvm::cl::parser::print(raw_ostream &os, + const OpPassManager &value) { + value.printAsTextualPipeline(os); +} + +void llvm::cl::parser::printOptionDiff( + const Option &opt, OpPassManager &pm, const OptVal &defaultValue, + size_t globalWidth) const { + printOptionName(opt, globalWidth); + outs() << "= "; + pm.printAsTextualPipeline(outs()); + + if (defaultValue.hasValue()) { + outs().indent(2) << " (default: "; + defaultValue.getValue().printAsTextualPipeline(outs()); + outs() << ")"; + } + outs() << "\n"; +} + +void llvm::cl::parser::anchor() {} + +llvm::cl::parser::ParsedPassManager::ParsedPassManager() = + default; +llvm::cl::parser::ParsedPassManager::ParsedPassManager( + ParsedPassManager &&) = default; +llvm::cl::parser::ParsedPassManager::~ParsedPassManager() = + default; + //===----------------------------------------------------------------------===// // TextualPassPipeline Parser //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -585,14 +585,8 @@ return; // Update the option for the op specific optimization pipelines. - for (auto &it : opPipelines) { - std::string pipeline; - llvm::raw_string_ostream pipelineOS(pipeline); - pipelineOS << it.getKey() << "("; - it.second.printAsTextualPipeline(pipelineOS); - pipelineOS << ")"; - opPipelineStrs.addValue(pipeline); - } + for (auto &it : opPipelines) + opPipelineList.addValue(it.second); this->opPipelines.emplace_back(std::move(opPipelines)); } @@ -751,15 +745,9 @@ // Initialize the op specific pass pipelines. llvm::StringMap pipelines; - for (StringRef pipeline : opPipelineStrs) { - // Skip empty pipelines. - if (pipeline.empty()) - continue; - FailureOr pm = parsePassPipeline(pipeline); - if (failed(pm)) - return failure(); - pipelines.try_emplace(pm->getOpName(), std::move(*pm)); - } + for (OpPassManager pipeline : opPipelineList) + if (!pipeline.empty()) + pipelines.try_emplace(pipeline.getOpName(), pipeline); opPipelines.assign({std::move(pipelines)}); return success(); diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h --- a/mlir/lib/Transforms/PassDetail.h +++ b/mlir/lib/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #define TRANSFORMS_PASSDETAIL_H_ #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" namespace mlir { diff --git a/mlir/test/Pass/crash-recovery.mlir b/mlir/test/Pass/crash-recovery.mlir --- a/mlir/test/Pass/crash-recovery.mlir +++ b/mlir/test/Pass/crash-recovery.mlir @@ -20,7 +20,7 @@ module @foo {} } -// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass, test-pass-crash)' +// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass,test-pass-crash)' // REPRO: module @inner_mod1 // REPRO: module @foo { diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir --- a/mlir/test/Pass/pipeline-options-parsing.mlir +++ b/mlir/test/Pass/pipeline-options-parsing.mlir @@ -14,4 +14,4 @@ // CHECK_1: test-options-pass{list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d} // CHECK_2: test-options-pass{list=1 string= string-list=a,b} -// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }), func.func(test-options-pass{list=1,2,3,4 string= })) +// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }),func.func(test-options-pass{list=1,2,3,4 string= })) diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -2,6 +2,7 @@ // RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s // RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC // RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY +// RUN: mlir-opt %s -inline='op-pipelines=func.func(canonicalize,cse)' | FileCheck %s --check-prefix INLINE_SIMPLIFY // Inline a function that takes an argument. func @func_with_arg(%c : i32) -> i32 {