diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -427,19 +427,16 @@ ``` }]; - let arguments = (ins OptionalAttr:$rootKind, - Confined:$benefit, + let arguments = (ins Confined:$benefit, OptionalAttr:$sym_name); let regions = (region SizedRegion<1>:$body); let assemblyFormat = [{ - ($sym_name^)? `:` `benefit` `(` $benefit `)` - (`,` `root` `(` $rootKind^ `)`)? attr-dict-with-keyword $body + ($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $body }]; let builders = [ - OpBuilder<(ins CArg<"Optional", "llvm::None">:$rootKind, - CArg<"Optional", "1">:$benefit, - CArg<"Optional", "llvm::None">:$name)>, + OpBuilder<(ins CArg<"Optional", "1">:$benefit, + CArg<"Optional", "llvm::None">:$name)>, ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -451,10 +448,6 @@ /// Returns the rewrite operation of this pattern. RewriteOp getRewriter(); - - /// Return the root operation kind that this pattern matches, or None if - /// there isn't a specific root. - Optional getRootKind(); }]; } @@ -579,19 +572,25 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [ Terminator, HasParent<"pdl::PatternOp">, NoTerminator, NoRegionArguments, - SingleBlock + SingleBlock, AttrSizedOperandSegments ]> { let summary = "Specify the rewrite of a matched pattern"; let description = [{ `pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify - the main rewrite of a `pdl.pattern`, on the specified root operation. The + the main rewrite of a `pdl.pattern`, on the optional root operation. The rewrite is specified either via a string name (`name`) to a native rewrite function, or via the region body. The rewrite region, if specified, must contain a single block. If the rewrite is external it functions similarly to `pdl.apply_native_rewrite`, and takes a set of constant parameters and a set of additional positional values defined within the matcher as arguments. If the rewrite is external, the root operation is - passed to the native function as the first argument. + passed to the native function as the leading arguments. The root operation, + if provided, specifies the starting point in the pattern for the subgraph + isomorphism search. Pattern matching will proceed from this node downward + (towards the defining operation) or upward (towards the users) until all + the operations in the pattern have been matched. If the root is omitted, + the pdl_interp lowering will automatically select the best root of the + pdl.rewrite among all the operations in the pattern. Example: @@ -599,23 +598,31 @@ // Specify an external rewrite function: pdl.rewrite %root with "myExternalRewriter"(%value : !pdl.value) - // Specify the rewrite inline using PDL: + // Specify a rewrite inline using PDL with the given root: pdl.rewrite %root { %op = pdl.operation "foo.op"(%arg0, %arg1) pdl.replace %root with %op } + + // Specify a rewrite inline using PDL, automatically selecting root: + pdl.rewrite { + %op1 = pdl.operation "foo.op"(%arg0, %arg1) + %op2 = pdl.operation "bar.op"(%arg0, %arg1) + pdl.replace %root1 with %op1 + pdl.replace %root2 with %op2 + } ``` }]; - let arguments = (ins PDL_Operation:$root, + let arguments = (ins Optional:$root, OptionalAttr:$name, Variadic:$externalArgs, OptionalAttr:$externalConstParams); let regions = (region AnyRegion:$body); let assemblyFormat = [{ - $root (`with` $name^ ($externalConstParams^)? - (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? - ($body^)? + ($root^)? (`with` $name^ ($externalConstParams^)? + (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? + ($body^)? attr-dict-with-keyword }]; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -43,27 +43,27 @@ using ValueMapScope = llvm::ScopedHashTableScope; /// Generate interpreter operations for the tree rooted at the given matcher - /// node. - Block *generateMatcher(MatcherNode &node); + /// node, in the specified region. + Block *generateMatcher(MatcherNode &node, Region ®ion); - /// Get or create an access to the provided positional value within the - /// current block. - Value getValueAt(Block *cur, Position *pos); + /// Get or create an access to the provided positional value in the current + /// block. This operation may mutate the provided block pointer if nested + /// regions (i.e., pdl_interp.iterate) are required. + Value getValueAt(Block *¤tBlock, Position *pos); - /// Create an interpreter predicate operation, branching to the provided true - /// and false destinations. - void generatePredicate(Block *currentBlock, Qualifier *question, - Qualifier *answer, Value val, Block *trueDest, - Block *falseDest); + /// Create the interpreter predicate operations. This operation may mutate the + /// provided current block pointer if nested regions (iterates) are required. + void generate(BoolNode *boolNode, Block *¤tBlock, Value val); - /// Create an interpreter switch predicate operation, with a provided default - /// and several case destinations. - void generateSwitch(SwitchNode *switchNode, Block *currentBlock, - Qualifier *question, Value val, Block *defaultDest); + /// Create the interpreter switch / predicate operations, with several case + /// destinations. This operation never mutates the provided current block + /// pointer, because the switch operation does not need Values beyond `val`. + void generate(SwitchNode *switchNode, Block *currentBlock, Value val); - /// Create the interpreter operations to record a successful pattern match. - void generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern); + /// Create the interpreter operations to record a successful pattern match + /// using the contained root operation. This operation may mutate the current + /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. + void generate(SuccessNode *successNode, Block *¤tBlock); /// Generate a rewriter function for the given pattern operation, and returns /// a reference to that function. @@ -156,7 +156,8 @@ // Generate a root matcher node from the provided PDL module. std::unique_ptr root = MatcherNode::generateMatcherTree( module, predicateBuilder, valueToPosition); - Block *firstMatcherBlock = generateMatcher(*root); + Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); + assert(failureBlockStack.empty() && "failed to empty the stack"); // After generation, merged the first matched block into the entry. matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), @@ -164,9 +165,9 @@ firstMatcherBlock->erase(); } -Block *PatternLowering::generateMatcher(MatcherNode &node) { +Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { // Push a new scope for the values used by this matcher. - Block *block = matcherFunc.addBlock(); + Block *block = ®ion.emplaceBlock(); ValueMapScope scope(values); // If this is the return node, simply insert the corresponding interpreter @@ -177,66 +178,114 @@ return block; } - // If this node contains a position, get the corresponding value for this - // block. - Position *position = node.getPosition(); - Value val = position ? getValueAt(block, position) : Value(); - // Get the next block in the match sequence. + // This is intentionally executed first, before we get the value for the + // position associated with the node, so that we preserve an "there exist" + // semantics: if getting a value requires an upward traversal (going from a + // value to its consumers), we want to perform the check on all the consumers + // before we pass control to the failure node. std::unique_ptr &failureNode = node.getFailureNode(); - Block *nextBlock; + Block *failureBlock; if (failureNode) { - nextBlock = generateMatcher(*failureNode); - failureBlockStack.push_back(nextBlock); + failureBlock = generateMatcher(*failureNode, region); + failureBlockStack.push_back(failureBlock); } else { assert(!failureBlockStack.empty() && "expected valid failure block"); - nextBlock = failureBlockStack.back(); + failureBlock = failureBlockStack.back(); } + // If this node contains a position, get the corresponding value for this + // block. + Block *currentBlock = block; + Position *position = node.getPosition(); + Value val = position ? getValueAt(currentBlock, position) : Value(); + // If this value corresponds to an operation, record that we are going to use // its location as part of a fused location. bool isOperationValue = val && val.getType().isa(); if (isOperationValue) locOps.insert(val); - // Generate code for a boolean predicate node. - if (auto *boolNode = dyn_cast(&node)) { - auto *child = generateMatcher(*boolNode->getSuccessNode()); - generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val, - child, nextBlock); - - // Generate code for a switch node. - } else if (auto *switchNode = dyn_cast(&node)) { - generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock); - - // Generate code for a success node. - } else if (auto *successNode = dyn_cast(&node)) { - generateRecordMatch(block, nextBlock, successNode->getPattern()); + // Dispatch to the correct method based on derived node type. + TypeSwitch(&node) + .Case( + [&](auto *derivedNode) { generate(derivedNode, currentBlock, val); }) + .Case([&](SuccessNode *successNode) { + generate(successNode, currentBlock); + }); + + // Pop all the failure blocks that were inserted due to nesting of + // pdl_interp.iterate. + while (failureBlockStack.back() != failureBlock) { + failureBlockStack.pop_back(); + assert(!failureBlockStack.empty() && "unable to locate failure block"); } + // Pop the new failure block. if (failureNode) failureBlockStack.pop_back(); + if (isOperationValue) locOps.remove(val); + return block; } -Value PatternLowering::getValueAt(Block *cur, Position *pos) { +Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { if (Value val = values.lookup(pos)) return val; // Get the value for the parent position. - Value parentVal = getValueAt(cur, pos->getParent()); + Value parentVal = getValueAt(currentBlock, pos->getParent()); // TODO: Use a location from the position. Location loc = parentVal.getLoc(); - builder.setInsertionPointToEnd(cur); + builder.setInsertionPointToEnd(currentBlock); Value value; switch (pos->getKind()) { - case Predicates::OperationPos: - value = builder.create( - loc, builder.getType(), parentVal); + case Predicates::OperationPos: { + auto *operationPos = cast(pos); + if (!operationPos->isUpward()) { + // Standard (downward) traversal which directly follows the defining op. + value = builder.create( + loc, builder.getType(), parentVal); + break; + } + + // The first operation retrieves the representative value of a range. + // This applies only when the parent is a range of values. + if (parentVal.getType().isa()) + value = builder.create(loc, parentVal, 0); + else + value = parentVal; + + // The second operation retrieves the users. + value = builder.create(loc, value); + + // The third operation iterates over them. + assert(!failureBlockStack.empty() && "expected valid failure block"); + auto foreach = builder.create( + loc, value, failureBlockStack.back(), /*initLoop=*/true); + value = foreach.getLoopVariable(); + + // Create the success and continuation blocks. + Block *successBlock = builder.createBlock(&foreach.region()); + Block *continueBlock = builder.createBlock(successBlock); + builder.create(loc); + failureBlockStack.push_back(continueBlock); + + // The fourth operation extracts the operand(s) of the user at the specified + // index (which can be None, indicating all operands). + builder.setInsertionPointToStart(&foreach.region().front()); + Value operands = builder.create( + loc, parentVal.getType(), value, operationPos->getIndex()); + + // The fifth operation compares the operands to the parent value / range. + builder.create(loc, parentVal, operands, + successBlock, continueBlock); + currentBlock = successBlock; break; + } case Predicates::OperandPos: { auto *operandPos = cast(pos); value = builder.create( @@ -285,41 +334,60 @@ llvm_unreachable("Generating unknown Position getter"); break; } + values.insert(pos, value); return value; } -void PatternLowering::generatePredicate(Block *currentBlock, - Qualifier *question, Qualifier *answer, - Value val, Block *trueDest, - Block *falseDest) { - builder.setInsertionPointToEnd(currentBlock); +void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, + Value val) { Location loc = val.getLoc(); + Qualifier *question = boolNode->getQuestion(); + Qualifier *answer = boolNode->getAnswer(); + Region *region = currentBlock->getParent(); + + // Execute the getValue queries first, so that we create success + // matcher in the correct (possibly nested) region. + SmallVector args; + if (auto *equalToQuestion = dyn_cast(question)) { + args = {getValueAt(currentBlock, equalToQuestion->getValue())}; + } else if (auto *cstQuestion = dyn_cast(question)) { + for (Position *position : std::get<1>(cstQuestion->getValue())) + args.push_back(getValueAt(currentBlock, position)); + } + + // Generate the matcher in the current (potentially nested) region + // and get the failure successor. + Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); + Block *failure = failureBlockStack.back(); + + // Finally, create the predicate. + builder.setInsertionPointToEnd(currentBlock); Predicates::Kind kind = question->getKind(); switch (kind) { case Predicates::IsNotNullQuestion: - builder.create(loc, val, trueDest, falseDest); + builder.create(loc, val, success, failure); break; case Predicates::OperationNameQuestion: { auto *opNameAnswer = cast(answer); builder.create( - loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest); + loc, val, opNameAnswer->getValue().getStringRef(), success, failure); break; } case Predicates::TypeQuestion: { auto *ans = cast(answer); if (val.getType().isa()) builder.create( - loc, val, ans->getValue().cast(), trueDest, falseDest); + loc, val, ans->getValue().cast(), success, failure); else builder.create( - loc, val, ans->getValue().cast(), trueDest, falseDest); + loc, val, ans->getValue().cast(), success, failure); break; } case Predicates::AttributeQuestion: { auto *ans = cast(answer); builder.create(loc, val, ans->getValue(), - trueDest, falseDest); + success, failure); break; } case Predicates::OperandCountAtLeastQuestion: @@ -327,31 +395,27 @@ builder.create( loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, - trueDest, falseDest); + success, failure); break; case Predicates::ResultCountAtLeastQuestion: case Predicates::ResultCountQuestion: builder.create( loc, val, cast(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, - trueDest, falseDest); + success, failure); break; case Predicates::EqualToQuestion: { - auto *equalToQuestion = cast(question); - builder.create( - loc, val, getValueAt(currentBlock, equalToQuestion->getValue()), - trueDest, falseDest); + bool trueAnswer = isa(answer); + builder.create(loc, val, args.front(), + trueAnswer ? success : failure, + trueAnswer ? failure : success); break; } case Predicates::ConstraintQuestion: { - auto *cstQuestion = cast(question); - SmallVector args; - for (Position *position : std::get<1>(cstQuestion->getValue())) - args.push_back(getValueAt(currentBlock, position)); + auto value = cast(question)->getValue(); builder.create( - loc, std::get<0>(cstQuestion->getValue()), args, - std::get<2>(cstQuestion->getValue()).cast(), trueDest, - falseDest); + loc, std::get<0>(value), args, std::get<2>(value).cast(), + success, failure); break; } default: @@ -373,9 +437,12 @@ builder.create(val.getLoc(), val, values, defaultDest, blocks); } -void PatternLowering::generateSwitch(SwitchNode *switchNode, - Block *currentBlock, Qualifier *question, - Value val, Block *defaultDest) { +void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, + Value val) { + Qualifier *question = switchNode->getQuestion(); + Region *region = currentBlock->getParent(); + Block *defaultDest = failureBlockStack.back(); + // If the switch question is not an exact answer, i.e. for the `at_least` // cases, we generate a special block sequence. Predicates::Kind kind = question->getKind(); @@ -407,12 +474,25 @@ // ... // failureBlockStack.push_back(defaultDest); + Location loc = val.getLoc(); for (unsigned idx : sortedChildren) { auto &child = switchNode->getChild(idx); - Block *childBlock = generateMatcher(*child.second); + Block *childBlock = generateMatcher(*child.second, *region); Block *predicateBlock = builder.createBlock(childBlock); - generatePredicate(predicateBlock, question, child.first, val, childBlock, - defaultDest); + builder.setInsertionPointToEnd(predicateBlock); + unsigned ans = cast(child.first)->getValue(); + switch (kind) { + case Predicates::OperandCountAtLeastQuestion: + builder.create( + loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + break; + case Predicates::ResultCountAtLeastQuestion: + builder.create( + loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + break; + default: + llvm_unreachable("Generating invalid AtLeast operation"); + } failureBlockStack.back() = predicateBlock; } Block *firstPredicateBlock = failureBlockStack.pop_back_val(); @@ -426,7 +506,7 @@ // switch. llvm::MapVector children; for (auto &it : switchNode->getChildren()) - children.insert({it.first, generateMatcher(*it.second)}); + children.insert({it.first, generateMatcher(*it.second, *region)}); builder.setInsertionPointToEnd(currentBlock); switch (question->getKind()) { @@ -455,8 +535,10 @@ } } -void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock, - pdl::PatternOp pattern) { +void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { + pdl::PatternOp pattern = successNode->getPattern(); + Value root = successNode->getRoot(); + // Generate a rewriter for the pattern this success node represents, and track // any values used from the match region. SmallVector usedMatchValues; @@ -478,14 +560,15 @@ // Grab the root kind if present. StringAttr rootKindAttr; - if (Optional rootKind = pattern.getRootKind()) - rootKindAttr = builder.getStringAttr(*rootKind); + if (pdl::OperationOp rootOp = root.getDefiningOp()) + if (Optional rootKind = rootOp.name()) + rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); builder.create( pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), - nextBlock); + failureBlockStack.back()); } SymbolRefAttr PatternLowering::generateRewriter( @@ -535,8 +618,10 @@ // method. pdl::RewriteOp rewriter = pattern.getRewriter(); if (StringAttr rewriteName = rewriter.nameAttr()) { + SmallVector args; + if (rewriter.root()) + args.push_back(mapRewriteValue(rewriter.root())); auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); - SmallVector args(1, mapRewriteValue(rewriter.root())); args.append(mappedArgs.begin(), mappedArgs.end()); builder.create( rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -65,8 +65,9 @@ // Answers. AttributeAnswer, - TrueAnswer, + FalseAnswer, OperationNameAnswer, + TrueAnswer, TypeAnswer, UnsignedAnswer, }; @@ -216,24 +217,45 @@ /// An operation position describes an operation node in the IR. Other position /// kinds are formed with respect to an operation position. -struct OperationPosition : public PredicateBase, - Predicates::OperationPos> { +struct OperationPosition + : public PredicateBase, unsigned>, + Predicates::OperationPos> { + static constexpr unsigned kDown = std::numeric_limits::max(); + explicit OperationPosition(const KeyTy &key) : Base(key) { - parent = key.first; + parent = std::get<0>(key); + } + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); } /// Gets the root position. static OperationPosition *getRoot(StorageUniquer &uniquer) { - return Base::get(uniquer, nullptr, 0); + return Base::get(uniquer, nullptr, kDown, 0); } - /// Gets an operation position with the given parent. + + /// Gets an downward operation position with the given parent. static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { - return Base::get(uniquer, parent, parent->getOperationDepth() + 1); + return Base::get(uniquer, parent, kDown, parent->getOperationDepth() + 1); } + /// Gets an upward operation position with the given parent and operand. + static OperationPosition *get(StorageUniquer &uniquer, Position *parent, + Optional operand) { + return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1); + } + + /// Returns the operand index for an upward operation position. + Optional getIndex() const { return std::get<1>(key); } + + /// Returns if this operation position is upward, accepting an input. + bool isUpward() const { return getIndex().getValueOr(0) != kDown; } + /// Returns the depth of this position. - unsigned getDepth() const { return key.second; } + unsigned getDepth() const { return std::get<2>(key); } /// Returns if this operation position corresponds to the root. bool isRoot() const { return getDepth() == 0; } @@ -346,6 +368,12 @@ using Base::Base; }; +/// An Answer representing a boolean 'false' value. +struct FalseAnswer + : PredicateBase { + using Base::Base; +}; + /// An Answer representing a `Type` value. The value is stored as either a /// TypeAttr, or an ArrayAttr of TypeAttr. struct TypeAnswer : public PredicateBase(); registerParametricStorageType(); registerParametricStorageType(); + registerSingletonStorageType(); registerSingletonStorageType(); // Register the types of Answers with the uniquer. @@ -485,6 +514,14 @@ return OperationPosition::get(uniquer, p); } + /// Returns the position of operation using the value at the given index. + OperationPosition *getUsersOp(Position *p, Optional operand) { + assert((isa(p)) && + "expected result position"); + return OperationPosition::get(uniquer, p, operand); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); @@ -536,11 +573,16 @@ AttributeAnswer::get(uniquer, attr)}; } - /// Create a predicate comparing two values. + /// Create a predicate checking if two values are equal. Predicate getEqualTo(Position *pos) { return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; } + /// Create a predicate checking if two values are not equal. + Predicate getNotEqualTo(Position *pos) { + return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)}; + } + /// Create a predicate that applies a generic constraint. Predicate getConstraint(StringRef name, ArrayRef pos, Attribute params) { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -152,7 +152,7 @@ /// matched. This does not terminate the matcher, as there may be multiple /// successful matches. struct SuccessNode : public MatcherNode { - explicit SuccessNode(pdl::PatternOp pattern, + explicit SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode); /// Returns if the given matcher node is an instance of this class, used to @@ -164,10 +164,16 @@ /// Return the high level pattern operation that is matched with this node. pdl::PatternOp getPattern() const { return pattern; } + /// Return the chosen root of the pattern. + Value getRoot() const { return root; } + private: /// The high level pattern operation that was successfully matched with this /// node. pdl::PatternOp pattern; + + /// The chosen root of the pattern. + Value root; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -7,12 +7,19 @@ //===----------------------------------------------------------------------===// #include "PredicateTree.h" +#include "RootOrdering.h" + #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "pdl-predicate-tree" using namespace mlir; using namespace mlir::pdl_to_pdl_interp; @@ -102,7 +109,8 @@ static void getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, - OperationPosition *pos) { + OperationPosition *pos, + Optional ignoreOperand = llvm::None) { assert(val.getType().isa() && "expected operation"); pdl::OperationOp op = cast(val.getDefiningOp()); OperationPosition *opPos = cast(pos); @@ -158,6 +166,11 @@ bool isVariadic = operandIt.value().getType().isa(); foundVariableLength |= isVariadic; + // Ignore the specified operand, usually because this position was + // visited in an upward traversal via an iterative choice. + if (ignoreOperand && *ignoreOperand == operandIt.index()) + continue; + Position *pos = foundVariableLength ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic) @@ -300,15 +313,302 @@ } } +namespace { + +/// An op accepting a value at an optional index. +struct OpIndex { + Value parent; + Optional index; +}; + +/// The parent and operand index of each operation for each root, stored +/// as a nested map [root][operation]. +using ParentMaps = DenseMap>; + +} // namespace + +/// Given a pattern, determines the set of roots present in this pattern. +/// These are the operations whose results are not consumed by other operations. +static SmallVector detectRoots(pdl::PatternOp pattern) { + // First, collect all the operations that are used as operands + // to other operations. These are not roots by default. + DenseSet used; + for (auto operationOp : pattern.body().getOps()) { + for (Value operand : operationOp.operands()) + TypeSwitch(operand.getDefiningOp()) + .Case( + [&used](auto resultOp) { used.insert(resultOp.parent()); }); + } + + // Remove the specified root from the use set, so that we can + // always select it as a root, even if it is used by other operations. + if (Value root = pattern.getRewriter().root()) + used.erase(root); + + // Finally, collect all the unused operations. + SmallVector roots; + for (Value operationOp : pattern.body().getOps()) + if (!used.contains(operationOp)) + roots.push_back(operationOp); + + return roots; +} + +/// Given a list of candidate roots, builds the cost graph for connecting them. +/// The graph is formed by traversing the DAG of operations starting from each +/// root and marking the depth of each connector value (operand). Then we join +/// the candidate roots based on the common connector values, taking the one +/// with the minimum depth. Along the way, we compute, for each candidate root, +/// a mapping from each operation (in the DAG underneath this root) to its +/// parent operation and the corresponding operand index. +static void buildCostGraph(ArrayRef roots, RootOrderingGraph &graph, + ParentMaps &parentMaps) { + + // The entry of a queue. The entry consists of the following items: + // * the value in the DAG underneath the root; + // * the parent of the value; + // * the operand index of the value in its parent; + // * the depth of the visited value. + struct Entry { + Entry(Value value, Value parent, Optional index, unsigned depth) + : value(value), parent(parent), index(index), depth(depth) {} + + Value value; + Value parent; + Optional index; + unsigned depth; + }; + + // A root of a value and its depth (distance from root to the value). + struct RootDepth { + Value root; + unsigned depth = 0; + }; + + // Map from candidate connector values to their roots and depths. Using a + // small vector with 1 entry because most values belong to a single root. + llvm::MapVector> connectorsRootsDepths; + + // Perform a breadth-first traversal of the op DAG rooted at each root. + for (Value root : roots) { + // The queue of visited values. A value may be present multiple times in + // the queue, for multiple parents. We only accept the first occurrence, + // which is guaranteed to have the lowest depth. + std::queue toVisit; + toVisit.emplace(root, Value(), 0, 0); + + // The map from value to its parent for the current root. + DenseMap &parentMap = parentMaps[root]; + + while (!toVisit.empty()) { + Entry entry = toVisit.front(); + toVisit.pop(); + // Skip if already visited. + if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second) + continue; + + // Mark the root and depth of the value. + connectorsRootsDepths[entry.value].push_back({root, entry.depth}); + + // Traverse the operands of an operation and result ops. + // We intentionally do not traverse attributes and types, because those + // are expensive to join on. + TypeSwitch(entry.value.getDefiningOp()) + .Case([&](auto operationOp) { + OperandRange operands = operationOp.operands(); + // Special case when we pass all the operands in one range. + // For those, the index is empty. + if (operands.size() == 1 && + operands[0].getType().isa()) { + toVisit.emplace(operands[0], entry.value, llvm::None, + entry.depth + 1); + return; + } + + // Default case: visit all the operands. + for (auto p : llvm::enumerate(operationOp.operands())) + toVisit.emplace(p.value(), entry.value, p.index(), + entry.depth + 1); + }) + .Case([&](auto resultOp) { + toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(), + entry.depth); + }); + } + } + + // Now build the cost graph. + // This is simply a minimum over all depths for the target root. + unsigned nextID = 0; + for (const auto &connectorRootsDepths : connectorsRootsDepths) { + Value value = connectorRootsDepths.first; + ArrayRef rootsDepths = connectorRootsDepths.second; + // If there is only one root for this value, this will not trigger + // any edges in the cost graph (a perf optimization). + if (rootsDepths.size() == 1) + continue; + + for (const RootDepth &p : rootsDepths) { + for (const RootDepth &q : rootsDepths) { + if (&p == &q) + continue; + // Insert or retrieve the property of edge from p to q. + RootOrderingCost &cost = graph[q.root][p.root]; + if (!cost.connector /* new edge */ || cost.cost.first > q.depth) { + if (!cost.connector) + cost.cost.second = nextID++; + cost.cost.first = q.depth; + cost.connector = value; + } + } + } + } + + assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) && + "the pattern contains a candidate root disconnected from the others"); +} + +/// Visit a node during upward traversal. +void visitUpward(std::vector &predList, OpIndex opIndex, + PredicateBuilder &builder, + DenseMap &valueToPosition, Position *&pos, + bool &first) { + Value value = opIndex.parent; + TypeSwitch(value.getDefiningOp()) + .Case([&](auto operationOp) { + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index); + + // Guard against traversing back to where we came from. + if (first) { + Position *parent = pos->getParent(); + predList.emplace_back(opPos, builder.getNotEqualTo(parent)); + first = false; + } + + // Guard against duplicate upward visits. These are not possible, + // because if this value was already visited, it would have been + // cheaper to start the traversal at this value rather than at the + // `connector`, violating the optimality of our spanning tree. + bool inserted = valueToPosition.try_emplace(value, opPos).second; + assert(inserted && "duplicate upward visit"); + + // Obtain the tree predicates at the current value. + getTreePredicates(predList, value, builder, valueToPosition, opPos, + opIndex.index); + + // Update the position + pos = opPos; + }) + .Case([&](auto resultOp) { + // Traverse up an individual result. + auto *opPos = dyn_cast(pos); + assert(opPos && "operations and results must be interleaved"); + pos = builder.getResult(opPos, *opIndex.index); + }) + .Case([&](auto resultOp) { + // Traverse up a group of results. + auto *opPos = dyn_cast(pos); + assert(opPos && "operations and results must be interleaved"); + bool isVariadic = value.getType().isa(); + if (opIndex.index) + pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); + else + pos = builder.getAllResults(opPos); + }); +} + /// Given a pattern operation, build the set of matcher predicates necessary to /// match this pattern. -static void buildPredicateList(pdl::PatternOp pattern, - PredicateBuilder &builder, - std::vector &predList, - DenseMap &valueToPosition) { - getTreePredicates(predList, pattern.getRewriter().root(), builder, - valueToPosition, builder.getRoot()); +static Value buildPredicateList(pdl::PatternOp pattern, + PredicateBuilder &builder, + std::vector &predList, + DenseMap &valueToPosition) { + SmallVector roots = detectRoots(pattern); + + // Build the root ordering graph and compute the parent maps. + RootOrderingGraph graph; + ParentMaps parentMaps; + buildCostGraph(roots, graph, parentMaps); + LLVM_DEBUG({ + llvm::dbgs() << "Graph:\n"; + for (auto &target : graph) { + llvm::dbgs() << " * " << target.first << "\n"; + for (auto &source : target.second) { + RootOrderingCost c = source.second; + llvm::dbgs() << " <- " << source.first << ": " << c.cost.first + << ":" << c.cost.second << " via " << c.connector.getLoc() + << "\n"; + } + } + }); + + // Solve the optimal branching problem for each candidate root, or use the + // provided one. + Value bestRoot = pattern.getRewriter().root(); + OptimalBranching::EdgeList bestEdges; + if (!bestRoot) { + unsigned bestCost = 0; + LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); + for (Value root : roots) { + OptimalBranching solver(graph, root); + unsigned cost = solver.solve(); + LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); + if (!bestRoot || bestCost > cost) { + bestCost = cost; + bestRoot = root; + bestEdges = solver.preOrderTraversal(roots); + } + } + } else { + OptimalBranching solver(graph, bestRoot); + solver.solve(); + bestEdges = solver.preOrderTraversal(roots); + } + + LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); + + // The best root is the starting point for the traversal. Get the tree + // predicates for the DAG rooted at bestRoot. + getTreePredicates(predList, bestRoot, builder, valueToPosition, + builder.getRoot()); + + // Traverse the selected optimal branching. For all edges in order, traverse + // up starting from the connector, until the candidate root is reached, and + // call getTreePredicates at every node along the way. + for (const std::pair &edge : bestEdges) { + Value target = edge.first; + Value source = edge.second; + + // Check if we already visited the target root. This happens in two cases: + // 1) the initial root (bestRoot); + // 2) a root that is dominated by (contained in the subtree rooted at) an + // already visited root. + if (valueToPosition.count(target)) + continue; + + // Determine the connector. + Value connector = graph[target][source].connector; + assert(connector && "invalid edge"); + LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); + DenseMap parentMap = parentMaps.lookup(target); + Position *pos = valueToPosition.lookup(connector); + assert(pos && "The value has not been traversed yet"); + bool first = true; + + // Traverse from the connector upwards towards the target root. + for (Value value = connector; value != target;) { + OpIndex opIndex = parentMap.lookup(value); + assert(opIndex.parent && "missing parent"); + visitUpward(predList, opIndex, builder, valueToPosition, pos, first); + value = opIndex.parent; + } + } + getNonTreePredicates(pattern, predList, builder, valueToPosition); + + return bestRoot; } //===----------------------------------------------------------------------===// @@ -382,9 +682,11 @@ /// This class wraps a set of ordered predicates that are used within a specific /// pattern operation. struct OrderedPredicateList { - OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} + OrderedPredicateList(pdl::PatternOp pattern, Value root) + : pattern(pattern), root(root) {} pdl::PatternOp pattern; + Value root; DenseSet predicates; }; } // end anonymous namespace @@ -421,7 +723,8 @@ std::vector::iterator end) { if (current == end) { // We've hit the end of a pattern, so create a successful result node. - node = std::make_unique(list.pattern, std::move(node)); + node = + std::make_unique(list.pattern, list.root, std::move(node)); // If the pattern doesn't contain this predicate, ignore it. } else if (list.predicates.find(*current) == list.predicates.end()) { @@ -489,22 +792,37 @@ std::unique_ptr MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap &valueToPosition) { - // Collect the set of predicates contained within the pattern operations of - // the module. - SmallVector>, 16> - patternsAndPredicates; + // The set of predicates contained within the pattern operations of the + // module. + struct PatternPredicates { + PatternPredicates(pdl::PatternOp pattern, Value root, + std::vector predicates) + : pattern(pattern), root(root), predicates(std::move(predicates)) {} + + /// A pattern. + pdl::PatternOp pattern; + + /// A root of the pattern chosen among the candidate roots in pdl.rewrite. + Value root; + + /// The extracted predicates for this pattern and root. + std::vector predicates; + }; + + SmallVector patternsAndPredicates; for (pdl::PatternOp pattern : module.getOps()) { std::vector predicateList; - buildPredicateList(pattern, builder, predicateList, valueToPosition); - patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); + Value root = + buildPredicateList(pattern, builder, predicateList, valueToPosition); + patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList)); } // Associate a pattern result with each unique predicate. DenseSet uniqued; for (auto &patternAndPredList : patternsAndPredicates) { - for (auto &predicate : patternAndPredList.second) { + for (auto &predicate : patternAndPredList.predicates) { auto it = uniqued.insert(predicate); - it.first->patternToAnswer.try_emplace(patternAndPredList.first, + it.first->patternToAnswer.try_emplace(patternAndPredList.pattern, predicate.answer); } } @@ -513,8 +831,9 @@ std::vector lists; lists.reserve(patternsAndPredicates.size()); for (auto &patternAndPredList : patternsAndPredicates) { - OrderedPredicateList list(patternAndPredList.first); - for (auto &predicate : patternAndPredList.second) { + OrderedPredicateList list(patternAndPredList.pattern, + patternAndPredList.root); + for (auto &predicate : patternAndPredList.predicates) { OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); list.predicates.insert(orderedPredicate); @@ -580,11 +899,11 @@ // SuccessNode //===----------------------------------------------------------------------===// -SuccessNode::SuccessNode(pdl::PatternOp pattern, +SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr failureNode) : MatcherNode(TypeID::get(), /*position=*/nullptr, /*question=*/nullptr, std::move(failureNode)), - pattern(pattern) {} + pattern(pattern), root(root) {} //===----------------------------------------------------------------------===// // SwitchNode diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -11,7 +11,8 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/InferTypeOpInterface.h" -#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::pdl; @@ -34,41 +35,55 @@ // PDL Operations //===----------------------------------------------------------------------===// -/// Returns true if the given operation is used by a "binding" pdl operation -/// within the main matcher body of a `pdl.pattern`. -static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) { - for (OpOperand &use : op->getUses()) { - Operation *user = use.getOwner(); - if (user->getBlock() != matcherBlock) - continue; - if (isa(user)) - return true; - // Only the first operand of RewriteOp may be bound to, i.e. the root - // operation of the pattern. - if (isa(user) && use.getOperandNumber() == 0) - return true; +/// Returns true if the given operation is used by a "binding" pdl operation. +static bool hasBindingUse(Operation *op) { + for (Operation *user : op->getUsers()) // A result by itself is not binding, it must also be bound. - if (isa(user) && - hasBindingUseInMatcher(user, matcherBlock)) + if (!isa(user) || hasBindingUse(user)) return true; - } return false; } -/// Returns success if the given operation is used by a "binding" pdl operation -/// within the main matcher body of a `pdl.pattern`. On failure, emits an error -/// with the given context message. -static LogicalResult -verifyHasBindingUseInMatcher(Operation *op, - StringRef bindableContextStr = "`pdl.operation`") { - // If the pattern is not a pattern, there is nothing to do. +/// Returns success if the given operation is not in the main matcher body or +/// is used by a "binding" operation. On failure, emits an error. +static LogicalResult verifyHasBindingUse(Operation *op) { + // If the parent is not a pattern, there is nothing to do. if (!isa(op->getParentOp())) return success(); - if (hasBindingUseInMatcher(op, op->getBlock())) + if (hasBindingUse(op)) return success(); - return op->emitOpError() - << "expected a bindable (i.e. " << bindableContextStr - << ") user when defined in the matcher body of a `pdl.pattern`"; + return op->emitOpError( + "expected a bindable user when defined in the matcher body of a " + "`pdl.pattern`"); +} + +/// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) +/// connected to the given operation. +static void visit(Operation *op, DenseSet &visited) { + // If the parent is not a pattern, there is nothing to do. + if (!isa(op->getParentOp()) || isa(op)) + return; + + // Ignore if already visited. + if (visited.contains(op)) + return; + + // Mark as visited. + visited.insert(op); + + // Traverse the operands / parent. + TypeSwitch(op) + .Case([&visited](auto operation) { + for (Value operand : operation.operands()) + visit(operand.getDefiningOp(), visited); + }) + .Case([&visited](auto result) { + visit(result.parent().getDefiningOp(), visited); + }); + + // Traverse the users. + for (Operation *user : op->getUsers()) + visit(user, visited); } //===----------------------------------------------------------------------===// @@ -104,24 +119,20 @@ "`pdl.rewrite`"); if (attrValue && attrType) return op.emitOpError("expected only one of [`type`, `value`] to be set"); - return verifyHasBindingUseInMatcher(op); + return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::OperandOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandOp op) { - return verifyHasBindingUseInMatcher(op); -} +static LogicalResult verify(OperandOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::OperandsOp //===----------------------------------------------------------------------===// -static LogicalResult verify(OperandsOp op) { - return verifyHasBindingUseInMatcher(op); -} +static LogicalResult verify(OperandsOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::OperationOp @@ -237,7 +248,7 @@ return failure(); } - return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`"); + return verifyHasBindingUse(op); } bool OperationOp::hasTypeInference() { @@ -256,15 +267,16 @@ static LogicalResult verify(PatternOp pattern) { Region &body = pattern.body(); - auto *term = body.front().getTerminator(); - if (!isa(term)) { + Operation *term = body.front().getTerminator(); + auto rewrite_op = dyn_cast(term); + if (!rewrite_op) { return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") .attachNote(term->getLoc()) .append("see terminator defined here"); } - // Check that all values defined in the top-level pattern are referenced at - // least once in the source tree. + // Check that all values defined in the top-level pattern belong to the PDL + // dialect. WalkResult result = body.walk([&](Operation *op) -> WalkResult { if (!isa_and_nonnull(op->getDialect())) { pattern @@ -275,15 +287,61 @@ } return WalkResult::advance(); }); - return failure(result.wasInterrupted()); + if (result.wasInterrupted()) + return failure(); + + // Check that there is at least one operation. + if (body.front().getOps().empty()) + return pattern.emitOpError( + "the pattern must contain at least one `pdl.operation`"); + + // Determine if the operations within the pdl.pattern form a connected + // component. This is determined by starting the search from the first + // operand/result/operation and visiting their users / parents / operands. + // We limit our attention to operations that have a user in pdl.rewrite, + // those that do not will be detected via other means (expected bindable + // user). + bool first = true; + DenseSet visited; + for (Operation &op : body.front()) { + // The following are the operations forming the connected component. + if (!isa(op)) + continue; + + // Determine if the operation has a user in `pdl.rewrite`. + bool hasUserInRewrite = false; + for (Operation *user : op.getUsers()) { + Region *region = user->getParentRegion(); + if (isa(user) || + (region && isa(region->getParentOp()))) { + hasUserInRewrite = true; + break; + } + } + + // If the operation does not have a user in `pdl.rewrite`, ignore it. + if (!hasUserInRewrite) + continue; + + if (first) { + // For the first operation, invoke visit. + visit(&op, visited); + first = false; + } else if (!visited.count(&op)) { + // For the subsequent operations, check if already visited. + return pattern + .emitOpError("the operations must form a connected component") + .attachNote(op.getLoc()) + .append("see a disconnected value / operation here"); + } + } + + return success(); } void PatternOp::build(OpBuilder &builder, OperationState &state, - Optional rootKind, Optional benefit, - Optional name) { - build(builder, state, - rootKind ? builder.getStringAttr(*rootKind) : StringAttr(), - builder.getI16IntegerAttr(benefit ? *benefit : 0), + Optional benefit, Optional name) { + build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0), name ? builder.getStringAttr(*name) : StringAttr()); state.regions[0]->emplaceBlock(); } @@ -293,13 +351,6 @@ return cast(body().front().getTerminator()); } -/// Return the root operation kind that this pattern matches, or None if -/// there isn't a specific root. -Optional PatternOp::getRootKind() { - OperationOp rootOp = cast(getRewriter().root().getDefiningOp()); - return rootOp.name(); -} - //===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// @@ -380,18 +431,13 @@ // pdl::TypeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypeOp op) { - return verifyHasBindingUseInMatcher( - op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`"); -} +static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // pdl::TypesOp //===----------------------------------------------------------------------===// -static LogicalResult verify(TypesOp op) { - return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`"); -} +static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -384,7 +384,7 @@ // ----- // CHECK-LABEL: module @predicate_ordering -module @predicate_ordering { +module @predicate_ordering { // Check that the result is checked for null first, before applying the // constraint. The null check is prevalent in both patterns, so should be // prioritized first. @@ -408,3 +408,168 @@ pdl.rewrite %apply with "rewriter" } } + + +// ----- + +// CHECK-LABEL: module @multi_root +module @multi_root { + // Check the lowering of a simple two-root pattern. + // This checks that we correctly generate the pdl_interp.choose_op operation + // and tie the break between %root1 and %root2 in favor of %root1. + + // CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[ROOT1]] + // CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]] + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL1]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[OPS]] + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 0 of %[[ROOT2]] + // CHECK-DAG: pdl_interp.are_equal %[[VAL1]], %[[OPERANDS]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: pdl_interp.continue + // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]] + // CHECK-DAG: %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP1]] : !pdl.operation -> ^{{.*}}, ^[[CONTINUE]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP2]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[VAL1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[VAL2]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation + // CHECK-DAG: pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[CONTINUE]] + + pdl.pattern @rewrite_multi_root : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1 with "rewriter"(%root2 : !pdl.operation) + } +} + + +// ----- + +// CHECK-LABEL: module @overlapping_roots +module @overlapping_roots { + // Check the lowering of a degenerate two-root pattern, where one root + // is in the subtree rooted at another. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL]] + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[OP]] + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[OP]] + // CHECK-DAG: pdl_interp.is_not_null %[[VAL]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value + + pdl.pattern @rewrite_overlapping_roots : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type) + %val = pdl.result 0 of %op + %root = pdl.operation(%val : !pdl.value) + pdl.rewrite with "rewriter"(%root : !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @force_overlapped_root +module @force_overlapped_root { + // Check the lowering of a degenerate two-root pattern, where one root + // is in the subtree rooted at another, and we are forced to use this + // root as the root of the search tree. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_result 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 2 + // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT]] is 1 + // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT2]] : !pdl.value + // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] + // CHECK-DAG: pdl_interp.is_not_null %[[INPUT1]] : !pdl.value + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[OP:.*]] : !pdl.operation in %[[OPS]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.check_operand_count of %[[OP]] is 1 + + pdl.pattern @rewrite_forced_overlapped_root : benefit(1) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %root = pdl.operation(%input1, %input2 : !pdl.value, !pdl.value) -> (%type : !pdl.type) + %val = pdl.result 0 of %root + %op = pdl.operation(%val : !pdl.value) + pdl.rewrite %root with "rewriter"(%op : !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @variadic_results_all +module @variadic_results_all { + // Check the correct lowering when using all results of an operation + // and passing it them as operands to another operation. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is 0 + // CHECK-DAG: %[[VALS:.*]] = pdl_interp.get_results of %[[ROOT]] : !pdl.range + // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS]] + // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[OP:.*]] : !pdl.operation in %[[OPS]] + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands of %[[OP]] + // CHECK-DAG pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] + // CHECK-DAG: pdl_interp.check_result_count of %[[OP]] is 0 + pdl.pattern @variadic_results_all : benefit(1) { + %types = pdl.types + %root = pdl.operation -> (%types : !pdl.range) + %vals = pdl.results of %root + %op = pdl.operation(%vals : !pdl.range) + pdl.rewrite %root with "rewriter"(%op : !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @variadic_results_at +module @variadic_results_at { + // Check the correct lowering when using selected results of an operation + // and passing it them as an operand to another operation. + + // CHECK: func @matcher(%[[ROOT1:.*]]: !pdl.operation) + // CHECK-DAG: %[[VALS:.*]] = pdl_interp.get_operands 0 of %[[ROOT1]] : !pdl.range + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VALS]] : !pdl.range + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] : !pdl.operation + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT1]] is at_least 1 + // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT1]] is 0 + // CHECK-DAG: %[[VAL:.*]] = pdl_interp.get_operands 1 of %[[ROOT1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[VAL]] + // CHECK-DAG: pdl_interp.is_not_null %[[VALS]] + // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS]] + // CHECK-DAG: %[[ROOTS2:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[ROOTS2]] { + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 1 of %[[ROOT2]] + // CHECK-DAG: pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] : !pdl.range -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] + // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT2]] is at_least 1 + // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT2]] is 0 + // CHECK-DAG: pdl_interp.check_operand_count of %[[OP]] is 0 + // CHECK-DAG: pdl_interp.check_result_count of %[[OP]] is at_least 1 + pdl.pattern @variadic_results_at : benefit(1) { + %type = pdl.type + %types = pdl.types + %val = pdl.operand + %op = pdl.operation -> (%types, %type : !pdl.range, !pdl.type) + %vals = pdl.results 0 of %op -> !pdl.range + %root1 = pdl.operation(%vals, %val : !pdl.range, !pdl.value) + %root2 = pdl.operation(%val, %vals : !pdl.value, !pdl.range) + pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation) + } +} diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -67,7 +67,7 @@ // ----- pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.attribute %op = pdl.operation "foo.op" @@ -81,7 +81,7 @@ //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.operand %op = pdl.operation "foo.op" @@ -95,7 +95,7 @@ //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.operands %op = pdl.operation "foo.op" @@ -143,7 +143,7 @@ // ----- pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operation` or `pdl.rewrite`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.operation "foo.op" %op = pdl.operation "foo.op" @@ -164,6 +164,12 @@ // ----- +// expected-error@below {{the pattern must contain at least one `pdl.operation`}} +pdl.pattern : benefit(1) { + pdl.rewrite with "foo" +} + +// ----- // expected-error@below {{expected only `pdl` operations within the pattern body}} pdl.pattern : benefit(1) { // expected-note@below {{see non-`pdl` operation defined here}} @@ -173,6 +179,32 @@ pdl.rewrite %root with "foo" } +// ----- +// expected-error@below {{the operations must form a connected component}} +pdl.pattern : benefit(1) { + %op1 = pdl.operation "foo.op" + %op2 = pdl.operation "bar.op" + // expected-note@below {{see a disconnected value / operation here}} + %val = pdl.result 0 of %op2 + pdl.rewrite %op1 with "foo"(%val : !pdl.value) +} + +// ----- +// expected-error@below {{the operations must form a connected component}} +pdl.pattern : benefit(1) { + %type = pdl.type + %op1 = pdl.operation "foo.op" -> (%type : !pdl.type) + %val = pdl.result 0 of %op1 + %op2 = pdl.operation "bar.op"(%val : !pdl.value) + // expected-note@below {{see a disconnected value / operation here}} + %op3 = pdl.operation "baz.op" + pdl.rewrite { + pdl.erase %op1 + pdl.erase %op2 + pdl.erase %op3 + } +} + // ----- pdl.pattern : benefit(1) { @@ -212,7 +244,9 @@ %op = pdl.operation "foo.op" // expected-error@below {{expected rewrite region to be non-empty if external name is not specified}} - "pdl.rewrite"(%op) ({}) : (!pdl.operation) -> () + "pdl.rewrite"(%op) ({}) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- @@ -223,7 +257,9 @@ // expected-error@below {{expected no external arguments when the rewrite is specified inline}} "pdl.rewrite"(%op, %op) ({ ^bb1: - }) : (!pdl.operation, !pdl.operation) -> () + }) { + operand_segment_sizes = dense<1> : vector<2xi32> + }: (!pdl.operation, !pdl.operation) -> () } // ----- @@ -234,7 +270,9 @@ // expected-error@below {{expected no external constant parameters when the rewrite is specified inline}} "pdl.rewrite"(%op) ({ ^bb1: - }) {externalConstParams = []} : (!pdl.operation) -> () + }) { + operand_segment_sizes = dense<[1,0]> : vector<2xi32>, + externalConstParams = []} : (!pdl.operation) -> () } // ----- @@ -245,7 +283,10 @@ // expected-error@below {{expected rewrite region to be empty when rewrite is external}} "pdl.rewrite"(%op) ({ ^bb1: - }) {name = "foo"} : (!pdl.operation) -> () + }) { + name = "foo", + operand_segment_sizes = dense<[1,0]> : vector<2xi32> + } : (!pdl.operation) -> () } // ----- @@ -255,7 +296,7 @@ //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.attribute`, `pdl.operand`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.type %op = pdl.operation "foo.op" @@ -269,7 +310,7 @@ //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { - // expected-error@below {{expected a bindable (i.e. `pdl.operands`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + // expected-error@below {{expected a bindable user when defined in the matcher body of a `pdl.pattern`}} %unused = pdl.types %op = pdl.operation "foo.op" diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -42,6 +42,36 @@ // ----- +pdl.pattern @rewrite_multi_root_optimal : benefit(2) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation) +} + +// ----- + +pdl.pattern @rewrite_multi_root_forced : benefit(2) { + %input1 = pdl.operand + %input2 = pdl.operand + %type = pdl.type + %op1 = pdl.operation(%input1 : !pdl.value) -> (%type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %root1 = pdl.operation(%val1 : !pdl.value) + %op2 = pdl.operation(%input2 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %root2 = pdl.operation(%val1, %val2 : !pdl.value, !pdl.value) + pdl.rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation) +} + +// ----- + // Check that the result type of an operation within a rewrite can be inferred // from a pdl.replace. pdl.pattern @infer_type_from_operation_replace : benefit(1) {