diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -197,7 +197,7 @@ } void VPBlockBase::deleteCFG(VPBlockBase *Entry) { - for (VPBlockBase *Block : to_vector(depth_first(Entry))) + for (VPBlockBase *Block : to_vector(vp_depth_first_shallow(Entry))) delete Block; } @@ -504,14 +504,15 @@ #endif void VPRegionBlock::dropAllReferences(VPValue *NewValue) { - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) // Drop all references in VPBasicBlocks and replace all uses with // DummyValue. Block->dropAllReferences(NewValue); } void VPRegionBlock::execute(VPTransformState *State) { - ReversePostOrderTraversal RPOT(Entry); + ReversePostOrderTraversal> + RPOT(Entry); if (!isReplicator()) { // Create and register the new vector loop. @@ -565,7 +566,7 @@ VPSlotTracker &SlotTracker) const { O << Indent << (isReplicator() ? " " : " ") << getName() << ": {"; auto NewIndent = Indent + " "; - for (auto *BlockBase : depth_first(Entry)) { + for (auto *BlockBase : vp_depth_first_shallow(Entry)) { O << '\n'; BlockBase->print(O, NewIndent, SlotTracker); } @@ -580,7 +581,7 @@ if (Entry) { VPValue DummyValue; - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) Block->dropAllReferences(&DummyValue); VPBlockBase::deleteCFG(Entry); @@ -670,7 +671,7 @@ State->Builder.SetInsertPoint(VectorPreHeader->getTerminator()); // Generate code in the loop pre-header and body. - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) Block->execute(State); VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock(); @@ -756,7 +757,7 @@ O << " = backedge-taken count\n"; } - for (const VPBlockBase *Block : depth_first(getEntry())) { + for (const VPBlockBase *Block : vp_depth_first_shallow(getEntry())) { O << '\n'; Block->print(O, "", SlotTracker); } @@ -881,7 +882,7 @@ OS << "edge [fontname=Courier, fontsize=30]\n"; OS << "compound=true\n"; - for (const VPBlockBase *Block : depth_first(Plan.getEntry())) + for (const VPBlockBase *Block : vp_depth_first_shallow(Plan.getEntry())) dumpBlock(Block); OS << "}\n"; @@ -966,7 +967,7 @@ << DOT::EscapeString(Region->getName()) << "\"\n"; // Dump the blocks of the region. assert(Region->getEntry() && "Region contains no inner blocks."); - for (const VPBlockBase *Block : depth_first(Region->getEntry())) + for (const VPBlockBase *Block : vp_depth_first_shallow(Region->getEntry())) dumpBlock(Block); bumpIndent(-1); OS << Indent << "}\n"; @@ -1035,7 +1036,8 @@ void VPInterleavedAccessInfo::visitRegion(VPRegionBlock *Region, Old2NewTy &Old2New, InterleavedAccessInfo &IAI) { - ReversePostOrderTraversal RPOT(Region->getEntry()); + ReversePostOrderTraversal> + RPOT(Region->getEntry()); for (VPBlockBase *Base : RPOT) { visitBlock(Base, Old2New, IAI); } diff --git a/llvm/lib/Transforms/Vectorize/VPlanCFG.h b/llvm/lib/Transforms/Vectorize/VPlanCFG.h --- a/llvm/lib/Transforms/Vectorize/VPlanCFG.h +++ b/llvm/lib/Transforms/Vectorize/VPlanCFG.h @@ -271,6 +271,66 @@ } }; +template class VPBlockNonRecursiveTraversalWrapper { + BlockTy Entry; + +public: + VPBlockNonRecursiveTraversalWrapper(BlockTy Entry) : Entry(Entry) {} + BlockTy getEntry() { return Entry; } +}; + +template <> +struct GraphTraits> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = SmallVectorImpl::iterator; + + static NodeRef + getEntryNode(VPBlockNonRecursiveTraversalWrapper N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +template <> +struct GraphTraits> { + using NodeRef = const VPBlockBase *; + using ChildIteratorType = SmallVectorImpl::const_iterator; + + static NodeRef + getEntryNode(VPBlockNonRecursiveTraversalWrapper N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +/// Returns an iterator range to traverse the graph starting at \p G in +/// depth-first order. The iterator won't traverse through region blocks. +inline iterator_range< + df_iterator>> +vp_depth_first_shallow(VPBlockBase *G) { + return depth_first(VPBlockNonRecursiveTraversalWrapper(G)); +} +inline iterator_range< + df_iterator>> +vp_depth_first_shallow(const VPBlockBase *G) { + return depth_first( + VPBlockNonRecursiveTraversalWrapper(G)); +} + } // namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -44,9 +44,7 @@ /// \p Region. Checks in this function are generic for VPBlockBases. They are /// not specific for VPBasicBlocks or VPRegionBlocks. static void verifyBlocksInRegion(const VPRegionBlock *Region) { - for (const VPBlockBase *VPB : make_range( - df_iterator::begin(Region->getEntry()), - df_iterator::end(Region->getExiting()))) { + for (const VPBlockBase *VPB : vp_depth_first_shallow(Region->getEntry())) { // Check block's parent. assert(VPB->getParent() == Region && "VPBlockBase has wrong parent");