diff --git a/mlir/lib/Conversion/PDLToPDLInterp/CMakeLists.txt b/mlir/lib/Conversion/PDLToPDLInterp/CMakeLists.txt --- a/mlir/lib/Conversion/PDLToPDLInterp/CMakeLists.txt +++ b/mlir/lib/Conversion/PDLToPDLInterp/CMakeLists.txt @@ -2,6 +2,7 @@ PDLToPDLInterp.cpp Predicate.cpp PredicateTree.cpp + RootOrdering.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PDLToPDLInterp diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.h b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.h @@ -0,0 +1,137 @@ +//===- RootOrdering.h - Optimal root ordering ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definition for a cost graph over candidate roots and +// an implementation of an algorithm to determine the optimal ordering over +// these roots. Each edge in this graph indicates that the target root can be +// connected (via a chain of positions) to the source root, and their cost +// indicates the estimated cost of such traversal. The optimal root ordering +// is then formulated as that of finding a spanning arborescence (i.e., a +// directed spanning tree) of minimal weight. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ +#define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ + +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include +#include + +namespace mlir { +namespace pdl_to_pdl_interp { + +/// The information associated with an edge in the cost graph. Each node in +/// the cost graph corresponds to a candidate root detected in the pdl.pattern, +/// and each edge in the cost graph corresponds to connecting the two candidate +/// roots via a chain of operations. The cost of an edge is the smallest number +/// of upward traversals required to go from the source to the target root, and +/// the connector is a `Value` in the intersection of the two subtrees rooted at +/// the source and target root that results in that smallest number of upward +/// traversals. Consider the following pattern with 3 roots op3, op4, and op5: +/// +/// argA ---> op1 ---> op2 ---> op3 ---> res3 +/// ^ ^ +/// | | +/// argB argC +/// | | +/// v v +/// res4 <--- op4 op5 ---> res5 +/// ^ ^ +/// | | +/// op6 op7 +/// +/// The cost of the edge op3 -> op4 is 1 (the upward traversal argB -> op4), +/// with argB being the connector `Value` and similarly for op3 -> op5 (cost 1, +/// connector argC). The cost of the edge op4 -> op3 is 3 (upward traversals +/// argB -> op1 -> op2 -> op3, connector argB), while the cost of edge op5 -> +/// op3 is 2 (uwpard traversals argC -> op2 -> op3). There are no edges between +/// op4 and op5 in the cost graph, because the subtrees rooted at these two +/// roots do not intersect. It is easy to see that the optimal root for this +/// pattern is op3, resulting in the spanning arborescence op3 -> {op4, op5}. +struct RootOrderingCost { + /// The depth of the connector `Value` w.r.t. the target root. + /// + /// This is a pair where the first entry is the actual cost, and the second + /// entry is a priority for breaking ties (with 0 being the highest). + /// Typically, the priority is a unique edge ID. + std::pair cost; + + /// The connector value in the intersection of the two subtrees rooted at + /// the source and target root that results in that smallest depth w.r.t. + /// the target root. + Value connector; +}; + +/// A directed graph representing the cost of ordering the roots in the +/// predicate tree. It is represented as an adjacency map, where the outer map +/// is indexed by the target node, and the inner map is indexed by the source +/// node. Each edge is associated with a cost and the underlying connector +/// value. +using RootOrderingGraph = DenseMap>; + +/// The optimal branching algorithm solver. This solver accepts a graph and the +/// root in its constructor, and is invoked via the solve() member function. +/// This is a direct implementation of the Edmonds' algorithm, see +/// https://en.wikipedia.org/wiki/Edmonds%27_algorithm. The worst-case +/// computational complexity of this algorithm is O(N^3), for a single root. +/// The PDL-to-PDLInterp lowering calls this N times (once for each candidate +/// root), so the overall complexity root ordering is O(N^4). If needed, this +/// could be reduced to O(N^3) with a more efficient algorithm. However, note +/// that the underlying implementation is very efficient, and N in our +/// instances tends to be very small (<10). +class OptimalBranching { +public: + /// A list of edges (child, parent). + using EdgeList = std::vector>; + + /// Constructs the solver for the given graph and root value. + OptimalBranching(RootOrderingGraph graph, Value root); + + /// Runs the Edmonds' algorithm for the current `graph`, returning the total + /// cost of the minimum-weight spanning arborescence (sum of the edge costs). + /// This function first determines the optimal local choice of the parents + /// and stores this choice in the `parents` mapping. If this choice results + /// in an acyclic graph, the function returns immediately. Otherwise, it + /// takes an arbitrary cycle, contracts it, and recurses on the new graph + /// (which is guaranteed to have fewer nodes than we began with). After we + /// return from recursion, we redirect the edges to/from the contracted node, + /// so the `parents` map contains a valid solution for the current graph. + unsigned solve(); + + /// Returns the computed parent map. This is the unique predecessor for each + /// node (root) in the optimal branching. + const DenseMap &getRootOrderingParents() const { + return parents; + } + + /// Returns the computed edges as visited in the preorder traversal. + /// The specified array determines the order for breaking any ties. + EdgeList preOrderTraversal(ArrayRef nodes) const; + +private: + /// The graph whose optimal branching we wish to determine. + RootOrderingGraph graph; + + /// The root of the optimal branching. + Value root; + + /// The computed parent mapping. This is the unique predecessor for each node + /// in the optimal branching. The keys of this map correspond to the keys of + /// the outer map of the input graph, and each value is one of the keys of + /// the inner map for this node. Also used as an intermediate (possibly + /// cyclical) result in the optimal branching algorithm. + DenseMap parents; +}; + +} // end namespace pdl_to_pdl_interp +} // end namespace mlir + +#endif // MLIR_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp @@ -0,0 +1,229 @@ +//===- RootOrdering.cpp - Optimal root ordering ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// An implementation of Edmonds' optimal branching algorithm. This is a +// directed analogue of the minimum spanning tree problem for a given root. +// +//===----------------------------------------------------------------------===// + +#include "RootOrdering.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include +#include + +using namespace mlir; +using namespace mlir::pdl_to_pdl_interp; + +/// Returns the cycle implied by the specified parent relation, starting at the +/// given node. +static SmallVector getCycle(const DenseMap &parents, + Value rep) { + SmallVector cycle; + Value node = rep; + do { + cycle.push_back(node); + node = parents.lookup(node); + assert(node && "got an empty value in the cycle"); + } while (node != rep); + return cycle; +} + +/// Contracts the specified cycle in the given graph in-place. +/// The parentsCost map specifies, for each node in the cycle, the lowest cost +/// among the edges entering that node. Then, the nodes in the cycle C are +/// replaced with a single node v_C (the first node in the cycle). All edges +/// (u, v) entering the cycle, v \in C, are replaced with a single edge +/// (u, v_C) with an appropriately chosen cost, and the selected node v is +/// marked in the output map actualTarget[u]. All edges (u, v) leaving the +/// cycle, u \in C, are replaced with a single edge (v_C, v), and the selected +/// node u is marked in the ouptut map actualSource[v]. +static void contract(RootOrderingGraph &graph, ArrayRef cycle, + const DenseMap &parentCosts, + DenseMap &actualSource, + DenseMap &actualTarget) { + Value rep = cycle.front(); + DenseSet cycleSet(cycle.begin(), cycle.end()); + + // Now, contract the cycle, marking the actual sources and targets. + DenseMap repCosts; + for (auto outer = graph.begin(), e = graph.end(); outer != e; ++outer) { + Value target = outer->first; + if (cycleSet.contains(target)) { + // Target in the cycle => edges incoming to the cycle or within the cycle. + unsigned parentCost = parentCosts.lookup(target); + for (const auto &inner : outer->second) { + Value source = inner.first; + // Ignore edges within the cycle. + if (cycleSet.contains(source)) + continue; + + // Edge incoming to the cycle. + std::pair cost = inner.second.cost; + assert(parentCost <= cost.first && "invalid parent cost"); + + // Subtract the cost of the parent within the cycle from the cost of + // the edge incoming to the cycle. This update ensures that the cost + // of the minimum-weight spanning arborescence of the entire graph is + // the cost of arborescence for the contracted graph plus the cost of + // the cycle, no matter which edge in the cycle we choose to drop. + cost.first -= parentCost; + auto it = repCosts.find(source); + if (it == repCosts.end() || it->second.cost > cost) { + actualTarget[source] = target; + // Do not bother populating the connector (the connector is only + // relevant for the final traversal, not for the optimal branching). + repCosts[source].cost = cost; + } + } + // Erase the node in the cycle. + graph.erase(outer); + } else { + // Target not in cycle => edges going away from or unrelated to the cycle. + DenseMap &costs = outer->second; + Value bestSource; + std::pair bestCost; + auto inner = costs.begin(), inner_e = costs.end(); + while (inner != inner_e) { + Value source = inner->first; + if (cycleSet.contains(source)) { + // Going-away edge => get its cost and erase it. + if (!bestSource || bestCost > inner->second.cost) { + bestSource = source; + bestCost = inner->second.cost; + } + costs.erase(inner++); + } else { + ++inner; + } + } + + // There were going-away edges, contract them. + if (bestSource) { + costs[rep].cost = bestCost; + actualSource[target] = bestSource; + } + } + } + + // Store the edges to the representative. + graph[rep] = std::move(repCosts); +} + +OptimalBranching::OptimalBranching(RootOrderingGraph graph, Value root) + : graph(std::move(graph)), root(root) {} + +unsigned OptimalBranching::solve() { + // Initialize the parents and total cost. + parents.clear(); + parents[root] = Value(); + unsigned totalCost = 0; + + // A map that stores the cost of the optimal local choice for each node + // in a directed cycle. This map is cleared every time we seed the search. + DenseMap parentCosts; + parentCosts.reserve(graph.size()); + + // Determine if the optimal local choice results in an acyclic graph. This is + // done by computing the optimal local choice and traversing up the computed + // parents. On success, `parents` will contain the parent of each node. + for (const auto &outer : graph) { + Value node = outer.first; + if (parents.count(node)) // already visited + continue; + + // Follow the trail of best sources until we reach an already visited node. + // The code will assert if we cannot reach an already visited node, i.e., + // the graph is not strongly connected. + parentCosts.clear(); + do { + auto it = graph.find(node); + assert(it != graph.end() && "the graph is not strongly connected"); + + Value &bestSource = parents[node]; + unsigned &bestCost = parentCosts[node]; + for (const auto &inner : it->second) { + const RootOrderingCost &cost = inner.second; + if (!bestSource /* initial */ || bestCost > cost.cost.first) { + bestSource = inner.first; + bestCost = cost.cost.first; + } + } + assert(bestSource && "the graph is not strongly connected"); + node = bestSource; + totalCost += bestCost; + } while (!parents.count(node)); + + // If we reached a non-root node, we have a cycle. + if (parentCosts.count(node)) { + // Determine the cycle starting at the representative node. + SmallVector cycle = getCycle(parents, node); + + // The following maps disambiguate the source / target of the edges + // going out of / into the cycle. + DenseMap actualSource, actualTarget; + + // Contract the cycle and recurse. + contract(graph, cycle, parentCosts, actualSource, actualTarget); + totalCost = solve(); + + // Redirect the going-away edges. + for (auto &p : parents) + if (p.second == node) + // The parent is the node representating the cycle; replace it + // with the actual (best) source in the cycle. + p.second = actualSource.lookup(p.first); + + // Redirect the unique incoming edge and copy the cycle. + Value parent = parents.lookup(node); + Value entry = actualTarget.lookup(parent); + cycle.push_back(node); // complete the cycle + for (size_t i = 0, e = cycle.size() - 1; i < e; ++i) { + totalCost += parentCosts.lookup(cycle[i]); + if (cycle[i] == entry) + parents[cycle[i]] = parent; // break the cycle + else + parents[cycle[i]] = cycle[i + 1]; + } + + // `parents` has a complete solution. + break; + } + } + + return totalCost; +} + +OptimalBranching::EdgeList +OptimalBranching::preOrderTraversal(ArrayRef nodes) const { + // Invert the parent mapping. + DenseMap> children; + for (Value node : nodes) { + if (node != root) { + Value parent = parents.lookup(node); + assert(parent && "invalid parent"); + children[parent].push_back(node); + } + } + + // The result which simultaneously acts as a queue. + EdgeList result; + result.reserve(nodes.size()); + result.emplace_back(root, Value()); + + // Perform a BFS, pushing into the queue. + for (size_t i = 0; i < result.size(); ++i) { + Value node = result[i].first; + for (Value child : children[node]) + result.emplace_back(child, node); + } + + return result; +} diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -5,6 +5,7 @@ endfunction() add_subdirectory(Analysis) +add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(ExecutionEngine) add_subdirectory(Interfaces) diff --git a/mlir/unittests/Conversion/CMakeLists.txt b/mlir/unittests/Conversion/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(PDLToPDLInterp) diff --git a/mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt b/mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Conversion/PDLToPDLInterp/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRPDLToPDLInterpTests + RootOrderingTest.cpp +) +target_link_libraries(MLIRPDLToPDLInterpTests + PRIVATE + MLIRStandard + MLIRPDLToPDLInterp +) diff --git a/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp b/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp @@ -0,0 +1,106 @@ +//===- RootOrderingTest.cpp - unit tests for optimal branching ------------===// +// +// Part of the LLVM Project, under the Apache License v[1].0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::pdl_to_pdl_interp; + +namespace { + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +/// The test fixture for constructing root ordering tests and verifying results. +/// This fixture constructs the test values v. The test populates the graph +/// with the desired costs and then calls check(), passing the expeted optimal +/// cost and the list of edges in the preorder traversal of the optimal +/// branching. +class RootOrderingTest : public ::testing::Test { +protected: + RootOrderingTest() { + context.loadDialect(); + createValues(); + } + + /// Creates the test values. + void createValues() { + OpBuilder builder(&context); + for (int i = 0; i < 4; ++i) + v[i] = builder.create(builder.getUnknownLoc(), + builder.getI32IntegerAttr(i)); + } + + /// Checks that optimal branching on graph has the given cost and + /// its preorder traversal results in the specified edges. + void check(unsigned cost, OptimalBranching::EdgeList edges) { + OptimalBranching opt(graph, v[0]); + EXPECT_EQ(opt.solve(), cost); + EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges); + for (std::pair edge : edges) + EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second); + } + +protected: + /// The context for creating the values. + MLIRContext context; + + /// Values used in the graph definition. We always use leading `n` values. + Value v[4]; + + /// The graph being tested on. + RootOrderingGraph graph; +}; + +//===----------------------------------------------------------------------===// +// Simple 3-node graphs +//===----------------------------------------------------------------------===// + +TEST_F(RootOrderingTest, simpleA) { + graph[v[1]][v[0]].cost = {1, 10}; + graph[v[2]][v[0]].cost = {1, 11}; + graph[v[1]][v[2]].cost = {2, 12}; + graph[v[2]][v[1]].cost = {2, 13}; + check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}}); +} + +TEST_F(RootOrderingTest, simpleB) { + graph[v[1]][v[0]].cost = {1, 10}; + graph[v[2]][v[0]].cost = {2, 11}; + graph[v[1]][v[2]].cost = {1, 12}; + graph[v[2]][v[1]].cost = {1, 13}; + check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); +} + +TEST_F(RootOrderingTest, simpleC) { + graph[v[1]][v[0]].cost = {2, 10}; + graph[v[2]][v[0]].cost = {2, 11}; + graph[v[1]][v[2]].cost = {1, 12}; + graph[v[2]][v[1]].cost = {1, 13}; + check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); +} + +//===----------------------------------------------------------------------===// +// Graph for testing contraction +//===----------------------------------------------------------------------===// + +TEST_F(RootOrderingTest, contraction) { + graph[v[1]][v[0]].cost = {10, 0}; + graph[v[2]][v[0]].cost = {5, 0}; + graph[v[2]][v[1]].cost = {1, 0}; + graph[v[3]][v[2]].cost = {2, 0}; + graph[v[1]][v[3]].cost = {3, 0}; + check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}}); +} + +} // end namespace