diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h --- a/mlir/include/mlir/TableGen/Format.h +++ b/mlir/include/mlir/TableGen/Format.h @@ -88,22 +88,33 @@ /// Struct representing a replacement segment for the formatted string. It can /// be a segment of the formatting template (for `Literal`) or a replacement -/// parameter (for `PositionalPH` and `SpecialPH`). +/// parameter (for `PositionalPH`, `PositionalRangePH` and `SpecialPH`). struct FmtReplacement { - enum class Type { Empty, Literal, PositionalPH, SpecialPH }; + enum class Type { + Empty, + Literal, + PositionalPH, + PositionalRangePH, + SpecialPH + }; FmtReplacement() = default; explicit FmtReplacement(StringRef literal) : type(Type::Literal), spec(literal) {} FmtReplacement(StringRef spec, size_t index) : type(Type::PositionalPH), spec(spec), index(index) {} + FmtReplacement(StringRef spec, size_t index, size_t end) + : type(Type::PositionalRangePH), spec(spec), index(index), end(end) {} FmtReplacement(StringRef spec, FmtContext::PHKind placeholder) : type(Type::SpecialPH), spec(spec), placeholder(placeholder) {} Type type = Type::Empty; StringRef spec; size_t index = 0; + size_t end = kUnset; FmtContext::PHKind placeholder = FmtContext::PHKind::None; + + static constexpr size_t kUnset = -1; }; class FmtObjectBase { @@ -121,7 +132,7 @@ // std::vector. struct CreateAdapters { template - std::vector operator()(Ts &... items) { + std::vector operator()(Ts &...items) { return std::vector{&items...}; } }; @@ -205,7 +216,8 @@ /// /// There are two categories of placeholders accepted, both led by a '$' sign: /// -/// 1. Positional placeholder: $[0-9]+ +/// 1.a Positional placeholder: $[0-9]+ +/// 1.b Positional range placeholder: $[0-9]+... /// 2. Special placeholder: $[a-zA-Z_][a-zA-Z0-9_]* /// /// Replacement parameters for positional placeholders are supplied as the @@ -214,6 +226,9 @@ /// can use the positional placeholders in any order and repeat any times, for /// example, "$2 $1 $1 $0" is accepted. /// +/// Replace parameters for positional range placeholders are supplied as if +/// positional placeholders were specified with commas separating them. +/// /// Replacement parameters for special placeholders are supplied using the `ctx` /// format context. /// @@ -237,7 +252,7 @@ /// 2. This utility does not support format layout because it is rarely needed /// in C++ code generation. template -inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals) +inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject(vals))...))> { using ParamTuple = decltype(std::make_tuple( diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp --- a/mlir/lib/TableGen/Format.cpp +++ b/mlir/lib/TableGen/Format.cpp @@ -97,7 +97,8 @@ // First try to see if it's a positional placeholder, and then handle special // placeholders. - size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1); + size_t end = + fmt.find_if_not([](char c) { return std::isdigit(c); }, /*From=*/1); if (end != 1) { // We have a positional placeholder. Parse the index. size_t index = 0; @@ -105,6 +106,12 @@ llvm_unreachable("invalid replacement sequence index"); } + if (fmt.substr(end, 3) == "...") { + return { + FmtReplacement{fmt.substr(0, end + 3), index, FmtReplacement::kUnset}, + fmt.substr(end + 3)}; + } + if (end == StringRef::npos) { // All the remaining characters are part of the positional placeholder. return {FmtReplacement{fmt, index}, StringRef()}; @@ -164,6 +171,21 @@ continue; } + if (repl.type == FmtReplacement::Type::PositionalRangePH) { + if (repl.index >= adapters.size()) { + s << repl.spec << kMarkerForNoSubst; + continue; + } + int it = repl.index; + int end = repl.end == FmtReplacement::kUnset ? adapters.size() : repl.end; + if (it == end) + continue; + adapters[it++]->format(s, /*Options=*/""); + for (; it != end; ++it) + adapters[it]->format(s << ", ", /*Options=*/""); + continue; + } + assert(repl.type == FmtReplacement::Type::PositionalPH); if (repl.index >= adapters.size()) { diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -85,3 +85,8 @@ // CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), (NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; + +// CHECK: struct test5 : public ::mlir::RewritePattern { +// CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) +def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), + (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1497,9 +1497,27 @@ cc_library( name = "SparseTensor", - srcs = glob(["lib/Dialect/SparseTensor/IR/*.cpp"]), + srcs = glob([ + "lib/Dialect/SparseTensor/IR/*.cpp", + ]), hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"], includes = ["include"], + deps = [ + ":IR", + ":SideEffectInterfaces", + ":SparseTensorAttrDefsIncGen", + ":SparseTensorOpsIncGen", + ":SparseTensorUtils", + ":StandardOps", + "//llvm:Support", + ], +) + +cc_library( + name = "SparseTensorUtils", + srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]), + includes = ["include"], deps = [ ":IR", ":SideEffectInterfaces", @@ -1535,17 +1553,6 @@ ], ) -cc_library( - name = "SparseTensorUtils", - srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]), - hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]), - includes = ["include"], - deps = [ - ":IR", - "//llvm:Support", - ], -) - td_library( name = "StdOpsTdFiles", srcs = [