diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -59,12 +59,13 @@ ## Rule Definition The core construct for defining a rewrite rule is defined in -[`OpBase.td`][OpBase] as +[`PatternBase.td`][PatternBase] as ```tablegen class Pattern< dag sourcePattern, list resultPatterns, list additionalConstraints = [], + list supplementalPatterns = [], dag benefitsAdded = (addBenefit 0)>; ``` @@ -678,6 +679,29 @@ * Apply constraints on multiple bound symbols (`$input` and `TwoResultOp`'s first result must have the same element type). +### Supplying additional result patterns + +Sometimes we need to add additional code after the result patterns, e.g. coping +the attributes of the source op to the result ops. These can be specified via +`SupplementalPatterns` parameter. Similar to auxiliary patterns, they are not +for replacing results in the source pattern. + +For example, we can write + +```tablegen +def CopyAttrFoo: +NativeCodeCallVoid<"$1.getOwner()->setAttr(""$_builder.getStringAttr(\"foo\")" + ", $0.getOwner()->getAttr(\"foo\"))">; + +def : Pattern<(ThreeResultOp:$src ...), + [(TwoResultOp:$dest1 ...), (OneResultOp:$dest2 ...)], + [], + [(CopyAttrFoo $src, $dest1), (CopyAttrFoo $src, $dest2)]>; +``` + +This will copy the attribute `foo` of `ThreeResultOp` in the source pattern to +`TwoResultOp` and `OneResultOp` in the result patterns respectively. + ### Adjusting benefits The benefit of a `Pattern` is an integer value indicating the benefit of diff --git a/mlir/include/mlir/IR/PatternBase.td b/mlir/include/mlir/IR/PatternBase.td --- a/mlir/include/mlir/IR/PatternBase.td +++ b/mlir/include/mlir/IR/PatternBase.td @@ -90,6 +90,7 @@ // * `FiveResultOp`#3: `TwoResultOp2`#1 // * `FiveResultOp`#4: `TwoResultOp2`#1 class Pattern results, list preds = [], + list supplemental_results = [], dag benefitAdded = (addBenefit 0)> { dag sourcePattern = source; // Result patterns. Each result pattern is expected to replace one result @@ -103,6 +104,11 @@ // matched in source pattern and places further constraints on them as a // whole. list constraints = preds; + // Optional patterns that are executed after the result patterns. Similar to + // auxiliary patterns, they are not used for replacement. These patterns can + // be used to generate additional code after the result patterns, e.g. copy + // the attributes from the source op to the result ops. + list supplementalPatterns = supplemental_results; // The delta value added to the default benefit value. The default value is // the number of ops in the source pattern. The rule with the highest final // benefit value will be applied first if there are multiple rules matches. @@ -112,8 +118,9 @@ // Form of a pattern which produces a single result. class Pat preds = [], + list supplemental_results = [], dag benefitAdded = (addBenefit 0)> : - Pattern; + Pattern; // Native code call wrapper. This allows invoking an arbitrary C++ expression // to create an op operand/attribute or replace an op result. diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -482,6 +482,14 @@ // Returns the constraints. std::vector getConstraints() const; + // Returns the number of supplemental auxiliary patterns generated by applying + // this rewrite rule. + int getNumSupplementalPatterns() const; + + // Returns the DAG tree root node of the `index`-th supplemental result + // pattern. + DagNode getSupplementalPattern(unsigned index) const; + // Returns the benefit score of the pattern. int getBenefit() const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -675,6 +675,16 @@ return ret; } +int Pattern::getNumSupplementalPatterns() const { + auto *results = def.getValueAsListInit("supplementalPatterns"); + return results->size(); +} + +DagNode Pattern::getSupplementalPattern(unsigned index) const { + auto *results = def.getValueAsListInit("supplementalPatterns"); + return DagNode(cast(results->getElement(index))); +} + int Pattern::getBenefit() const { // The initial benefit value is a heuristic with number of ops in the source // pattern. diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1096,6 +1096,16 @@ os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; } + // Process supplemtary patterns. + for (int i = 0, offset = 0; i < pattern.getNumSupplementalPatterns(); ++i) { + DagNode resultTree = pattern.getSupplementalPattern(i); + offset -= getNodeValueCount(pattern.getResultPattern(i)); + auto val = handleResultPattern(resultTree, offset, 0); + if (resultTree.isNativeCodeCall() && + resultTree.getNumReturnsOfNativeCode() == 0) + os << val << ";\n"; + } + LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); }