diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -11,8 +11,8 @@
This manual explains in detail all of the available mechanisms for defining
rewrite rules in such a declarative manner. It aims to be a specification
instead of a tutorial. Please refer to
-[Quickstart tutorial to adding MLIR graph
-rewrite](Tutorials/QuickstartRewrites.md) for the latter.
+[Quickstart tutorial to adding MLIR graph rewrite](Tutorials/QuickstartRewrites.md)
+for the latter.
Given that declarative rewrite rules depend on op definition specification, this
manual assumes knowledge of the [ODS](OpDefinitions.md) doc.
@@ -51,8 +51,8 @@
* Matching multi-result ops in nested patterns.
* Matching and generating variadic operand/result ops in nested patterns.
* Packing and unpacking variadic operands/results during generation.
-* [`NativeCodeCall`](#nativecodecall-transforming-the-generated-op)
- returning more than one results.
+* [`NativeCodeCall`](#nativecodecall-transforming-the-generated-op) returning
+ more than one results.
## Rule Definition
@@ -93,9 +93,9 @@
[directives](#rewrite-directives). `argN` is for matching (if used in source
pattern) or generating (if used in result pattern) the `N`-th argument for
`operator`. If the `operator` is some MLIR operation, it means the `N`-th
-argument as specified in the `arguments` list of the op's definition.
-Therefore, we say op argument specification in pattern is **position-based**:
-the position where they appear matters.
+argument as specified in the `arguments` list of the op's definition. Therefore,
+we say op argument specification in pattern is **position-based**: the position
+where they appear matters.
`argN` can be a `dag` object itself, thus we can have nested `dag` tree to model
the def-use relationship between ops.
@@ -245,15 +245,15 @@
Otherwise, a custom `build()` method that matches the argument list is required.
Right now all ODS-generated `build()` methods require specifying the result
-type(s), unless the op has known traits like `SameOperandsAndResultType` that
-we can use to auto-generate a `build()` method with result type deduction.
-When generating an op to replace the result of the matched root op, we can use
-the matched root op's result type when calling the ODS-generated builder.
-Otherwise (e.g., generating an [auxiliary op](#supporting-auxiliary-ops) or
-generating an op with a nested result pattern), DRR will not be able to deduce
-the result type(s). The pattern author will need to define a custom builder
-that has result type deduction ability via `OpBuilder` in ODS. For example,
-in the following pattern
+type(s), unless the op has known traits like `SameOperandsAndResultType` that we
+can use to auto-generate a `build()` method with result type deduction. When
+generating an op to replace the result of the matched root op, we can use the
+matched root op's result type when calling the ODS-generated builder. Otherwise
+(e.g., generating an [auxiliary op](#supporting-auxiliary-ops) or generating an
+op with a nested result pattern), DRR will not be able to deduce the result
+type(s). The pattern author will need to define a custom builder that has result
+type deduction ability via `OpBuilder` in ODS. For example, in the following
+pattern
```tablegen
def : Pat<(AOp $input, $attr), (COp (AOp $input, $attr) $attr)>;
@@ -295,8 +295,8 @@
In the result pattern, we can bind to the result(s) of a newly built op by
attaching symbols to the op. (But we **cannot** bind to op arguments given that
-they are referencing previously bound symbols.) This is useful for reusing
-newly created results where suitable. For example,
+they are referencing previously bound symbols.) This is useful for reusing newly
+created results where suitable. For example,
```tablegen
def DOp : Op<"d_op"> {
@@ -373,18 +373,18 @@
definition of the C++ helper function.
In the above example, we are using a string to specialize the `NativeCodeCall`
-template. The string can be an arbitrary C++ expression that evaluates into
-some C++ object expected at the `NativeCodeCall` site (here it would be
-expecting an array attribute). Typically the string should be a function call.
+template. The string can be an arbitrary C++ expression that evaluates into some
+C++ object expected at the `NativeCodeCall` site (here it would be expecting an
+array attribute). Typically the string should be a function call.
Note that currently `NativeCodeCall` must return no more than one value or
attribute. This might change in the future.
##### `NativeCodeCall` placeholders
-In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N`. The former
-is called _special placeholder_, while the latter is called _positional
-placeholder_.
+In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`.
+The former is called _special placeholder_, while the latter is called
+_positional placeholder_ and _positional range placeholder_.
`NativeCodeCall` right now only supports three special placeholders:
`$_builder`, `$_loc`, and `$_self`:
@@ -423,6 +423,11 @@
NativeCodeCall<"someFn($1, $2, $0)">` and use it like `(SomeCall $in0, $in1,
$in2)`, then this will be translated into C++ call `someFn($in1, $in2, $in0)`.
+Positional range placeholders will be substituted by multiple `dag` object
+parameters at the `NativeCodeCall` use site. For example, if we define
+`SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0,
+$in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`.
+
##### Customizing entire op building
`NativeCodeCall` is not only limited to transforming arguments for building an
@@ -490,8 +495,8 @@
Multi-result ops bring extra complexity to declarative rewrite rules. We use
TableGen `dag` objects to represent ops in patterns; there is no native way to
-indicate that an op generates multiple results. The approach adopted is based
-on **naming convention**: a `__N` suffix is added to a symbol to indicate the
+indicate that an op generates multiple results. The approach adopted is based on
+**naming convention**: a `__N` suffix is added to a symbol to indicate the
`N`-th result.
#### `__N` suffix
@@ -541,12 +546,12 @@
To replace an `N`-result op, the result patterns must generate at least `N`
declared values (see [Declared vs. actual value](#declared-vs-actual-value) for
-definition). If there are more than `N` declared values generated, only the
-last `N` declared values will be used to replace the matched op. Note that
-because of the existence of multi-result op, one result pattern **may** generate
-multiple declared values. So it means we do not necessarily need `N` result
-patterns to replace an `N`-result op. For example, to replace an op with three
-results, you can have
+definition). If there are more than `N` declared values generated, only the last
+`N` declared values will be used to replace the matched op. Note that because of
+the existence of multi-result op, one result pattern **may** generate multiple
+declared values. So it means we do not necessarily need `N` result patterns to
+replace an `N`-result op. For example, to replace an op with three results, you
+can have
```tablegen
// ThreeResultOp/TwoResultOp/OneResultOp generates three/two/one result(s),
@@ -590,8 +595,8 @@
* _Actual operand/result/value_: an operand/result/value of an op instance at
runtime
-The above terms are needed because ops can have multiple results, and some of the
-results can also be variadic. For example,
+The above terms are needed because ops can have multiple results, and some of
+the results can also be variadic. For example,
```tablegen
def MultiVariadicOp : Op<"multi_variadic_op"> {
@@ -611,8 +616,8 @@
We say the above op has 3 declared operands and 3 declared results. But at
runtime, an instance can have 3 values corresponding to `$input2` and 2 values
-correspond to `$output2`; we say it has 5 actual operands and 4 actual
-results. A variadic operand/result is a considered as a declared value that can
+correspond to `$output2`; we say it has 5 actual operands and 4 actual results.
+A variadic operand/result is a considered as a declared value that can
correspond to multiple actual values.
[TODO]
@@ -651,10 +656,10 @@
### Adjusting benefits
-The benefit of a `Pattern` is an integer value indicating the benefit of matching
-the pattern. It determines the priorities of patterns inside the pattern rewrite
-driver. A pattern with a higher benefit is applied before one with a lower
-benefit.
+The benefit of a `Pattern` is an integer value indicating the benefit of
+matching the pattern. It determines the priorities of patterns inside the
+pattern rewrite driver. A pattern with a higher benefit is applied before one
+with a lower benefit.
In DRR, a rule is set to have a benefit of the number of ops in the source
pattern. This is based on the heuristics and assumptions that:
@@ -662,7 +667,6 @@
* Larger matches are more beneficial than smaller ones.
* If a smaller one is applied first the larger one may not apply anymore.
-
The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a
pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value.
@@ -696,8 +700,8 @@
(LocDst1Op (LocDst2Op ..., (location $src2)), (location "outer"))>;
```
-In the above pattern, the generated `LocDst2Op` will use the matched location
-of `LocSrc2Op` while the root `LocDst1Op` node will used the named location
+In the above pattern, the generated `LocDst2Op` will use the matched location of
+`LocSrc2Op` while the root `LocDst1Op` node will used the named location
`outer`.
### `replaceWithValue`
@@ -724,8 +728,8 @@
### Run `mlir-tblgen` to see the generated content
-TableGen syntax sometimes can be obscure; reading the generated content can be
-a very helpful way to understand and debug issues. To build `mlir-tblgen`, run
+TableGen syntax sometimes can be obscure; reading the generated content can be a
+very helpful way to understand and debug issues. To build `mlir-tblgen`, run
`cmake --build . --target mlir-tblgen` in your build directory and find the
`mlir-tblgen` binary in the `bin/` subdirectory. All the supported generators
can be found via `mlir-tblgen --help`.
diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h
--- a/mlir/include/mlir/TableGen/Format.h
+++ b/mlir/include/mlir/TableGen/Format.h
@@ -88,22 +88,33 @@
/// Struct representing a replacement segment for the formatted string. It can
/// be a segment of the formatting template (for `Literal`) or a replacement
-/// parameter (for `PositionalPH` and `SpecialPH`).
+/// parameter (for `PositionalPH`, `PositionalRangePH` and `SpecialPH`).
struct FmtReplacement {
- enum class Type { Empty, Literal, PositionalPH, SpecialPH };
+ enum class Type {
+ Empty,
+ Literal,
+ PositionalPH,
+ PositionalRangePH,
+ SpecialPH
+ };
FmtReplacement() = default;
explicit FmtReplacement(StringRef literal)
: type(Type::Literal), spec(literal) {}
FmtReplacement(StringRef spec, size_t index)
: type(Type::PositionalPH), spec(spec), index(index) {}
+ FmtReplacement(StringRef spec, size_t index, size_t end)
+ : type(Type::PositionalRangePH), spec(spec), index(index), end(end) {}
FmtReplacement(StringRef spec, FmtContext::PHKind placeholder)
: type(Type::SpecialPH), spec(spec), placeholder(placeholder) {}
Type type = Type::Empty;
StringRef spec;
size_t index = 0;
+ size_t end = kUnset;
FmtContext::PHKind placeholder = FmtContext::PHKind::None;
+
+ static constexpr size_t kUnset = -1;
};
class FmtObjectBase {
@@ -121,7 +132,7 @@
// std::vector.
struct CreateAdapters {
template
- std::vector operator()(Ts &... items) {
+ std::vector operator()(Ts &...items) {
return std::vector{&items...};
}
};
@@ -205,7 +216,8 @@
///
/// There are two categories of placeholders accepted, both led by a '$' sign:
///
-/// 1. Positional placeholder: $[0-9]+
+/// 1.a Positional placeholder: $[0-9]+
+/// 1.b Positional range placeholder: $[0-9]+...
/// 2. Special placeholder: $[a-zA-Z_][a-zA-Z0-9_]*
///
/// Replacement parameters for positional placeholders are supplied as the
@@ -214,6 +226,9 @@
/// can use the positional placeholders in any order and repeat any times, for
/// example, "$2 $1 $1 $0" is accepted.
///
+/// Replace parameters for positional range placeholders are supplied as if
+/// positional placeholders were specified with commas separating them.
+///
/// Replacement parameters for special placeholders are supplied using the `ctx`
/// format context.
///
@@ -237,7 +252,7 @@
/// 2. This utility does not support format layout because it is rarely needed
/// in C++ code generation.
template
-inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
+inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals)
-> FmtObject(vals))...))> {
using ParamTuple = decltype(std::make_tuple(
diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp
--- a/mlir/lib/TableGen/Format.cpp
+++ b/mlir/lib/TableGen/Format.cpp
@@ -97,7 +97,7 @@
// First try to see if it's a positional placeholder, and then handle special
// placeholders.
- size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1);
+ size_t end = fmt.find_if_not(std::isdigit, /*From=*/1);
if (end != 1) {
// We have a positional placeholder. Parse the index.
size_t index = 0;
@@ -105,6 +105,14 @@
llvm_unreachable("invalid replacement sequence index");
}
+ // Check if this is the part of a range specification.
+ if (fmt.substr(end, 3) == "...") {
+ // Currently only ranges without upper bound are supported.
+ return {
+ FmtReplacement{fmt.substr(0, end + 3), index, FmtReplacement::kUnset},
+ fmt.substr(end + 3)};
+ }
+
if (end == StringRef::npos) {
// All the remaining characters are part of the positional placeholder.
return {FmtReplacement{fmt, index}, StringRef()};
@@ -164,6 +172,20 @@
continue;
}
+ if (repl.type == FmtReplacement::Type::PositionalRangePH) {
+ if (repl.index >= adapters.size()) {
+ s << repl.spec << kMarkerForNoSubst;
+ continue;
+ }
+ auto range = llvm::makeArrayRef(adapters);
+ range = range.drop_front(repl.index);
+ if (repl.end != FmtReplacement::kUnset)
+ range = range.drop_back(adapters.size() - repl.end);
+ llvm::interleaveComma(range, s,
+ [&](auto &x) { x->format(s, /*Options=*/""); });
+ continue;
+ }
+
assert(repl.type == FmtReplacement::Type::PositionalPH);
if (repl.index >= adapters.size()) {
diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -85,3 +85,8 @@
// CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
(NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
+
+// CHECK: struct test5 : public ::mlir::RewritePattern {
+// CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
+def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
+ (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1497,9 +1497,27 @@
cc_library(
name = "SparseTensor",
- srcs = glob(["lib/Dialect/SparseTensor/IR/*.cpp"]),
+ srcs = glob([
+ "lib/Dialect/SparseTensor/IR/*.cpp",
+ ]),
hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"],
includes = ["include"],
+ deps = [
+ ":IR",
+ ":SideEffectInterfaces",
+ ":SparseTensorAttrDefsIncGen",
+ ":SparseTensorOpsIncGen",
+ ":SparseTensorUtils",
+ ":StandardOps",
+ "//llvm:Support",
+ ],
+)
+
+cc_library(
+ name = "SparseTensorUtils",
+ srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]),
+ hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]),
+ includes = ["include"],
deps = [
":IR",
":SideEffectInterfaces",
@@ -1535,17 +1553,6 @@
],
)
-cc_library(
- name = "SparseTensorUtils",
- srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]),
- hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]),
- includes = ["include"],
- deps = [
- ":IR",
- "//llvm:Support",
- ],
-)
-
td_library(
name = "StdOpsTdFiles",
srcs = [