diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -248,45 +248,43 @@ switch (pos->getKind()) { case Predicates::OperationPos: { auto *operationPos = cast(pos); - if (!operationPos->isUpward()) { + if (operationPos->isOperandDefiningOp()) // Standard (downward) traversal which directly follows the defining op. value = builder.create( loc, builder.getType(), parentVal); - break; - } + else + // A passthrough operation position. + value = parentVal; + break; + } + case Predicates::UsersPos: { + auto *usersPos = cast(pos); // The first operation retrieves the representative value of a range. - // This applies only when the parent is a range of values. - if (parentVal.getType().isa()) + // This applies only when the parent is a range of values and we were + // requested to use a representative value (e.g., upward traversal). + if (parentVal.getType().isa() && + usersPos->useRepresentative()) value = builder.create(loc, parentVal, 0); else value = parentVal; // The second operation retrieves the users. value = builder.create(loc, value); - - // The third operation iterates over them. + break; + } + case Predicates::ForEachPos: { assert(!failureBlockStack.empty() && "expected valid failure block"); auto foreach = builder.create( - loc, value, failureBlockStack.back(), /*initLoop=*/true); + loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); value = foreach.getLoopVariable(); - // Create the success and continuation blocks. - Block *successBlock = builder.createBlock(&foreach.region()); - Block *continueBlock = builder.createBlock(successBlock); + // Create the continuation block. + Block *continueBlock = builder.createBlock(&foreach.region()); builder.create(loc); failureBlockStack.push_back(continueBlock); - // The fourth operation extracts the operand(s) of the user at the specified - // index (which can be None, indicating all operands). - builder.setInsertionPointToStart(&foreach.region().front()); - Value operands = builder.create( - loc, parentVal.getType(), value, operationPos->getIndex()); - - // The fifth operation compares the operands to the parent value / range. - builder.create(loc, parentVal, operands, - successBlock, continueBlock); - currentBlock = successBlock; + currentBlock = &foreach.region().front(); break; } case Predicates::OperandPos: { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -52,6 +52,8 @@ TypePos, AttributeLiteralPos, TypeLiteralPos, + UsersPos, + ForEachPos, // Questions, ordered by dependency and decreasing priority. IsNotNullQuestion, @@ -186,6 +188,20 @@ }; //===----------------------------------------------------------------------===// +// ForEachPosition + +/// A position describing an iterative choice of an operation. +struct ForEachPosition : public PredicateBase, + Predicates::ForEachPos> { + explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; } + + /// Returns the ID, for differentiating various loops. + /// For upward traversals, this is the index of the root. + unsigned getID() const { return key.second; } +}; + +//===----------------------------------------------------------------------===// // OperandPosition /// A position describing an operand of an operation. @@ -229,14 +245,11 @@ /// An operation position describes an operation node in the IR. Other position /// kinds are formed with respect to an operation position. -struct OperationPosition - : public PredicateBase, unsigned>, - Predicates::OperationPos> { - static constexpr unsigned kDown = std::numeric_limits::max(); - +struct OperationPosition : public PredicateBase, + Predicates::OperationPos> { explicit OperationPosition(const KeyTy &key) : Base(key) { - parent = std::get<0>(key); + parent = key.first; } /// Returns a hash suitable for the given keytype. @@ -246,31 +259,22 @@ /// Gets the root position. static OperationPosition *getRoot(StorageUniquer &uniquer) { - return Base::get(uniquer, nullptr, kDown, 0); + return Base::get(uniquer, nullptr, 0); } - /// Gets an downward operation position with the given parent. + /// Gets an operation position with the given parent. static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { - return Base::get(uniquer, parent, kDown, parent->getOperationDepth() + 1); - } - - /// Gets an upward operation position with the given parent and operand. - static OperationPosition *get(StorageUniquer &uniquer, Position *parent, - Optional operand) { - return Base::get(uniquer, parent, operand, parent->getOperationDepth() + 1); + return Base::get(uniquer, parent, parent->getOperationDepth() + 1); } - /// Returns the operand index for an upward operation position. - Optional getIndex() const { return std::get<1>(key); } - - /// Returns if this operation position is upward, accepting an input. - bool isUpward() const { return getIndex().getValueOr(0) != kDown; } - /// Returns the depth of this position. - unsigned getDepth() const { return std::get<2>(key); } + unsigned getDepth() const { return key.second; } /// Returns if this operation position corresponds to the root. bool isRoot() const { return getDepth() == 0; } + + /// Returns if this operation represents an operand defining op. + bool isOperandDefiningOp() const; }; //===----------------------------------------------------------------------===// @@ -341,6 +345,26 @@ }; //===----------------------------------------------------------------------===// +// UsersPosition + +/// A position describing the users of a value or a range of values. The second +/// value in the key indicates whether we choose users of a representative for +/// a range (this is true, e.g., in the upward traversals). +struct UsersPosition + : public PredicateBase, + Predicates::UsersPos> { + explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; } + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + /// Indicates whether to compute a range of a representative. + bool useRepresentative() const { return key.second; } +}; + +//===----------------------------------------------------------------------===// // Qualifiers //===----------------------------------------------------------------------===// @@ -496,6 +520,7 @@ // Register the types of Positions with the uniquer. registerParametricStorageType(); registerParametricStorageType(); + registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); @@ -503,6 +528,7 @@ registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); + registerParametricStorageType(); // Register the types of Questions with the uniquer. registerParametricStorageType(); @@ -550,12 +576,10 @@ return OperationPosition::get(uniquer, p); } - /// Returns the position of operation using the value at the given index. - OperationPosition *getUsersOp(Position *p, Optional operand) { - assert((isa(p)) && - "expected result position"); - return OperationPosition::get(uniquer, p, operand); + /// Returns the operation position equivalent to the given position. + OperationPosition *getPassthroughOp(Position *p) { + assert((isa(p)) && "expected users position"); + return OperationPosition::get(uniquer, p); } /// Returns an attribute position for an attribute of the given operation. @@ -568,6 +592,10 @@ return AttributeLiteralPosition::get(uniquer, attr); } + Position *getForEach(Position *p, unsigned id) { + return ForEachPosition::get(uniquer, p, id); + } + /// Returns an operand position for an operand of the given operation. Position *getOperand(OperationPosition *p, unsigned operand) { return OperandPosition::get(uniquer, p, operand); @@ -605,6 +633,14 @@ return TypeLiteralPosition::get(uniquer, attr); } + /// Returns the users of a position using the value at the given operand. + UsersPosition *getUsers(Position *p, bool useRepresentative) { + assert((isa(p)) && + "expected result position"); + return UsersPosition::get(uniquer, p, useRepresentative); + } + //===--------------------------------------------------------------------===// // Qualifiers //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp @@ -48,4 +48,6 @@ //===----------------------------------------------------------------------===// // OperationPosition -constexpr unsigned OperationPosition::kDown; +bool OperationPosition::isOperandDefiningOp() const { + return isa_and_nonnull(parent); +} diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -158,8 +158,11 @@ // group, we treat it as all of the operands/results of the operation. /// Operands. if (operands.size() == 1 && operands[0].getType().isa()) { - getTreePredicates(predList, operands.front(), builder, inputs, - builder.getAllOperands(opPos)); + // Ignore the operands if we are performing an upward traversal (in that + // case, they have already been visited). + if (opPos->isRoot() || opPos->isOperandDefiningOp()) + getTreePredicates(predList, operands.front(), builder, inputs, + builder.getAllOperands(opPos)); } else { bool foundVariableLength = false; for (auto operandIt : llvm::enumerate(operands)) { @@ -502,23 +505,47 @@ "the pattern contains a candidate root disconnected from the others"); } +/// Returns true if the operand at the given index needs to be queried using an +/// operand group, i.e., if it is variadic itself or follows a variadic operand. +static bool useOperandGroup(pdl::OperationOp op, unsigned index) { + OperandRange operands = op.operands(); + assert(index < operands.size() && "operand index out of range"); + for (unsigned i = 0; i <= index; ++i) + if (operands[i].getType().isa()) + return true; + return false; +} + /// Visit a node during upward traversal. -void visitUpward(std::vector &predList, OpIndex opIndex, - PredicateBuilder &builder, - DenseMap &valueToPosition, Position *&pos, - bool &first) { +static void visitUpward(std::vector &predList, + OpIndex opIndex, PredicateBuilder &builder, + DenseMap &valueToPosition, + Position *&pos, unsigned rootID) { Value value = opIndex.parent; TypeSwitch(value.getDefiningOp()) .Case([&](auto operationOp) { LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); - OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index); - // Guard against traversing back to where we came from. - if (first) { - Position *parent = pos->getParent(); - predList.emplace_back(opPos, builder.getNotEqualTo(parent)); - first = false; + // Get users and iterate over them. + Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true); + Position *foreachPos = builder.getForEach(usersPos, rootID); + OperationPosition *opPos = builder.getPassthroughOp(foreachPos); + + // Compare the operand(s) of the user against the input value(s). + Position *operandPos; + if (!opIndex.index) { + // We are querying all the operands of the operation. + operandPos = builder.getAllOperands(opPos); + } else if (useOperandGroup(operationOp, *opIndex.index)) { + // We are querying an operand group. + Type type = operationOp.operands()[*opIndex.index].getType(); + bool variadic = type.isa(); + operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic); + } else { + // We are querying an individual operand. + operandPos = builder.getOperand(opPos, *opIndex.index); } + predList.emplace_back(operandPos, builder.getEqualTo(pos)); // Guard against duplicate upward visits. These are not possible, // because if this value was already visited, it would have been @@ -540,6 +567,9 @@ auto *opPos = dyn_cast(pos); assert(opPos && "operations and results must be interleaved"); pos = builder.getResult(opPos, *opIndex.index); + + // Insert the result position in case we have not visited it yet. + valueToPosition.try_emplace(value, pos); }) .Case([&](auto resultOp) { // Traverse up a group of results. @@ -550,6 +580,9 @@ pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); else pos = builder.getAllResults(opPos); + + // Insert the result position in case we have not visited it yet. + valueToPosition.try_emplace(value, pos); }); } @@ -568,7 +601,8 @@ LLVM_DEBUG({ llvm::dbgs() << "Graph:\n"; for (auto &target : graph) { - llvm::dbgs() << " * " << target.first << "\n"; + llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first + << "\n"; for (auto &source : target.second) { RootOrderingEntry &entry = source.second; llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first @@ -601,6 +635,17 @@ bestEdges = solver.preOrderTraversal(roots); } + // Print the best solution. + LLVM_DEBUG({ + llvm::dbgs() << "Best tree:\n"; + for (const std::pair &edge : bestEdges) { + llvm::dbgs() << " * " << edge.first; + if (edge.second) + llvm::dbgs() << " <- " << edge.second; + llvm::dbgs() << "\n"; + } + }); + LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); @@ -612,9 +657,9 @@ // Traverse the selected optimal branching. For all edges in order, traverse // up starting from the connector, until the candidate root is reached, and // call getTreePredicates at every node along the way. - for (const std::pair &edge : bestEdges) { - Value target = edge.first; - Value source = edge.second; + for (auto it : llvm::enumerate(bestEdges)) { + Value target = it.value().first; + Value source = it.value().second; // Check if we already visited the target root. This happens in two cases: // 1) the initial root (bestRoot); @@ -629,14 +674,13 @@ LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); DenseMap parentMap = parentMaps.lookup(target); Position *pos = valueToPosition.lookup(connector); - assert(pos && "The value has not been traversed yet"); - bool first = true; + assert(pos && "connector has not been traversed yet"); // Traverse from the connector upwards towards the target root. for (Value value = connector; value != target;) { OpIndex opIndex = parentMap.lookup(value); assert(opIndex.parent && "missing parent"); - visitUpward(predList, opIndex, builder, valueToPosition, pos, first); + visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index()); value = opIndex.parent; } } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -423,8 +423,8 @@ // CHECK-DAG: %[[OP1:.*]] = pdl_interp.get_defining_op of %[[VAL1]] // CHECK-DAG: %[[OPS:.*]] = pdl_interp.get_users of %[[VAL1]] : !pdl.value // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[OPS]] - // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 0 of %[[ROOT2]] - // CHECK-DAG: pdl_interp.are_equal %[[VAL1]], %[[OPERANDS]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operand 0 of %[[ROOT2]] + // CHECK-DAG: pdl_interp.are_equal %[[OPERANDS]], %[[VAL1]] : !pdl.value -> ^{{.*}}, ^[[CONTINUE:.*]] // CHECK-DAG: pdl_interp.continue // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 1 of %[[ROOT2]] // CHECK-DAG: %[[OP2:.*]] = pdl_interp.get_defining_op of %[[VAL2]] @@ -433,7 +433,6 @@ // CHECK-DAG: pdl_interp.is_not_null %[[VAL1]] : !pdl.value // CHECK-DAG: pdl_interp.is_not_null %[[VAL2]] : !pdl.value // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] : !pdl.operation - // CHECK-DAG: pdl_interp.are_equal %[[ROOT2]], %[[ROOT1]] : !pdl.operation -> ^[[CONTINUE]] pdl.pattern @rewrite_multi_root : benefit(1) { %input1 = pdl.operand @@ -556,7 +555,7 @@ // CHECK-DAG: %[[ROOTS2:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value // CHECK-DAG: pdl_interp.foreach %[[ROOT2:.*]] : !pdl.operation in %[[ROOTS2]] { // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands 1 of %[[ROOT2]] - // CHECK-DAG: pdl_interp.are_equal %[[VALS]], %[[OPERANDS]] : !pdl.range -> ^{{.*}}, ^[[CONTINUE:.*]] + // CHECK-DAG: pdl_interp.are_equal %[[OPERANDS]], %[[VALS]] : !pdl.range -> ^{{.*}}, ^[[CONTINUE:.*]] // CHECK-DAG: pdl_interp.is_not_null %[[ROOT2]] // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT2]] is at_least 1 // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT2]] is 0 @@ -612,3 +611,83 @@ } } +// ----- + +// CHECK-LABEL: module @common_connector +module @common_connector { + // Check the correct lowering when multiple roots are using the same + // connector. + + // CHECK: func @matcher(%[[ROOTC:.*]]: !pdl.operation) + // CHECK-DAG: %[[VAL2:.*]] = pdl_interp.get_operand 0 of %[[ROOTC]] + // CHECK-DAG: %[[INTER:.*]] = pdl_interp.get_defining_op of %[[VAL2]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[INTER]] : !pdl.operation -> ^bb2, ^bb1 + // CHECK-DAG: %[[VAL1:.*]] = pdl_interp.get_operand 0 of %[[INTER]] + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VAL1]] : !pdl.value + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] + // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.get_result 0 of %[[OP]] + // CHECK-DAG: %[[ROOTS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[ROOTA:.*]] : !pdl.operation in %[[ROOTS]] { + // CHECK-DAG: pdl_interp.is_not_null %[[ROOTA]] : !pdl.operation -> ^{{.*}}, ^[[CONTA:.*]] + // CHECK-DAG: pdl_interp.continue + // CHECK-DAG: pdl_interp.foreach %[[ROOTB:.*]] : !pdl.operation in %[[ROOTS]] { + // CHECK-DAG: pdl_interp.is_not_null %[[ROOTB]] : !pdl.operation -> ^{{.*}}, ^[[CONTB:.*]] + // CHECK-DAG: %[[ROOTA_OP:.*]] = pdl_interp.get_operand 0 of %[[ROOTA]] + // CHECK-DAG: pdl_interp.are_equal %[[ROOTA_OP]], %[[VAL0]] : !pdl.value + // CHECK-DAG: %[[ROOTB_OP:.*]] = pdl_interp.get_operand 0 of %[[ROOTB]] + // CHECK-DAG: pdl_interp.are_equal %[[ROOTB_OP]], %[[VAL0]] : !pdl.value + // CHECK-DAG } -> ^[[CONTA:.*]] + pdl.pattern @common_connector : benefit(1) { + %type = pdl.type + %op = pdl.operation -> (%type, %type : !pdl.type, !pdl.type) + %val0 = pdl.result 0 of %op + %val1 = pdl.result 1 of %op + %rootA = pdl.operation (%val0 : !pdl.value) + %rootB = pdl.operation (%val0 : !pdl.value) + %inter = pdl.operation (%val1 : !pdl.value) -> (%type : !pdl.type) + %val2 = pdl.result 0 of %inter + %rootC = pdl.operation (%val2 : !pdl.value) + pdl.rewrite with "rewriter"(%rootA, %rootB, %rootC : !pdl.operation, !pdl.operation, !pdl.operation) + } +} + +// ----- + +// CHECK-LABEL: module @common_connector_range +module @common_connector_range { + // Check the correct lowering when multiple roots are using the same + // connector range. + + // CHECK: func @matcher(%[[ROOTC:.*]]: !pdl.operation) + // CHECK-DAG: %[[VALS2:.*]] = pdl_interp.get_operands of %[[ROOTC]] : !pdl.range + // CHECK-DAG: %[[INTER:.*]] = pdl_interp.get_defining_op of %[[VALS2]] : !pdl.range + // CHECK-DAG: pdl_interp.is_not_null %[[INTER]] : !pdl.operation -> ^bb2, ^bb1 + // CHECK-DAG: %[[VALS1:.*]] = pdl_interp.get_operands of %[[INTER]] : !pdl.range + // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[VALS1]] : !pdl.range + // CHECK-DAG: pdl_interp.is_not_null %[[OP]] + // CHECK-DAG: %[[VALS0:.*]] = pdl_interp.get_results 0 of %[[OP]] + // CHECK-DAG: %[[VAL0:.*]] = pdl_interp.extract 0 of %[[VALS0]] : !pdl.value + // CHECK-DAG: %[[ROOTS:.*]] = pdl_interp.get_users of %[[VAL0]] : !pdl.value + // CHECK-DAG: pdl_interp.foreach %[[ROOTA:.*]] : !pdl.operation in %[[ROOTS]] { + // CHECK-DAG: pdl_interp.is_not_null %[[ROOTA]] : !pdl.operation -> ^{{.*}}, ^[[CONTA:.*]] + // CHECK-DAG: pdl_interp.continue + // CHECK-DAG: pdl_interp.foreach %[[ROOTB:.*]] : !pdl.operation in %[[ROOTS]] { + // CHECK-DAG: pdl_interp.is_not_null %[[ROOTB]] : !pdl.operation -> ^{{.*}}, ^[[CONTB:.*]] + // CHECK-DAG: %[[ROOTA_OPS:.*]] = pdl_interp.get_operands of %[[ROOTA]] + // CHECK-DAG: pdl_interp.are_equal %[[ROOTA_OPS]], %[[VALS0]] : !pdl.range + // CHECK-DAG: %[[ROOTB_OPS:.*]] = pdl_interp.get_operands of %[[ROOTB]] + // CHECK-DAG: pdl_interp.are_equal %[[ROOTB_OPS]], %[[VALS0]] : !pdl.range + // CHECK-DAG } -> ^[[CONTA:.*]] + pdl.pattern @common_connector_range : benefit(1) { + %types = pdl.types + %op = pdl.operation -> (%types, %types : !pdl.range, !pdl.range) + %vals0 = pdl.results 0 of %op -> !pdl.range + %vals1 = pdl.results 1 of %op -> !pdl.range + %rootA = pdl.operation (%vals0 : !pdl.range) + %rootB = pdl.operation (%vals0 : !pdl.range) + %inter = pdl.operation (%vals1 : !pdl.range) -> (%types : !pdl.range) + %vals2 = pdl.results of %inter + %rootC = pdl.operation (%vals2 : !pdl.range) + pdl.rewrite with "rewriter"(%rootA, %rootB, %rootC : !pdl.operation, !pdl.operation, !pdl.operation) + } +}