diff --git a/llvm/include/llvm/ADT/iterator.h b/llvm/include/llvm/ADT/iterator.h --- a/llvm/include/llvm/ADT/iterator.h +++ b/llvm/include/llvm/ADT/iterator.h @@ -333,6 +333,11 @@ PointerIteratorT(std::end(std::forward(Range)))); } +template ())>::type, + typename T2 = typename std::add_pointer::type> +using raw_pointer_iterator = pointer_iterator, T2>; + // Wrapper iterator over iterator ItType, adding DataRef to the type of ItType, // to create NodeRef = std::pair. template diff --git a/llvm/test/TableGen/GICombinerEmitter/parse-match-pattern.td b/llvm/test/TableGen/GICombinerEmitter/parse-match-pattern.td --- a/llvm/test/TableGen/GICombinerEmitter/parse-match-pattern.td +++ b/llvm/test/TableGen/GICombinerEmitter/parse-match-pattern.td @@ -132,8 +132,8 @@ // CHECK-NEXT: (MOV 0:dst, 1:src1):$__anon3_0 // $s=getOperand(0), $s2=getOperand(1) // CHECK-NEXT: (MOV 0:dst, 1:src1):$__anon3_2 // $d1=getOperand(0), $s=getOperand(1) // CHECK-NEXT: (MOV 0:dst, 1:src1):$__anon3_4 // $d2=getOperand(0), $s=getOperand(1) -// CHECK-NEXT: __anon3_2[src1] --[s]--> __anon3_0[dst] -// CHECK-NEXT: __anon3_4[src1] --[s]--> __anon3_0[dst] +// CHECK-NEXT: __anon3_0[dst] --[s]--> __anon3_2[src1] +// CHECK-NEXT: __anon3_0[dst] --[s]--> __anon3_4[src1] // CHECK-NEXT: <<$mi.getOpcode() == MOV>>:$__anonpred3_1 // CHECK-NEXT: <<$mi.getOpcode() == MOV>>:$__anonpred3_3 // CHECK-NEXT: <<$mi.getOpcode() == MOV>>:$__anonpred3_5 @@ -150,8 +150,8 @@ // CHECK-NEXT: Node[[N1:0x[0-9a-f]+]] [shape=record,label="{{[{]{}}#0 $dst|#1 $src1}|__anon3_0|MOV|Match starts here|$s=getOperand(0), $s2=getOperand(1)|{{0x[0-9a-f]+}}|{#0 $dst|#1 $src1}}",color=red] // CHECK-NEXT: Node[[N2:0x[0-9a-f]+]] [shape=record,label="{{[{]{}}#0 $dst|#1 $src1}|__anon3_2|MOV|$d1=getOperand(0), $s=getOperand(1)|{{0x[0-9a-f]+}}|{#0 $dst|#1 $src1}}"] // CHECK-NEXT: Node[[N3:0x[0-9a-f]+]] [shape=record,label="{{[{]{}}#0 $dst|#1 $src1}|__anon3_4|MOV|$d2=getOperand(0), $s=getOperand(1)|{{0x[0-9a-f]+}}|{#0 $dst|#1 $src1}}"] -// CHECK-NEXT: Node[[N2]]:s1:n -> Node[[N1]]:d0:s [label="$s"] -// CHECK-NEXT: Node[[N3]]:s1:n -> Node[[N1]]:d0:s [label="$s"] +// CHECK-NEXT: Node[[N2]]:s1:n -> Node[[N1]]:d0:s [label="$s",dir=back,arrowtail=crow] +// CHECK-NEXT: Node[[N3]]:s1:n -> Node[[N1]]:d0:s [label="$s",dir=back,arrowtail=crow] // CHECK-NEXT: Pred[[P1:0x[0-9a-f]+]] [shape=record,label="{{[{]{}}#0 $$|#1 $mi}|__anonpred3_1|$mi.getOpcode() == MOV|{{0x[0-9a-f]+}}|{#0 $$|#1 $mi}}",style=dotted] // CHECK-NEXT: Pred[[P2:0x[0-9a-f]+]] [shape=record,label="{{[{]{}}#0 $$|#1 $mi}|__anonpred3_3|$mi.getOpcode() == MOV|{{0x[0-9a-f]+}}|{#0 $$|#1 $mi}}",style=dotted] // CHECK-NEXT: Pred[[P3:0x[0-9a-f]+]] [shape=record,label="{{[{]{}}#0 $$|#1 $mi}|__anonpred3_5|$mi.getOpcode() == MOV|{{0x[0-9a-f]+}}|{#0 $$|#1 $mi}}",style=dotted] diff --git a/llvm/utils/TableGen/GICombinerEmitter.cpp b/llvm/utils/TableGen/GICombinerEmitter.cpp --- a/llvm/utils/TableGen/GICombinerEmitter.cpp +++ b/llvm/utils/TableGen/GICombinerEmitter.cpp @@ -11,6 +11,7 @@ /// //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" @@ -137,6 +138,60 @@ iterator_range roots() const { return llvm::make_range(Roots.begin(), Roots.end()); } + + /// The matcher will begin from the roots and will perform the match by + /// traversing the edges to cover the whole DAG. This function reverses DAG + /// edges such that everything is reachable from a root. This is part of the + /// preparation work for flattening the DAG into a tree. + void reorientToRoots() { + SmallSet Roots; + SmallSet Visited; + SmallSet EdgesRemaining; + + for (auto &I : MatchDag.roots()) { + Roots.insert(I); + Visited.insert(I); + } + for (auto &I : MatchDag.edges()) + EdgesRemaining.insert(I); + + bool Progressed = false; + while (!EdgesRemaining.empty()) { + SmallSet EdgesToRemove; + for (auto EI = EdgesRemaining.begin(), EE = EdgesRemaining.end(); + EI != EE; ++EI) { + if (Visited.count((*EI)->getFromMI())) { + if (Roots.count((*EI)->getToMI())) + PrintError(TheDef.getLoc(), "One or more roots are unnecessary"); + Visited.insert((*EI)->getToMI()); + EdgesToRemove.insert(*EI); + Progressed = true; + } + } + for (GIMatchDagEdge *ToRemove : EdgesToRemove) + EdgesRemaining.erase(ToRemove); + EdgesToRemove.clear(); + + for (auto EI = EdgesRemaining.begin(), EE = EdgesRemaining.end(); + EI != EE; ++EI) { + if (Visited.count((*EI)->getToMI())) { + (*EI)->reverse(); + Visited.insert((*EI)->getToMI()); + EdgesToRemove.insert(*EI); + Progressed = true; + } + for (GIMatchDagEdge *ToRemove : EdgesToRemove) + EdgesRemaining.erase(ToRemove); + EdgesToRemove.clear(); + } + + if (!Progressed) { + LLVM_DEBUG(dbgs() << "No progress\n"); + return; + } + Progressed = false; + } + } }; /// A convenience function to check that an Init refers to a specific def. This @@ -450,6 +505,9 @@ return nullptr; if (!Rule->parseMatcher(Target)) return nullptr; + + Rule->reorientToRoots(); + LLVM_DEBUG(dbgs() << "Parsed rule defs/match for '" << Rule->getName() << "'\n"); LLVM_DEBUG(Rule->getMatchDag().dump()); diff --git a/llvm/utils/TableGen/GlobalISel/GIMatchDag.h b/llvm/utils/TableGen/GlobalISel/GIMatchDag.h --- a/llvm/utils/TableGen/GlobalISel/GIMatchDag.h +++ b/llvm/utils/TableGen/GlobalISel/GIMatchDag.h @@ -53,6 +53,9 @@ public: using InstrNodesVec = std::vector>; using EdgesVec = std::vector>; + using edge_iterator = raw_pointer_iterator; + using const_edge_iterator = raw_pointer_iterator; + using PredicateNodesVec = std::vector>; using PredicateDependencyEdgesVec = @@ -73,6 +76,30 @@ GIMatchDag(const GIMatchDag &) = delete; GIMatchDagContext &getContext() const { return Ctx; } + edge_iterator edges_begin() { + return raw_pointer_iterator(Edges.begin()); + } + edge_iterator edges_end() { + return raw_pointer_iterator(Edges.end()); + } + const_edge_iterator edges_begin() const { + return raw_pointer_iterator(Edges.begin()); + } + const_edge_iterator edges_end() const { + return raw_pointer_iterator(Edges.end()); + } + iterator_range edges() { + return make_range(edges_begin(), edges_end()); + } + iterator_range edges() const { + return make_range(edges_begin(), edges_end()); + } + iterator_range::iterator> roots() { + return make_range(MatchRoots.begin(), MatchRoots.end()); + } + iterator_range::const_iterator> roots() const { + return make_range(MatchRoots.begin(), MatchRoots.end()); + } template GIMatchDagInstr *addInstrNode(Args &&... args) { auto Obj = diff --git a/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.h b/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.h --- a/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.h +++ b/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.h @@ -51,6 +51,9 @@ const GIMatchDagInstr *getToMI() const { return ToMI; } const GIMatchDagOperand *getToMO() const { return ToMO; } + /// Flip the direction of the edge. + void reverse(); + LLVM_DUMP_METHOD void print(raw_ostream &OS) const; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) diff --git a/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.cpp b/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.cpp --- a/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.cpp +++ b/llvm/utils/TableGen/GlobalISel/GIMatchDagEdge.cpp @@ -18,3 +18,8 @@ << "]"; } +void GIMatchDagEdge::reverse() { + std::swap(FromMI, ToMI); + std::swap(FromMO, ToMO); +} +