diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -427,19 +427,16 @@ ``` }]; - let arguments = (ins OptionalAttr:$rootKind, - Confined:$benefit, + let arguments = (ins Confined:$benefit, OptionalAttr:$sym_name); let regions = (region SizedRegion<1>:$body); let assemblyFormat = [{ - ($sym_name^)? `:` `benefit` `(` $benefit `)` - (`,` `root` `(` $rootKind^ `)`)? attr-dict-with-keyword $body + ($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $body }]; let builders = [ - OpBuilder<(ins CArg<"Optional", "llvm::None">:$rootKind, - CArg<"Optional", "1">:$benefit, - CArg<"Optional", "llvm::None">:$name)>, + OpBuilder<(ins CArg<"Optional", "1">:$benefit, + CArg<"Optional", "llvm::None">:$name)>, ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -451,10 +448,6 @@ /// Returns the rewrite operation of this pattern. RewriteOp getRewriter(); - - /// Return the root operation kind that this pattern matches, or None if - /// there isn't a specific root. - Optional getRootKind(); }]; } @@ -579,19 +572,22 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [ Terminator, HasParent<"pdl::PatternOp">, NoTerminator, NoRegionArguments, - SingleBlock + SingleBlock, AttrSizedOperandSegments ]> { let summary = "Specify the rewrite of a matched pattern"; let description = [{ `pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify - the main rewrite of a `pdl.pattern`, on the specified root operation. The + the main rewrite of a `pdl.pattern`, on the optional root operation. The rewrite is specified either via a string name (`name`) to a native rewrite function, or via the region body. The rewrite region, if specified, must contain a single block. If the rewrite is external it functions similarly to `pdl.apply_native_rewrite`, and takes a set of constant parameters and a set of additional positional values defined within the matcher as arguments. If the rewrite is external, the root operation is - passed to the native function as the first argument. + passed to the native function as the leading arguments. The root operation, + if provided, specifies which operation is used as the root of the search + tree. If it is omitted, the pdl_interp lowering will automatically select + the best root of the search tree among all operations in the pattern. Example: @@ -599,23 +595,31 @@ // Specify an external rewrite function: pdl.rewrite %root with "myExternalRewriter"(%value : !pdl.value) - // Specify the rewrite inline using PDL: + // Specify a rewrite inline using PDL with the given root: pdl.rewrite %root { %op = pdl.operation "foo.op"(%arg0, %arg1) pdl.replace %root with %op } + + // Specify a rewrite inline using PDL, automatically selecting root: + pdl.rewrite { + %op1 = pdl.operation "foo.op"(%arg0, %arg1) + %op2 = pdl.operation "bar.op"(%arg0, %arg1) + pdl.replace %root1 with %op1 + pdl.replace %root2 with %op2 + } ``` }]; - let arguments = (ins PDL_Operation:$root, + let arguments = (ins Optional:$root, OptionalAttr:$name, Variadic:$externalArgs, OptionalAttr:$externalConstParams); let regions = (region AnyRegion:$body); let assemblyFormat = [{ - $root (`with` $name^ ($externalConstParams^)? - (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? - ($body^)? + ($root^)? (`with` $name^ ($externalConstParams^)? + (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? + ($body^)? attr-dict-with-keyword }]; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -43,27 +43,27 @@ using ValueMapScope = llvm::ScopedHashTableScope; /// Generate interpreter operations for the tree rooted at the given matcher - /// node. - Block *generateMatcher(MatcherNode &node); + /// node, in the specified region. + Block *generateMatcher(MatcherNode &node, Region ®ion); - /// Get or create an access to the provided positional value within the - /// current block. - Value getValueAt(Block *cur, Position *pos); + /// Get or create an access to the provided positional value in the current + /// block. This operation may mutate the provided block pointer if nested + /// regions (i.e., pdl_interp.iterate) are required. + Value getValueAt(Block *¤tBlock, Position *pos); - /// Create an interpreter predicate operation, branching to the provided true - /// and false destinations. - void generatePredicate(Block *currentBlock, Qualifier *question, - Qualifier *answer, Value val, Block *trueDest, - Block *falseDest); + /// Create the interpreter predicate operations. This operation may mutate the + /// provided current block pointer if nested regions (iterates) are required. + void generate(BoolNode *boolNode, Block *¤tBlock, Value val); - /// Create an interpreter switch predicate operation, with a provided default - /// and several case destinations. - void generateSwitch(SwitchNode *switchNode, Block *currentBlock, - Qualifier *question, Value val, Block *defaultDest); + /// Create the interpreter switch / predicate operations, with several case + /// destinations. This operation never mutates the provided current block + /// pointer, because the switch operation does not need Values beyond `val`. + void generate(SwitchNode *switchNode, Block *currentBlock, Value val); - /// Create the interpreter operations to record a successful pattern match. - void generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern); + /// Create the interpreter operations to record a successful pattern match + /// using the contained root operation. This operation may mutate the current + /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. + void generate(SuccessNode *successNode, Block *¤tBlock); /// Generate a rewriter function for the given pattern operation, and returns /// a reference to that function. @@ -156,7 +156,8 @@ // Generate a root matcher node from the provided PDL module. std::unique_ptr root = MatcherNode::generateMatcherTree( module, predicateBuilder, valueToPosition); - Block *firstMatcherBlock = generateMatcher(*root); + Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); + assert(failureBlockStack.empty() && "Failed to empty the stack"); // After generation, merged the first matched block into the entry. matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), @@ -164,9 +165,9 @@ firstMatcherBlock->erase(); } -Block *PatternLowering::generateMatcher(MatcherNode &node) { +Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { // Push a new scope for the values used by this matcher. - Block *block = matcherFunc.addBlock(); + Block *const block = ®ion.emplaceBlock(); ValueMapScope scope(values); // If this is the return node, simply insert the corresponding interpreter @@ -177,66 +178,101 @@ return block; } - // If this node contains a position, get the corresponding value for this - // block. - Position *position = node.getPosition(); - Value val = position ? getValueAt(block, position) : Value(); - // Get the next block in the match sequence. + // This is intentionally executed first, before we get the value for the + // position associated with the node, so that we preserve an "there exist" + // semantics: if getting a value requires an upward traversal (going from a + // value to its consumers), we want to perform the check on all the consumers + // before we pass control to the failure node. std::unique_ptr &failureNode = node.getFailureNode(); - Block *nextBlock; + Block *failureBlock; if (failureNode) { - nextBlock = generateMatcher(*failureNode); - failureBlockStack.push_back(nextBlock); + failureBlock = generateMatcher(*failureNode, region); + failureBlockStack.push_back(failureBlock); } else { assert(!failureBlockStack.empty() && "expected valid failure block"); - nextBlock = failureBlockStack.back(); + failureBlock = failureBlockStack.back(); } + // If this node contains a position, get the corresponding value for this + // block. + Block *currentBlock = block; + Position *position = node.getPosition(); + Value val = position ? getValueAt(currentBlock, position) : Value(); + // If this value corresponds to an operation, record that we are going to use // its location as part of a fused location. bool isOperationValue = val && val.getType().isa(); if (isOperationValue) locOps.insert(val); - // Generate code for a boolean predicate node. - if (auto *boolNode = dyn_cast(&node)) { - auto *child = generateMatcher(*boolNode->getSuccessNode()); - generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, - child, nextBlock); - - // Generate code for a switch node. - } else if (auto *switchNode = dyn_cast(&node)) { - generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); - - // Generate code for a success node. - } else if (auto *successNode = dyn_cast(&node)) { - generateRecordMatch(block, nextBlock, successNode->getPattern()); + // Dispatch to the correct method based on derived node type. + TypeSwitch(&node) + .Case( + [&](auto *derivedNode) { generate(derivedNode, currentBlock, val); }) + .Case([&](SuccessNode *successNode) { + generate(successNode, currentBlock); + }); + + // Pop all the failure blocks that were inserted due to nesting of + // pdl_interp.iterate. + while (failureBlockStack.back() != failureBlock) { + failureBlockStack.pop_back(); + assert(!failureBlockStack.empty()); } + // Pop the new failure block. if (failureNode) failureBlockStack.pop_back(); + if (isOperationValue) locOps.remove(val); + return block; } -Value PatternLowering::getValueAt(Block *cur, Position *pos) { +Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { if (Value val = values.lookup(pos)) return val; // Get the value for the parent position. - Value parentVal = getValueAt(cur, pos->getParent()); + Value parentVal = getValueAt(currentBlock, pos->getParent()); // TODO: Use a location from the position. Location loc = parentVal.getLoc(); - builder.setInsertionPointToEnd(cur); + builder.setInsertionPointToEnd(currentBlock); Value value; switch (pos->getKind()) { - case Predicates::OperationPos: - value = builder.create( - loc, builder.getType(), parentVal); + case Predicates::OperationPos: { + auto *operationPos = cast(pos); + if (operationPos->isUpward()) { + // The first operation retrieves the users. + Type operationTy = builder.getType(); + value = builder.create( + loc, pdl::RangeType::get(operationTy), parentVal, + operationPos->getIndex()); + + // The second operation iterates over them. + assert(!failureBlockStack.empty() && "expected valid failure block"); + auto foreach = pdl_interp::ForEachOp::create(builder, loc, value, + failureBlockStack.back()); + value = foreach.getLoopVariable(); + + // The third operation is the continuation to the next iterate. + Block *continueBlock = &foreach.region().emplaceBlock(); + builder.setInsertionPointToEnd(continueBlock); + builder.create(loc); + failureBlockStack.push_back(continueBlock); + + // Update the current block pointer. + currentBlock = &foreach.region().front(); + builder.setInsertionPointToEnd(currentBlock); + } else { + value = builder.create( + loc, builder.getType(), parentVal); + } break; + } case Predicates::OperandPos: { auto *operandPos = cast(pos); value = builder.create( @@ -285,41 +321,59 @@ llvm_unreachable("Generating unknown Position getter"); break; } + values.insert(pos, value); return value; } -void PatternLowering::generatePredicate(Block *currentBlock, - Qualifier *question, Qualifier *answer, - Value val, Block *trueDest, - Block *falseDest) { - builder.setInsertionPointToEnd(currentBlock); +void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, + Value val) { Location loc = val.getLoc(); + Qualifier *question = boolNode->getQuestion(); + Qualifier *answer = boolNode->getAnswer(); + Region *region = currentBlock->getParent(); + + // Execute the getValue queries first, so that we create success + // matcher in the correct (possibly nested) region. + SmallVector args; + if (auto *equalToQuestion = dyn_cast(question)) + args = {getValueAt(currentBlock, equalToQuestion->getValue())}; + else if (auto *cstQuestion = dyn_cast(question)) + for (Position *position : std::get<1>(cstQuestion->getValue())) + args.push_back(getValueAt(currentBlock, position)); + + // Generate the matcher in the current (potentially nested) region + // and get the failure successor. + Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); + Block *failure = failureBlockStack.back(); + + // Finally, create the predicate. + builder.setInsertionPointToEnd(currentBlock); Predicates::Kind kind = question->getKind(); switch (kind) { case Predicates::IsNotNullQuestion: - builder.create(loc, val, trueDest, falseDest); + builder.create(loc, val, success, failure); break; case Predicates::OperationNameQuestion: { auto *opNameAnswer = cast(answer); builder.create( - loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); + loc, val, opNameAnswer->getValue().getStringRef(), success, failure); break; } case Predicates::TypeQuestion: { auto *ans = cast(answer); if (val.getType().isa()) builder.create( - loc, val, ans->getValue().cast(), trueDest, falseDest); + loc, val, ans->getValue().cast(), success, failure); else builder.create( - loc, val, ans->getValue().cast(), trueDest, falseDest); + loc, val, ans->getValue().cast(), success, failure); break; } case Predicates::AttributeQuestion: { auto *ans = cast(answer); builder.create(loc, val, ans->getValue(), - trueDest, falseDest); + success, failure); break; } case Predicates::OperandCountAtLeastQuestion: @@ -327,31 +381,27 @@ builder.create( loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, - trueDest, falseDest); + success, failure); break; case Predicates::ResultCountAtLeastQuestion: case Predicates::ResultCountQuestion: builder.create( loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, - trueDest, falseDest); + success, failure); break; case Predicates::EqualToQuestion: { - auto *equalToQuestion = cast(question); - builder.create( - loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), - trueDest, falseDest); + bool true_answer = isa(answer); + builder.create(loc, val, args.front(), + true_answer ? success : failure, + true_answer ? failure : success); break; } case Predicates::ConstraintQuestion: { - auto *cstQuestion = cast(question); - SmallVector args; - for (Position *position : std::get<1>(cstQuestion->getValue())) - args.push_back(getValueAt(currentBlock, position)); + auto value = cast(question)->getValue(); builder.create( - loc, std::get<0>(cstQuestion->getValue()), args, - std::get<2>(cstQuestion->getValue()).cast(), trueDest, - falseDest); + loc, std::get<0>(value), args, std::get<2>(value).cast(), + success, failure); break; } default: @@ -373,9 +423,12 @@ builder.create(val.getLoc(), val, values, defaultDest, blocks); } -void PatternLowering::generateSwitch(SwitchNode *switchNode, - Block *currentBlock, Qualifier *question, - Value val, Block *defaultDest) { +void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, + Value val) { + Qualifier *question = switchNode->getQuestion(); + Region *region = currentBlock->getParent(); + Block *defaultDest = failureBlockStack.back(); + // If the switch question is not an exact answer, i.e. for the `at_least` // cases, we generate a special block sequence. Predicates::Kind kind = question->getKind(); @@ -407,12 +460,26 @@ // ... // failureBlockStack.push_back(defaultDest); + Location loc = val.getLoc(); for (unsigned idx : sortedChildren) { auto &child = switchNode->getChild(idx); - Block *childBlock = generateMatcher(*child.second); + // TODO - can the matcher nest? probably yes, make sure this still works + Block *childBlock = generateMatcher(*child.second, *region); Block *predicateBlock = builder.createBlock(childBlock); - generatePredicate(predicateBlock, question, child.first, val, childBlock, - defaultDest); + builder.setInsertionPointToEnd(predicateBlock); + unsigned ans = cast(child.first)->getValue(); + switch (kind) { + case Predicates::OperandCountAtLeastQuestion: + builder.create( + loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + break; + case Predicates::ResultCountAtLeastQuestion: + builder.create( + loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + break; + default: + llvm_unreachable("Generating invalid AtLeast operation"); + } failureBlockStack.back() = predicateBlock; } Block *firstPredicateBlock = failureBlockStack.pop_back_val(); @@ -426,7 +493,7 @@ // switch. llvm::MapVector children; for (auto &it : switchNode->getChildren()) - children.insert({it.first, generateMatcher(*it.second)}); + children.insert({it.first, generateMatcher(*it.second, *region)}); builder.setInsertionPointToEnd(currentBlock); switch (question->getKind()) { @@ -455,8 +522,10 @@ } } -void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern) { +void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { + pdl::PatternOp pattern = successNode->getPattern(); + Value root = successNode->getRoot(); + // Generate a rewriter for the pattern this success node represents, and track // any values used from the match region. SmallVector usedMatchValues; @@ -478,14 +547,15 @@ // Grab the root kind if present. StringAttr rootKindAttr; - if (Optional rootKind = pattern.getRootKind()) - rootKindAttr = builder.getStringAttr(*rootKind); + if (pdl::OperationOp rootOp = root.getDefiningOp()) + if (Optional rootKind = rootOp.name()) + rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); builder.create( pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), - nextBlock); + failureBlockStack.back()); } SymbolRefAttr PatternLowering::generateRewriter( @@ -535,8 +605,10 @@ // method. pdl::RewriteOp rewriter = pattern.getRewriter(); if (StringAttr rewriteName = rewriter.nameAttr()) { + SmallVector args; + if (rewriter.root()) + args.push_back(mapRewriteValue(rewriter.root())); auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); - SmallVector args(1, mapRewriteValue(rewriter.root())); args.append(mappedArgs.begin(), mappedArgs.end()); builder.create( rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -65,8 +65,9 @@ // Answers. AttributeAnswer, - TrueAnswer, + FalseAnswer, OperationNameAnswer, + TrueAnswer, TypeAnswer, UnsignedAnswer, }; @@ -216,24 +217,40 @@ /// An operation position describes an operation node in the IR. Other position /// kinds are formed with respect to an operation position. -struct OperationPosition : public PredicateBase, - Predicates::OperationPos> { +struct OperationPosition + : public PredicateBase, unsigned>, + Predicates::OperationPos> { explicit OperationPosition(const KeyTy &key) : Base(key) { - parent = key.first; + parent = std::get<0>(key); + } + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); } /// Gets the root position. static OperationPosition *getRoot(StorageUniquer &uniquer) { - return Base::get(uniquer, nullptr, 0); + return Base::get(uniquer, nullptr, llvm::None, 0); } + /// Gets an operation position with the given parent. - static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { - return Base::get(uniquer, parent, parent->getOperationDepth() + 1); + /// The optional operand index indicates a downward traversal from operand to + /// the operation. + static OperationPosition *get(StorageUniquer &uniquer, Position *parent, + Optional operand = llvm::None) { + return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1); } + /// Returns if this operation position is upward, accepting an input. + bool isUpward() const { return std::get<1>(key).hasValue(); } + + /// Returns the operand index for an upward operation position. + unsigned getIndex() const { return *std::get<1>(key); } + /// Returns the depth of this position. - unsigned getDepth() const { return key.second; } + unsigned getDepth() const { return std::get<2>(key); } /// Returns if this operation position corresponds to the root. bool isRoot() const { return getDepth() == 0; } @@ -346,6 +363,12 @@ using Base::Base; }; +/// An Answer representing a boolean 'false' value. +struct FalseAnswer + : PredicateBase { + using Base::Base; +}; + /// An Answer representing a `Type` value. The value is stored as either a /// TypeAttr, or an ArrayAttr of TypeAttr. struct TypeAnswer : public PredicateBase(); registerParametricStorageType(); registerParametricStorageType(); + registerSingletonStorageType(); registerSingletonStorageType(); // Register the types of Answers with the uniquer. @@ -485,6 +509,14 @@ return OperationPosition::get(uniquer, p); } + /// Returns the operation position accepting the value at the given position. + OperationPosition *getOperandAcceptingOp(Position *p, unsigned operand) { + assert((isa(p)) && + "expected result position"); + return OperationPosition::get(uniquer, p, operand); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, Identifier::get(name, ctx)); @@ -536,11 +568,16 @@ AttributeAnswer::get(uniquer, attr)}; } - /// Create a predicate comparing two values. + /// Create a predicate checking if two values are equal. Predicate getEqualTo(Position *pos) { return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; } + /// Create a predicate checking if two values are not equal. + Predicate getNotEqualTo(Position *pos) { + return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)}; + } + /// Create a predicate that applies a generic constraint. Predicate getConstraint(StringRef name, ArrayRef pos, Attribute params) { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -152,7 +152,7 @@ /// matched. This does not terminate the matcher, as there may be multiple /// successful matches. struct SuccessNode : public MatcherNode { - explicit SuccessNode(pdl::PatternOp pattern, + explicit SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode); /// Returns if the given matcher node is an instance of this class, used to @@ -164,10 +164,16 @@ /// Return the high level pattern operation that is matched with this node. pdl::PatternOp getPattern() const { return pattern; } + /// Return the chosen root of the pattern. + Value getRoot() const { return root; } + private: /// The high level pattern operation that was successfully matched with this /// node. pdl::PatternOp pattern; + + /// The chosen root of the pattern. + Value root; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -7,12 +7,18 @@ //===----------------------------------------------------------------------===// #include "PredicateTree.h" +#include "mlir/Conversion/PDLToPDLInterp/RootOrdering.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "pdl-predicate-tree" using namespace mlir; using namespace mlir::pdl_to_pdl_interp; @@ -102,7 +108,8 @@ static void getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, - OperationPosition *pos) { + OperationPosition *pos, + Optional ignoreOperand = llvm::None) { assert(val.getType().isa() && "expected operation"); pdl::OperationOp op = cast(val.getDefiningOp()); OperationPosition *opPos = cast(pos); @@ -158,6 +165,11 @@ bool isVariadic = operandIt.value().getType().isa(); foundVariableLength |= isVariadic; + // Ignore the specified operand, usually because this position was + // visited in an upward traversal via a nondeterministic choice. + if (ignoreOperand && *ignoreOperand == operandIt.index()) + continue; + Position *pos = foundVariableLength ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) @@ -300,15 +312,284 @@ } } +namespace { + +/// An op accepting a value at an optional index. +struct OpIndex { + Value parent; + Optional index; +}; + +/// The parent and operand index of each operation for each root, stored +/// as a nested map [root][operation]. +using ParentMaps = DenseMap>; + +} // namespace + +/// Given a pattern, determines the set of roots present in this pattern. +/// These are the operations whose results are not consumed by other operations. +static SmallVector detectRoots(pdl::PatternOp pattern) { + // First, collect all the operations that are used as operands + // to other operations. These are not roots by default. + DenseSet used; + for (auto operationOp : pattern.body().getOps()) + for (Value operand : operationOp.operands()) + TypeSwitch(operand.getDefiningOp()) + .Case([&used](auto resultOp) { + used.insert(resultOp.parent()); + }); + + // Remove the specified root from the use set, so that we can + // always select it as a root, even if it is used by other operations. + if (Value root = pattern.getRewriter().root()) + used.erase(root); + + // Finally, collect all the unused operations. + SmallVector roots; + for (Value operationOp : pattern.body().getOps()) + if (!used.contains(operationOp)) + roots.push_back(operationOp); + + return roots; +} + +/// Given a list of candidate roots, builds the cost graph for connecting them. +/// The graph is formed by traversing the DAG of operations starting from each +/// root and marking the depth of each connector value (operand). Then we join +/// the candidate roots based on the common connector values, taking the one +/// with the minimum depth. Along the way, we compute, for each candidate root, +/// a mapping from each operation (in the DAG underneath this root) to its +/// parent operation and the corresponding operand index. +static void buildCostGraph(ArrayRef roots, RootOrderingGraph &graph, + ParentMaps &parentMaps) { + + // A root of a value and its depth (distance from root to the value). + struct RootDepth { + Value root; + unsigned depth = 0; + }; + + // Map from candidate connector values to their roots and depths. Using a + // small vector with 1 entry because most values belong to a single root. + llvm::MapVector> connectorsRootsDepths; + + // Perform a breadth-first traversal of the op DAG rooted at each root. + for (Value root : roots) { + // The entry of a queue. The entry consists of the following items: + // * the value in the DAG underneath the root; + // * the parent of the value; + // * the operand index of the value in its parent; + // * the depth of the visited value. + struct Entry { + Value value; + Value parent; + Optional index; + unsigned depth; + Entry(Value value, Value parent, Optional index, unsigned depth) + : value(value), parent(parent), index(index), depth(depth) {} + }; + + // The queue of visited values. A value may be present multiple times in + // the queue, for multiple parents. We only accept the first occurrence, + // which is guaranteed to have the lowest depth. + std::queue toVisit; + toVisit.emplace(root, Value(), 0, 0); + + // The map from value to its parent for the current root. + DenseMap &parentMap = parentMaps[root]; + + while (!toVisit.empty()) { + Entry entry = toVisit.front(); + toVisit.pop(); + // Skip if already visited. + if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) + continue; + + // Mark the root and depth of the value. + connectorsRootsDepths[entry.value].push_back({root, entry.depth}); + + // Traverse the operands of an operation and result ops. + // We intentionally do not traverse attributes and types, because those + // are expensive to join on. + TypeSwitch(entry.value.getDefiningOp()) + .Case([&](auto operationOp) { + for (auto p : llvm::enumerate(operationOp.operands())) + toVisit.emplace(p.value(), entry.value, p.index(), + entry.depth + 1); + }) + .Case([&](auto resultOp) { + toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(), + entry.depth); + }); + } + } + + // Now build the cost graph. + // This is simply a minimum over all depths for the target root. + unsigned nextID = 0; + for (const auto &connectorRootsDepths : connectorsRootsDepths) { + Value value = connectorRootsDepths.first; + ArrayRef rootsDepths = connectorRootsDepths.second; + // If there is only one root for this value, this will not trigger + // any edges in the cost graph (a perf optimization). + if (rootsDepths.size() == 1) + continue; + + for (const RootDepth &p : rootsDepths) { + for (const RootDepth &q : rootsDepths) { + if (&p == &q) + continue; + // insert or retrieve the property of edge from p to q + RootOrderingCost &cost = graph[q.root][p.root]; + if (!cost.connector /* new edge */ || cost.cost.first > q.depth) { + if (!cost.connector) + cost.cost.second = nextID++; + cost.cost.first = q.depth; + cost.connector = value; + } + } + } + } + + assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && + "the pattern contains a candidate root disconnected from the others"); +} + /// Given a pattern operation, build the set of matcher predicates necessary to /// match this pattern. -static void buildPredicateList(pdl::PatternOp pattern, - PredicateBuilder &builder, - std::vector &predList, - DenseMap &valueToPosition) { - getTreePredicates(predList, pattern.getRewriter().root(), builder, - valueToPosition, builder.getRoot()); +static Value buildPredicateList(pdl::PatternOp pattern, + PredicateBuilder &builder, + std::vector &predList, + DenseMap &valueToPosition) { + SmallVector roots = detectRoots(pattern); + + // Build the root ordering graph and compute the parent maps. + RootOrderingGraph graph; + ParentMaps parentMaps; + buildCostGraph(roots, graph, parentMaps); + LLVM_DEBUG({ + llvm::dbgs() << "Graph:\n"; + for (auto &target : graph) { + llvm::dbgs() << " * " << target.first << "\n"; + for (auto &source : target.second) { + RootOrderingCost c = source.second; + llvm::dbgs() << " <- " << source.first << ": " << c.cost.first + << ":" << c.cost.second << " via " << c.connector.getLoc() + << "\n"; + } + } + }); + + // Solve the optimal branching problem for each candidate root, or use the + // provided one. + Value bestRoot = pattern.getRewriter().root(); + OptimalBranching::EdgeList bestEdges; + if (!bestRoot) { + unsigned bestCost = 0; + LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); + for (Value root : roots) { + OptimalBranching solver(graph, root); + unsigned cost = solver.solve(); + LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); + if (!bestRoot || bestCost > cost) { + bestCost = cost; + bestRoot = root; + bestEdges = solver.preOrderTraversal(roots); + } + } + } else { + OptimalBranching solver(graph, bestRoot); + solver.solve(); + bestEdges = solver.preOrderTraversal(roots); + } + + LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); + + // The best root is the starting point for the traversal. Get the tree + // predicates for the DAG rooted at bestRoot. + getTreePredicates(predList, bestRoot, builder, valueToPosition, + builder.getRoot()); + + // Traverse the selected optimal branching. For all edges in order, traverse + // up starting from the connector, until the candidate root is reached, and + // call getTreePredicates at every node along the way. + for (const std::pair &edge : bestEdges) { + Value target = edge.first; + Value source = edge.second; + + // Check if we already visited the target root. This happens in two cases: + // 1) the initial root (bestRoot); + // 2) a root that is dominated by (contained in the subtree rooted at) an + // already visited root. + if (valueToPosition.count(target)) + continue; + + // Determine the connector. + Value connector = graph[target][source].connector; + assert(connector && "Invalid edge"); + LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); + DenseMap parentMap = parentMaps.lookup(target); + Position *pos = valueToPosition.lookup(connector); + assert(pos && "The value has not been traversed yet"); + bool first = true; + + // Traverse from the connector upwards towards the target root. + for (Value value = connector; value != target;) { + OpIndex opIndex = parentMap.lookup(value); + assert(opIndex.parent && "Missing parent"); + value = opIndex.parent; + + // Visit the node. + TypeSwitch(value.getDefiningOp()) + .Case([&](auto operationOp) { + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + OperationPosition *opPos = + builder.getOperandAcceptingOp(pos, *opIndex.index); + + // Guard against traversing back to where we came from. + if (first) { + Position *parent = pos->getParent(); + predList.emplace_back(opPos, builder.getNotEqualTo(parent)); + first = false; + } + + // Guard against duplicate upward visits. These are not possible, + // because if this value was already visited, it would have been + // cheaper to start the traversal at this value rather than at the + // `connector`, violating the optimality of our spanning tree. + bool inserted = valueToPosition.try_emplace(value, opPos).second; + assert(inserted && "Duplicate upward visit"); + + // Obtain the tree predicates at the current value + getTreePredicates(predList, value, builder, valueToPosition, opPos, + opIndex.index); + + // Update the position + pos = opPos; + }) + .Case([&](auto resultOp) { + // Traverse up an individual result. + auto *opPos = dyn_cast(pos); + assert(opPos && "Operations and results must be interleaved"); + pos = builder.getResult(opPos, *opIndex.index); + }) + .Case([&](auto resultOp) { + // Traverse up a group of results. + auto *opPos = dyn_cast(pos); + assert(opPos && "Operations and results must be interleaved"); + if (opIndex.index) + pos = builder.getResultGroup(opPos, opIndex.index, + /*isVariadic=*/false); + else + pos = builder.getAllResults(opPos); + }); + } + } + getNonTreePredicates(pattern, predList, builder, valueToPosition); + + return bestRoot; } //===----------------------------------------------------------------------===// @@ -382,9 +663,11 @@ /// This class wraps a set of ordered predicates that are used within a specific /// pattern operation. struct OrderedPredicateList { - OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} + OrderedPredicateList(pdl::PatternOp pattern, Value root) + : pattern(pattern), root(root) {} pdl::PatternOp pattern; + Value root; DenseSet predicates; }; } // end anonymous namespace @@ -421,7 +704,8 @@ std::vector::iterator end) { if (current == end) { // We've hit the end of a pattern, so create a successful result node. - node = std::make_unique(list.pattern, std::move(node)); + node = + std::make_unique(list.pattern, list.root, std::move(node)); // If the pattern doesn't contain this predicate, ignore it. } else if (list.predicates.find(*current) == list.predicates.end()) { @@ -489,22 +773,37 @@ std::unique_ptr MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap &valueToPosition) { - // Collect the set of predicates contained within the pattern operations of - // the module. - SmallVector>, 16> - patternsAndPredicates; + // The set of predicates contained within the pattern operations of the + // module. + struct PatternPredicates { + /// A pattern. + pdl::PatternOp pattern; + + /// A root of the pattern chosen among the candidate roots in pdl.rewrite. + Value root; + + /// The extracted predicates for this pattern and root. + std::vector predicates; + + PatternPredicates(pdl::PatternOp pattern, Value root, + std::vector predicates) + : pattern(pattern), root(root), predicates(std::move(predicates)) {} + }; + + SmallVector patternsAndPredicates; for (pdl::PatternOp pattern : module.getOps()) { std::vector predicateList; - buildPredicateList(pattern, builder, predicateList, valueToPosition); - patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); + Value root = + buildPredicateList(pattern, builder, predicateList, valueToPosition); + patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); } // Associate a pattern result with each unique predicate. DenseSet uniqued; for (auto &patternAndPredList : patternsAndPredicates) { - for (auto &predicate : patternAndPredList.second) { + for (auto &predicate : patternAndPredList.predicates) { auto it = uniqued.insert(predicate); - it.first->patternToAnswer.try_emplace(patternAndPredList.first, + it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, predicate.answer); } } @@ -513,8 +812,9 @@ std::vector lists; lists.reserve(patternsAndPredicates.size()); for (auto &patternAndPredList : patternsAndPredicates) { - OrderedPredicateList list(patternAndPredList.first); - for (auto &predicate : patternAndPredList.second) { + OrderedPredicateList list(patternAndPredList.pattern, + patternAndPredList.root); + for (auto &predicate : patternAndPredList.predicates) { OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); list.predicates.insert(orderedPredicate); @@ -580,11 +880,11 @@ // SuccessNode //===----------------------------------------------------------------------===// -SuccessNode::SuccessNode(pdl::PatternOp pattern, +SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode) : MatcherNode(TypeID::get(), /*position=*/nullptr, /*question=*/nullptr, std::move(failureNode)), - pattern(pattern) {} + pattern(pattern), root(root) {} //===----------------------------------------------------------------------===// // SwitchNode diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -41,11 +41,7 @@ Operation *user = use.getOwner(); if (user->getBlock() != matcherBlock) continue; - if (isa(user)) - return true; - // Only the first operand of RewriteOp may be bound to, i.e. the root - // operation of the pattern. - if (isa(user) && use.getOperandNumber() == 0) + if (isa(user)) return true; // A result by itself is not binding, it must also be bound. if (isa(user) && @@ -280,11 +276,8 @@ } void PatternOp::build(OpBuilder &builder, OperationState &state, - Optional rootKind, Optional benefit, - Optional name) { - build(builder, state, - rootKind ? builder.getStringAttr(*rootKind) : StringAttr(), - builder.getI16IntegerAttr(benefit ? *benefit : 0), + Optional benefit, Optional name) { + build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0), name ? builder.getStringAttr(*name) : StringAttr()); state.regions[0]->emplaceBlock(); } @@ -294,13 +287,6 @@ return cast(body().front().getTerminator()); } -/// Return the root operation kind that this pattern matches, or None if -/// there isn't a specific root. -Optional PatternOp::getRootKind() { - OperationOp rootOp = cast(getRewriter().root().getDefiningOp()); - return rootOp.name(); -} - //===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -408,3 +408,101 @@ pdl.rewrite %apply with "rewriter" } } + + +// ----- + +// CHECK-LABEL: module @multi_root +module @multi_root { + // Check the lowering of a simple two-root pattern. + // This checks that we correctly generate the pdl_interp.choose_op operation + // and tie the break between %root1 and %root2 in favor of %root1. + + // CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[ROOT1]] + // CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]] + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL1]] : !pdl.value at 0 + // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[OPS]] + // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]] + // CHECK-DAG: %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP1]] : !pdl.operation -> ^{{.*}}, ^[[FAIL:.*]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP2]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[VAL1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[VAL2]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation + // CHECK-DAG: pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[FAIL]] + + pdl.pattern @rewrite_multi_root : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1 with "rewriter"(%root2 : !pdl.operation) + } +} + + +// ----- + +// CHECK-LABEL: module @overlapping_roots +module @overlapping_roots { + // Check the lowering of a degenerate two-root pattern, where one root + // is in the subtree rooted at another. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL]] + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[OP]] + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[OP]] + // CHECK-DAG: pdl_interp.is_not_null %[[VAL]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value + + pdl.pattern @rewrite_overlapping_roots : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type) + %val = pdl.result 0 of %op + %root = pdl.operation(%val : !pdl.value) + pdl.rewrite with "rewriter"(%root : !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @force_overlapped_root +module @force_overlapped_root { + // Check the lowering of a degenerate two-root pattern, where one root + // is in the subtree rooted at another, and we are forced to use this + // root as the root of the search tree. + + // CHECK: func @matcher(%[[OP:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_result 0 of %[[OP]] + // CHECK-DAG: pdl_interp.check_operand_count of %[[OP]] is 2 + // CHECK-DAG: pdl_interp.check_result_count of %[[OP]] is 1 + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[OP]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[OP]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value + // CHECK-DAG: %[[ROOTS:.*]] = pdl_interp.get_users of %[[VAL]] : !pdl.value at 0 + // CHECK-DAG: pdl_interp.foreach %[[ROOT:.*]] : !pdl.operation in %[[ROOTS]] + // CHECK-DAG: pdl_interp.is_not_null %[[ROOT]] : !pdl.operation + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 1 + + pdl.pattern @rewrite_forced_overlapped_root : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type) + %val = pdl.result 0 of %op + %root = pdl.operation(%val : !pdl.value) + pdl.rewrite %op with "rewriter"(%root : !pdl.operation) + } +} diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -212,7 +212,9 @@ %op = pdl.operation "foo.op" // expected-error@below {{expected rewrite region to be non-empty if external name is not specified}} - "pdl.rewrite"(%op) ({}) : (!pdl.operation) -> () + "pdl.rewrite"(%op) ({}) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- @@ -223,7 +225,9 @@ // expected-error@below {{expected no external arguments when the rewrite is specified inline}} "pdl.rewrite"(%op, %op) ({ ^bb1: - }) : (!pdl.operation, !pdl.operation) -> () + }) { + operand_segment_sizes = dense<1> : vector<2xi32> + }: (!pdl.operation, !pdl.operation) -> () } // ----- @@ -234,7 +238,9 @@ // expected-error@below {{expected no external constant parameters when the rewrite is specified inline}} "pdl.rewrite"(%op) ({ ^bb1: - }) {externalConstParams = []} : (!pdl.operation) -> () + }) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32>, + externalConstParams = []} : (!pdl.operation) -> () } // ----- @@ -245,7 +251,10 @@ // expected-error@below {{expected rewrite region to be empty when rewrite is external}} "pdl.rewrite"(%op) ({ ^bb1: - }) {name = "foo"} : (!pdl.operation) -> () + }) { + name = "foo", + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -42,6 +42,36 @@ // ----- +pdl.pattern @rewrite_multi_root_optimal : benefit(2) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation) +} + +// ----- + +pdl.pattern @rewrite_multi_root_forced : benefit(2) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation) +} + +// ----- + // Check that the result type of an operation within a rewrite can be inferred // from a pdl.replace. pdl.pattern @infer_type_from_operation_replace : benefit(1) {