diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" namespace mlir { class AffineExpr; @@ -25,6 +26,62 @@ namespace linalg { class LinalgDependenceGraph; +//===----------------------------------------------------------------------===// +// Custom command line option +//===----------------------------------------------------------------------===// + +/// Custom option parser to parse list of list command line options. +/// +/// Example: +/// ``` +/// ListOption, +/// InnerListOptionParser> myOpt{*this, "my-opt", ... } +/// ``` +/// Translates --my-opt=1_2_3,3_4,0_1 to [[1,2,3],[3,4],[0,1]] +template +class InnerListOptionParser { +public: + InnerListOptionParser(llvm::cl::Option &O) : elemParser(O) {} + + using parser_data_type = std::vector; + + void initialize() {} + void getExtraOptionNames(SmallVectorImpl &OptionNames) {} + enum llvm::cl::ValueExpected getValueExpectedFlagDefault() const { + return llvm::cl::ValueRequired; + } + + /// Split `Arg` into list elements and parse them using `elemParser`. + bool parse(llvm::cl::Option &O, StringRef ArgName, StringRef Arg, + parser_data_type &Val) { + SmallVector elements; + Arg.split(elements, '_'); + for (StringRef elem : elements) { + int64_t curr; + if (elemParser.parse(O, ArgName, elem, curr)) + return O.error("Cannot parse list element '" + elem + "'!"); + Val.push_back(curr); + } + return false; + } + + /// Add the `vector<>` qualifier to the option info. + void printOptionInfo(const llvm::cl::Option &O, size_t GlobalWidth) const { + llvm::outs() << " --" << O.ArgStr; + llvm::outs() << "=>"; + llvm::cl::Option::printHelpStr(O.HelpStr, GlobalWidth, getOptionWidth(O)); + } + + /// Add the `vector<>` qualifier to the option width. + size_t getOptionWidth(const llvm::cl::Option &O) const { + StringRef vectorExt("vector<>"); + return elemParser.getOptionWidth(O) + vectorExt.size(); + } + +private: + llvm::cl::parser elemParser; +}; + //===----------------------------------------------------------------------===// // General utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -87,6 +87,11 @@ static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) { os << value; } + template + static void printValue(raw_ostream &os, ParserT &parser, + const std::vector &value) { + llvm::interleave(value, os, "_"); + } template static void printValue(raw_ostream &os, ParserT &parser, const bool &value) { os << (value ? StringRef("true") : StringRef("false")); @@ -112,7 +117,7 @@ public OptionBase { public: template - Option(PassOptions &parent, StringRef arg, Args &&... args) + Option(PassOptions &parent, StringRef arg, Args &&...args) : llvm::cl::opt( arg, llvm::cl::sub(parent), std::forward(args)...) { assert(!this->isPositional() && !this->isSink() && @@ -156,7 +161,7 @@ public OptionBase { public: template - ListOption(PassOptions &parent, StringRef arg, Args &&... args) + ListOption(PassOptions &parent, StringRef arg, Args &&...args) : llvm::cl::list( arg, llvm::cl::sub(parent), std::forward(args)...) { assert(!this->isPositional() && !this->isSink() && @@ -279,4 +284,3 @@ } // namespace mlir #endif // MLIR_PASS_PASSOPTIONS_H_ - diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir --- a/mlir/test/Dialect/Linalg/hoist-padding.mlir +++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATVEC -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 transpose-paddings=1:0,0,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=TRANSP +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matvec pad hoist-paddings=1,1,0 transpose-paddings=1_0,0,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=TRANSP // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad hoist-paddings=1,2,1 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATMUL // MATVEC-DAG: #[[DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -109,17 +109,19 @@ *this, "hoist-paddings", llvm::cl::desc("Operand hoisting depths when test-pad-pattern."), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; - ListOption transposePaddings{ - *this, "transpose-paddings", - llvm::cl::desc( - "Transpose paddings when test-pad-pattern. Specify a " - "operand dimension interchange using the following format:\n" - "-transpose-paddings=1:0:2,0:1,0:1\n" - "It defines the interchange [1, 0, 2] for operand one and " - "the interchange [0, 1] (no transpose) for the remaining operands." - "All interchange vectors have to be permuations matching the " - "operand rank."), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption, InnerListOptionParser> + transposePaddings{ + *this, "transpose-paddings", + llvm::cl::desc( + "Transpose paddings when test-pad-pattern. Specify a\n" + "\t operand dimension interchange using the following format:\n" + "\t\t--transpose-paddings=1_0_2,0_1,0_1\n" + "\t It defines the interchange [1, 0, 2] for operand one and\n" + "\t the interchange [0, 1] (no transpose) for the remaining " + "operands.\n" + "\t All interchange vectors have to be permuations matching the " + "operand rank."), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; Option generalize{*this, "generalize", llvm::cl::desc("Generalize named operations."), llvm::cl::init(false)}; @@ -272,11 +274,9 @@ SmallVector transposeVector = {}; if (opOperand.getOperandNumber() >= transposePaddings.size()) return transposeVector; - SmallVector elems; - StringRef(transposePaddings[opOperand.getOperandNumber()]) - .split(elems, ':'); - for (StringRef elem : elems) - transposeVector.push_back(std::stoi(elem.str())); + transposeVector.append( + transposePaddings[opOperand.getOperandNumber()].begin(), + transposePaddings[opOperand.getOperandNumber()].end()); return transposeVector; }; paddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp);