diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h
--- a/mlir/include/mlir/TableGen/Format.h
+++ b/mlir/include/mlir/TableGen/Format.h
@@ -88,22 +88,33 @@
/// Struct representing a replacement segment for the formatted string. It can
/// be a segment of the formatting template (for `Literal`) or a replacement
-/// parameter (for `PositionalPH` and `SpecialPH`).
+/// parameter (for `PositionalPH`, `PositionalRangePH` and `SpecialPH`).
struct FmtReplacement {
- enum class Type { Empty, Literal, PositionalPH, SpecialPH };
+ enum class Type {
+ Empty,
+ Literal,
+ PositionalPH,
+ PositionalRangePH,
+ SpecialPH
+ };
FmtReplacement() = default;
explicit FmtReplacement(StringRef literal)
: type(Type::Literal), spec(literal) {}
FmtReplacement(StringRef spec, size_t index)
: type(Type::PositionalPH), spec(spec), index(index) {}
+ FmtReplacement(StringRef spec, size_t index, size_t end)
+ : type(Type::PositionalRangePH), spec(spec), index(index), end(end) {}
FmtReplacement(StringRef spec, FmtContext::PHKind placeholder)
: type(Type::SpecialPH), spec(spec), placeholder(placeholder) {}
Type type = Type::Empty;
StringRef spec;
size_t index = 0;
+ size_t end = kUnset;
FmtContext::PHKind placeholder = FmtContext::PHKind::None;
+
+ static constexpr size_t kUnset = -1;
};
class FmtObjectBase {
@@ -121,7 +132,7 @@
// std::vector.
struct CreateAdapters {
template
- std::vector operator()(Ts &... items) {
+ std::vector operator()(Ts &...items) {
return std::vector{&items...};
}
};
@@ -205,7 +216,8 @@
///
/// There are two categories of placeholders accepted, both led by a '$' sign:
///
-/// 1. Positional placeholder: $[0-9]+
+/// 1.a Positional placeholder: $[0-9]+
+/// 1.b Positional range placeholder: $[0-9]+...
/// 2. Special placeholder: $[a-zA-Z_][a-zA-Z0-9_]*
///
/// Replacement parameters for positional placeholders are supplied as the
@@ -214,6 +226,9 @@
/// can use the positional placeholders in any order and repeat any times, for
/// example, "$2 $1 $1 $0" is accepted.
///
+/// Replace parameters for positional range placeholders are supplied as if
+/// positional placeholders were specified with commas separating them.
+///
/// Replacement parameters for special placeholders are supplied using the `ctx`
/// format context.
///
@@ -237,7 +252,7 @@
/// 2. This utility does not support format layout because it is rarely needed
/// in C++ code generation.
template
-inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
+inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals)
-> FmtObject(vals))...))> {
using ParamTuple = decltype(std::make_tuple(
diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp
--- a/mlir/lib/TableGen/Format.cpp
+++ b/mlir/lib/TableGen/Format.cpp
@@ -97,7 +97,8 @@
// First try to see if it's a positional placeholder, and then handle special
// placeholders.
- size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1);
+ size_t end =
+ fmt.find_if_not([](char c) { return std::isdigit(c); }, /*From=*/1);
if (end != 1) {
// We have a positional placeholder. Parse the index.
size_t index = 0;
@@ -105,6 +106,12 @@
llvm_unreachable("invalid replacement sequence index");
}
+ if (fmt.substr(end, 3) == "...") {
+ return {
+ FmtReplacement{fmt.substr(0, end + 3), index, FmtReplacement::kUnset},
+ fmt.substr(end + 3)};
+ }
+
if (end == StringRef::npos) {
// All the remaining characters are part of the positional placeholder.
return {FmtReplacement{fmt, index}, StringRef()};
@@ -164,6 +171,21 @@
continue;
}
+ if (repl.type == FmtReplacement::Type::PositionalRangePH) {
+ if (repl.index >= adapters.size()) {
+ s << repl.spec << kMarkerForNoSubst;
+ continue;
+ }
+ int it = repl.index;
+ int end = repl.end == FmtReplacement::kUnset ? adapters.size() : repl.end;
+ if (it == end)
+ continue;
+ adapters[it++]->format(s, /*Options=*/"");
+ for (; it != end; ++it)
+ adapters[it]->format(s << ", ", /*Options=*/"");
+ continue;
+ }
+
assert(repl.type == FmtReplacement::Type::PositionalPH);
if (repl.index >= adapters.size()) {
diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -85,3 +85,8 @@
// CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
(NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
+
+// CHECK: struct test5 : public ::mlir::RewritePattern {
+// CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
+def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
+ (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1497,9 +1497,27 @@
cc_library(
name = "SparseTensor",
- srcs = glob(["lib/Dialect/SparseTensor/IR/*.cpp"]),
+ srcs = glob([
+ "lib/Dialect/SparseTensor/IR/*.cpp",
+ ]),
hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"],
includes = ["include"],
+ deps = [
+ ":IR",
+ ":SideEffectInterfaces",
+ ":SparseTensorAttrDefsIncGen",
+ ":SparseTensorOpsIncGen",
+ ":SparseTensorUtils",
+ ":StandardOps",
+ "//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "SparseTensorUtils",
+ srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]),
+ hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]),
+ includes = ["include"],
deps = [
":IR",
":SideEffectInterfaces",
@@ -1535,17 +1553,6 @@
],
)
-cc_library(
- name = "SparseTensorUtils",
- srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]),
- hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]),
- includes = ["include"],
- deps = [
- ":IR",
- "//llvm:Support",
- ],
-)
-
td_library(
name = "StdOpsTdFiles",
srcs = [