diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -451,10 +451,6 @@ /// Returns the rewrite operation of this pattern. RewriteOp getRewriter(); - - /// Return the root operation kind that this pattern matches, or None if - /// there isn't a specific root. - Optional getRootKind(); }]; } @@ -579,19 +575,19 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [ Terminator, HasParent<"pdl::PatternOp">, NoTerminator, NoRegionArguments, - SingleBlock + SingleBlock, AttrSizedOperandSegments ]> { let summary = "Specify the rewrite of a matched pattern"; let description = [{ `pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify - the main rewrite of a `pdl.pattern`, on the specified root operation. The + the main rewrite of a `pdl.pattern`, on the specified root operations. The rewrite is specified either via a string name (`name`) to a native rewrite function, or via the region body. The rewrite region, if specified, must contain a single block. If the rewrite is external it functions similarly to `pdl.apply_native_rewrite`, and takes a set of constant parameters and a set of additional positional values defined within the - matcher as arguments. If the rewrite is external, the root operation is - passed to the native function as the first argument. + matcher as arguments. If the rewrite is external, the root operations are + passed to the native function as the leading arguments. Example: @@ -599,21 +595,29 @@ // Specify an external rewrite function: pdl.rewrite %root with "myExternalRewriter"(%value : !pdl.value) - // Specify the rewrite inline using PDL: + // Specify a single-root rewrite inline using PDL: pdl.rewrite %root { %op = pdl.operation "foo.op"(%arg0, %arg1) pdl.replace %root with %op } + + // Specify a multi-root rewrite inline using PDL: + pdl.rewrite %root1, %root2 { + %op1 = pdl.operation "foo.op"(%arg0, %arg1) + %op2 = pdl.operation "bar.op"(%arg0, %arg1) + pdl.replace %root1 with %op1 + pdl.replace %root2 with %op2 + } ``` }]; - let arguments = (ins PDL_Operation:$root, + let arguments = (ins Variadic:$roots, OptionalAttr:$name, Variadic:$externalArgs, OptionalAttr:$externalConstParams); let regions = (region AnyRegion:$body); let assemblyFormat = [{ - $root (`with` $name^ ($externalConstParams^)? + $roots (`with` $name^ ($externalConstParams^)? (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? ($body^)? attr-dict-with-keyword diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -61,9 +61,10 @@ void generateSwitch(SwitchNode *switchNode, Block *currentBlock, Qualifier *question, Value val, Block *defaultDest); - /// Create the interpreter operations to record a successful pattern match. + /// Create the interpreter operations to record a successful pattern match + /// using the specified root operation. void generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern); + pdl::PatternOp pattern, Value root); /// Generate a rewriter function for the given pattern operation, and returns /// a reference to that function. @@ -211,7 +212,8 @@ // Generate code for a success node. } else if (auto *successNode = dyn_cast(&node)) { - generateRecordMatch(block, nextBlock, successNode->getPattern()); + generateRecordMatch(block, nextBlock, successNode->getPattern(), + successNode->getRoot()); } if (failureNode) @@ -233,10 +235,23 @@ builder.setInsertionPointToEnd(cur); Value value; switch (pos->getKind()) { - case Predicates::OperationPos: - value = builder.create( - loc, builder.getType(), parentVal); + case Predicates::OperationPos: { + auto *operationPos = cast(pos); + if (operationPos->isUpward()) { + // Create two PDLInterp operations - one for retrieving the accepting ops + // and one for making an iterative choice among them. + Type operationTy = builder.getType(); + value = builder.create( + loc, pdl::RangeType::get(operationTy), parentVal, + operationPos->getIndex()); + value = builder.create( + loc, builder.getType(), value); + } else { + value = builder.create( + loc, builder.getType(), parentVal); + } break; + } case Predicates::OperandPos: { auto *operandPos = cast(pos); value = builder.create( @@ -338,9 +353,10 @@ break; case Predicates::EqualToQuestion: { auto *equalToQuestion = cast(question); + bool nominal = isa(answer); builder.create( loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), - trueDest, falseDest); + nominal ? trueDest : falseDest, nominal ? falseDest : trueDest); break; } case Predicates::ConstraintQuestion: { @@ -456,7 +472,7 @@ } void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern) { + pdl::PatternOp pattern, Value root) { // Generate a rewriter for the pattern this success node represents, and track // any values used from the match region. SmallVector usedMatchValues; @@ -477,8 +493,10 @@ generatedOpsAttr = builder.getStrArrayAttr(generatedOps); // Grab the root kind if present. + pdl::OperationOp rootOp = root.getDefiningOp(); + assert(rootOp); StringAttr rootKindAttr; - if (Optional rootKind = pattern.getRootKind()) + if (Optional rootKind = rootOp.name()) rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); @@ -535,8 +553,9 @@ // method. pdl::RewriteOp rewriter = pattern.getRewriter(); if (StringAttr rewriteName = rewriter.nameAttr()) { + auto mappedRoots = llvm::map_range(rewriter.roots(), mapRewriteValue); auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); - SmallVector args(1, mapRewriteValue(rewriter.root())); + auto args = llvm::to_vector<4>(mappedRoots); args.append(mappedArgs.begin(), mappedArgs.end()); builder.create( rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -65,8 +65,9 @@ // Answers. AttributeAnswer, - TrueAnswer, + FalseAnswer, OperationNameAnswer, + TrueAnswer, TypeAnswer, UnsignedAnswer, }; @@ -216,24 +217,40 @@ /// An operation position describes an operation node in the IR. Other position /// kinds are formed with respect to an operation position. -struct OperationPosition : public PredicateBase, - Predicates::OperationPos> { +struct OperationPosition + : public PredicateBase, unsigned>, + Predicates::OperationPos> { explicit OperationPosition(const KeyTy &key) : Base(key) { - parent = key.first; + parent = std::get<0>(key); + } + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); } /// Gets the root position. static OperationPosition *getRoot(StorageUniquer &uniquer) { - return Base::get(uniquer, nullptr, 0); + return Base::get(uniquer, nullptr, llvm::None, 0); } + /// Gets an operation position with the given parent. - static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { - return Base::get(uniquer, parent, parent->getOperationDepth() + 1); + /// The optional operand index indicates a downward traversal from operand to + /// the operation. + static OperationPosition *get(StorageUniquer &uniquer, Position *parent, + Optional operand = llvm::None) { + return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1); } + /// Returns if this operation position is upward, accepting an input. + bool isUpward() const { return std::get<1>(key).hasValue(); } + + /// Returns the operand index for an upward operation position. + unsigned getIndex() const { return *std::get<1>(key); } + /// Returns the depth of this position. - unsigned getDepth() const { return key.second; } + unsigned getDepth() const { return std::get<2>(key); } /// Returns if this operation position corresponds to the root. bool isRoot() const { return getDepth() == 0; } @@ -346,6 +363,12 @@ using Base::Base; }; +/// An Answer representing a boolean 'false' value. +struct FalseAnswer + : PredicateBase { + using Base::Base; +}; + /// An Answer representing a `Type` value. The value is stored as either a /// TypeAttr, or an ArrayAttr of TypeAttr. struct TypeAnswer : public PredicateBase(); registerParametricStorageType(); registerParametricStorageType(); + registerSingletonStorageType(); registerSingletonStorageType(); // Register the types of Answers with the uniquer. @@ -485,6 +509,14 @@ return OperationPosition::get(uniquer, p); } + /// Returns the operation position accepting the value at the given position. + OperationPosition *getOperandAcceptingOp(Position *p, unsigned operand) { + assert((isa(p)) && + "expected result position"); + return OperationPosition::get(uniquer, p, operand); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, Identifier::get(name, ctx)); @@ -536,11 +568,16 @@ AttributeAnswer::get(uniquer, attr)}; } - /// Create a predicate comparing two values. + /// Create a predicate checking if two values are equal. Predicate getEqualTo(Position *pos) { return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; } + /// Create a predicate checking if two values are not equal. + Predicate getNotEqualTo(Position *pos) { + return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)}; + } + /// Create a predicate that applies a generic constraint. Predicate getConstraint(StringRef name, ArrayRef pos, Attribute params) { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -152,7 +152,7 @@ /// matched. This does not terminate the matcher, as there may be multiple /// successful matches. struct SuccessNode : public MatcherNode { - explicit SuccessNode(pdl::PatternOp pattern, + explicit SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode); /// Returns if the given matcher node is an instance of this class, used to @@ -164,10 +164,16 @@ /// Return the high level pattern operation that is matched with this node. pdl::PatternOp getPattern() const { return pattern; } + /// Return the chosen root of the pattern. + Value getRoot() const { return root; } + private: /// The high level pattern operation that was successfully matched with this /// node. pdl::PatternOp pattern; + + /// The chosen root of the pattern. + Value root; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -7,12 +7,18 @@ //===----------------------------------------------------------------------===// #include "PredicateTree.h" +#include "mlir/Conversion/PDLToPDLInterp/RootOrdering.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "pdl-predicate-tree" using namespace mlir; using namespace mlir::pdl_to_pdl_interp; @@ -102,7 +108,8 @@ static void getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, - OperationPosition *pos) { + OperationPosition *pos, + Optional ignoreOperand = llvm::None) { assert(val.getType().isa() && "expected operation"); pdl::OperationOp op = cast(val.getDefiningOp()); OperationPosition *opPos = cast(pos); @@ -158,6 +165,11 @@ bool isVariadic = operandIt.value().getType().isa(); foundVariableLength |= isVariadic; + // Ignore the specified operand, usually because this position was + // visited in an upward traversal via a nondeterministic choice. + if (ignoreOperand && *ignoreOperand == operandIt.index()) + continue; + Position *pos = foundVariableLength ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) @@ -300,15 +312,249 @@ } } +namespace { + +/// An op accepting a value at an optional index. +struct OpIndex { + Value parent; + Optional index; +}; + +/// The parent and operand index of each operation for each root, stored +/// as a nested map [root][operation]. +using ParentMaps = DenseMap>; + +} // namespace + +/// Given a list of candidate roots, builds the cost graph for connecting them. +/// The graph is formed by traversing the DAG of operations starting from each +/// root and marking the depth of each connector value (operand). Then we join +/// the candidate roots based on the common connector values, taking the one +/// with the minimum depth. Along the way, we compute, for each candidate root, +/// a mapping from each operation (in the DAG underneath this root) to its +/// parent operation and the corresponding operand index. +static void buildCostGraph(ArrayRef roots, RootOrderingGraph &graph, + ParentMaps &parentMaps) { + + // A root of a value and its depth (distance from root to the value). + struct RootDepth { + Value root; + unsigned depth = 0; + }; + + // Map from candidate connector values to their roots and depths. Using a + // small vector with 1 entry because most values belong to a single root. + llvm::MapVector> connectorsRootsDepths; + + // Perform a breadth-first traversal of the op DAG rooted at each root. + for (Value root : roots) { + // The entry of a queue. The entry consists of the following items: + // * the value in the DAG underneath the root; + // * the parent of the value; + // * the operand index of the value in its parent; + // * the depth of the visited value. + struct Entry { + Value value; + Value parent; + Optional index; + unsigned depth; + Entry(Value value, Value parent, Optional index, unsigned depth) + : value(value), parent(parent), index(index), depth(depth) {} + }; + + // The queue of visited values. A value may be present multiple times in + // the queue, for multiple parents. We only accept the first occurrence, + // which is guaranteed to have the lowest depth. + std::queue toVisit; + toVisit.emplace(root, Value(), 0, 0); + + // The map from value to its parent for the current root. + DenseMap &parentMap = parentMaps[root]; + + while (!toVisit.empty()) { + Entry entry = toVisit.front(); + toVisit.pop(); + // Skip if already visited. + if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) + continue; + + // Mark the root and depth of the value. + connectorsRootsDepths[entry.value].push_back({root, entry.depth}); + + // Traverse the operands of an operation and result ops. + // We intentionally do not traverse attributes and types, because those + // are expensive to join on. + TypeSwitch(entry.value.getDefiningOp()) + .Case([&](auto operationOp) { + for (auto p : llvm::enumerate(operationOp.operands())) + toVisit.emplace(p.value(), entry.value, p.index(), + entry.depth + 1); + }) + .Case([&](auto resultOp) { + toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(), + entry.depth); + }); + } + } + + // Now build the cost graph. + // This is simply a minimum over all depths for the target root. + unsigned nextID = 0; + for (const auto &connectorRootsDepths : connectorsRootsDepths) { + Value value = connectorRootsDepths.first; + ArrayRef rootsDepths = connectorRootsDepths.second; + // If there is only one root for this value, this will not trigger + // any edges in the cost graph (a perf optimization). + if (rootsDepths.size() == 1) + continue; + + for (const RootDepth &p : rootsDepths) { + for (const RootDepth &q : rootsDepths) { + if (&p == &q) + continue; + // insert or retrieve the property of edge from p to q + RootOrderingCost &cost = graph[q.root][p.root]; + if (!cost.connector /* new edge */ || cost.cost.first > q.depth) { + if (!cost.connector) + cost.cost.second = nextID++; + cost.cost.first = q.depth; + cost.connector = value; + } + } + } + } + + assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && + "the pattern contains a candidate root disconnected from the others"); +} + /// Given a pattern operation, build the set of matcher predicates necessary to /// match this pattern. -static void buildPredicateList(pdl::PatternOp pattern, - PredicateBuilder &builder, - std::vector &predList, - DenseMap &valueToPosition) { - getTreePredicates(predList, pattern.getRewriter().root(), builder, - valueToPosition, builder.getRoot()); +static Value buildPredicateList(pdl::PatternOp pattern, + PredicateBuilder &builder, + std::vector &predList, + DenseMap &valueToPosition) { + // Build the root ordering graph and compute the parent maps. + auto roots = llvm::to_vector<4>(pattern.getRewriter().roots()); + RootOrderingGraph graph; + ParentMaps parentMaps; + buildCostGraph(roots, graph, parentMaps); + LLVM_DEBUG({ + llvm::dbgs() << "Graph:\n"; + for (auto &target : graph) { + llvm::dbgs() << " * " << target.first << "\n"; + for (auto &source : target.second) { + RootOrderingCost c = source.second; + llvm::dbgs() << " <- " << source.first << ": " << c.cost.first + << ":" << c.cost.second << " via " << c.connector.getLoc() + << "\n"; + } + } + }); + + // Solve the optimal branching problem for each candidate root. + unsigned bestCost = 0; + Value bestRoot; + OptimalBranching::EdgeList bestEdges; + LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); + for (Value root : roots) { + OptimalBranching solver(graph, root); + unsigned cost = solver.solve(); + LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); + if (!bestRoot || bestCost > cost) { + bestCost = cost; + bestRoot = root; + bestEdges = solver.preOrderTraversal(roots); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); + + // The best root is the starting point for the traversal. Get the tree + // predicates for the DAG rooted at bestRoot. + getTreePredicates(predList, bestRoot, builder, valueToPosition, + builder.getRoot()); + + // Traverse the selected optimal branching. For all edges in order, traverse + // up starting from the connector, until the candidate root is reached, and + // call getTreePredicates at every node along the way. + for (const std::pair &edge : bestEdges) { + Value target = edge.first; + Value source = edge.second; + + // Check if we already visited the target root. This happens in two cases: + // 1) the initial root (bestRoot); + // 2) a root that is dominated by (contained in the subtree rooted at) an + // already visited root. + if (valueToPosition.count(target)) + continue; + + // Determine the connector. + Value connector = graph[target][source].connector; + assert(connector && "Invalid edge"); + LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); + DenseMap parentMap = parentMaps.lookup(target); + Position *pos = valueToPosition.lookup(connector); + assert(pos && "The value has not been traversed yet"); + bool first = true; + + // Traverse from the connector upwards towards the target root. + for (Value value = connector; value != target;) { + OpIndex opIndex = parentMap.lookup(value); + assert(opIndex.parent && "Missing parent"); + value = opIndex.parent; + + // Visit the node. + TypeSwitch(value.getDefiningOp()) + .Case([&](auto operationOp) { + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + OperationPosition *opPos = + builder.getOperandAcceptingOp(pos, *opIndex.index); + + // Guard against traversing back to where we came from. + if (first) { + Position *parent = pos->getParent(); + predList.emplace_back(opPos, builder.getNotEqualTo(parent)); + first = false; + } + + // Guard against duplicate upward visits. These are not possible, + // because if this value was already visited, it would have been + // cheaper to start the traversal at this value rather than at the + // `connector`, violating the optimality of our spanning tree. + bool inserted = valueToPosition.try_emplace(value, opPos).second; + assert(inserted && "Duplicate upward visit"); + + // Obtain the tree predicates at the current value + getTreePredicates(predList, value, builder, valueToPosition, opPos, + opIndex.index); + + // Update the position + pos = opPos; + }) + .Case([&](auto resultOp) { + // Traverse up an individual result. + auto *opPos = dyn_cast(pos); + assert(opPos && "Operations and results must be interleaved"); + pos = builder.getResult(opPos, *opIndex.index); + }) + .Case([&](auto resultOp) { + // Traverse up a group of results. + auto *opPos = dyn_cast(pos); + assert(opPos && "Operations and results must be interleaved"); + if (opIndex.index) + pos = builder.getResultGroup(opPos, opIndex.index, + /*isVariadic=*/false); + else + pos = builder.getAllResults(opPos); + }); + } + } + getNonTreePredicates(pattern, predList, builder, valueToPosition); + + return bestRoot; } //===----------------------------------------------------------------------===// @@ -382,9 +628,11 @@ /// This class wraps a set of ordered predicates that are used within a specific /// pattern operation. struct OrderedPredicateList { - OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} + OrderedPredicateList(pdl::PatternOp pattern, Value root) + : pattern(pattern), root(root) {} pdl::PatternOp pattern; + Value root; DenseSet predicates; }; } // end anonymous namespace @@ -421,7 +669,8 @@ std::vector::iterator end) { if (current == end) { // We've hit the end of a pattern, so create a successful result node. - node = std::make_unique(list.pattern, std::move(node)); + node = + std::make_unique(list.pattern, list.root, std::move(node)); // If the pattern doesn't contain this predicate, ignore it. } else if (list.predicates.find(*current) == list.predicates.end()) { @@ -489,22 +738,37 @@ std::unique_ptr MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap &valueToPosition) { - // Collect the set of predicates contained within the pattern operations of - // the module. - SmallVector>, 16> - patternsAndPredicates; + // The set of predicates contained within the pattern operations of the + // module. + struct PatternPredicates { + /// A pattern. + pdl::PatternOp pattern; + + /// A root of the pattern chosen among the candidate roots in pdl.rewrite. + Value root; + + /// The extracted predicates for this pattern and root. + std::vector predicates; + + PatternPredicates(pdl::PatternOp pattern, Value root, + std::vector predicates) + : pattern(pattern), root(root), predicates(std::move(predicates)) {} + }; + + SmallVector patternsAndPredicates; for (pdl::PatternOp pattern : module.getOps()) { std::vector predicateList; - buildPredicateList(pattern, builder, predicateList, valueToPosition); - patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); + Value root = + buildPredicateList(pattern, builder, predicateList, valueToPosition); + patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); } // Associate a pattern result with each unique predicate. DenseSet uniqued; for (auto &patternAndPredList : patternsAndPredicates) { - for (auto &predicate : patternAndPredList.second) { + for (auto &predicate : patternAndPredList.predicates) { auto it = uniqued.insert(predicate); - it.first->patternToAnswer.try_emplace(patternAndPredList.first, + it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, predicate.answer); } } @@ -513,8 +777,9 @@ std::vector lists; lists.reserve(patternsAndPredicates.size()); for (auto &patternAndPredList : patternsAndPredicates) { - OrderedPredicateList list(patternAndPredList.first); - for (auto &predicate : patternAndPredList.second) { + OrderedPredicateList list(patternAndPredList.pattern, + patternAndPredList.root); + for (auto &predicate : patternAndPredList.predicates) { OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); list.predicates.insert(orderedPredicate); @@ -580,11 +845,11 @@ // SuccessNode //===----------------------------------------------------------------------===// -SuccessNode::SuccessNode(pdl::PatternOp pattern, +SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode) : MatcherNode(TypeID::get(), /*position=*/nullptr, /*question=*/nullptr, std::move(failureNode)), - pattern(pattern) {} + pattern(pattern), root(root) {} //===----------------------------------------------------------------------===// // SwitchNode diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -45,7 +45,8 @@ return true; // Only the first operand of RewriteOp may be bound to, i.e. the root // operation of the pattern. - if (isa(user) && use.getOperandNumber() == 0) + RewriteOp rewriteOp = dyn_cast(user); + if (rewriteOp && use.getOperandNumber() < rewriteOp.roots().size()) return true; // A result by itself is not binding, it must also be bound. if (isa(user) && @@ -294,13 +295,6 @@ return cast(body().front().getTerminator()); } -/// Return the root operation kind that this pattern matches, or None if -/// there isn't a specific root. -Optional PatternOp::getRootKind() { - OperationOp rootOp = cast(getRewriter().root().getDefiningOp()); - return rootOp.name(); -} - //===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -408,3 +408,69 @@ pdl.rewrite %apply with "rewriter" } } + + +// ----- + +// CHECK-LABEL: module @multi_root +module @multi_root { + // Check the lowering of a simple two-root pattern. + // This checks that we correctly generate the pdl_interp.choose_op operation + // and tie the break between %root1 and %root2 in favor of %root1. + + // CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[ROOT1]] + // CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]] + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_accepting_ops of %[[VAL1]] : !pdl.value at 0 + // CHECK-DAG: %[[ROOT2:.*]] = pdl_interp.choose_op from %[[OPS]] + // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]] + // CHECK-DAG: %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP1]] : !pdl.operation -> ^{{.*}}, ^[[FAIL:.*]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP2]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[VAL1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[VAL2]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation + // CHECK-DAG: pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[FAIL]] + + pdl.pattern @rewrite_multi_root : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1, %root2 with "rewriter" + } +} + + +// ----- + +// CHECK-LABEL: module @overlapping_roots +module @overlapping_roots { + // Check the lowering of a degenerate two-root pattern, where one root + // is in the subtree rooted at another. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL]] + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[OP]] + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[OP]] + // CHECK-DAG: pdl_interp.is_not_null %[[VAL]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value + + pdl.pattern @rewrite_overlapping_roots : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type) + %val = pdl.result 0 of %op + %root = pdl.operation(%val : !pdl.value) + pdl.rewrite %root, %op with "rewriter" + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -212,7 +212,9 @@ %op = pdl.operation "foo.op" // expected-error@below {{expected rewrite region to be non-empty if external name is not specified}} - "pdl.rewrite"(%op) ({}) : (!pdl.operation) -> () + "pdl.rewrite"(%op) ({}) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- @@ -223,7 +225,9 @@ // expected-error@below {{expected no external arguments when the rewrite is specified inline}} "pdl.rewrite"(%op, %op) ({ ^bb1: - }) : (!pdl.operation, !pdl.operation) -> () + }) { + operand_segment_sizes = dense<1> : vector<2xi32> + }: (!pdl.operation, !pdl.operation) -> () } // ----- @@ -234,7 +238,9 @@ // expected-error@below {{expected no external constant parameters when the rewrite is specified inline}} "pdl.rewrite"(%op) ({ ^bb1: - }) {externalConstParams = []} : (!pdl.operation) -> () + }) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32>, + externalConstParams = []} : (!pdl.operation) -> () } // ----- @@ -245,7 +251,10 @@ // expected-error@below {{expected rewrite region to be empty when rewrite is external}} "pdl.rewrite"(%op) ({ ^bb1: - }) {name = "foo"} : (!pdl.operation) -> () + }) { + name = "foo", + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -42,6 +42,21 @@ // ----- +pdl.pattern @rewrite_multi_root : benefit(2) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1, %root2 with "rewriter"["I am param"](%input1 : !pdl.value) +} + +// ----- + // Check that the result type of an operation within a rewrite can be inferred // from a pdl.replace. pdl.pattern @infer_type_from_operation_replace : benefit(1) {