diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1744,6 +1744,138 @@ } }; +/// Iterator to traverse all successors of a VPBlockBase node. This includes the +/// entry node of VPRegionBlocks. Exit blocks of a region implicitly have their +/// parent region's successors. This ensures all blocks in a region are visited +/// before any blocks in a successor region when doing a reverse post-order +// traversal of the graph. +template +class VPTraversingIterator + : public iterator_facade_base, + std::forward_iterator_tag, VPBlockBase> { + BlockTy Block; + size_t Idx; + +public: + VPTraversingIterator(BlockTy Block, size_t Idx = 0) + : Block(Block), Idx(Idx) {} + VPTraversingIterator(const VPTraversingIterator &Other) + : Block(Other.Block), Idx(Other.Idx) {} + + VPTraversingIterator &operator=(const VPTraversingIterator &R) { + Block = R.Block; + Idx = R.Idx; + return *this; + } + + static VPTraversingIterator end(BlockTy Block) { + if (Block->getParent() && Block->getNumSuccessors() == 0) + return {Block, Block->getParent()->getNumSuccessors()}; + + if (auto *R = dyn_cast(Block)) + return {R, R->getNumSuccessors() + 1}; + return {Block, Block->getNumSuccessors()}; + } + + bool operator==(const VPTraversingIterator &R) const { + return Block == R.Block && Idx == R.Idx; + } + + const VPBlockBase *operator*() const { + // For exit blocks, use the successors of their parent region. + if (Block->getParent() && Block->getNumSuccessors() == 0) { + return Block->getParent()->getSuccessors()[Idx]; + } + unsigned SuccIdx = Idx; + if (auto *R = dyn_cast(Block)) { + if (SuccIdx == 0) + return R->getEntry(); + SuccIdx--; + } + + return Block->getSuccessors()[SuccIdx]; + } + + BlockTy operator*() { + // For exit blocks, use the successors of their parent region. + if (Block->getParent() && Block->getNumSuccessors() == 0) { + return Block->getParent()->getSuccessors()[Idx]; + } + + size_t SuccIdx = Idx; + if (auto *R = dyn_cast(Block)) { + if (SuccIdx == 0) + return R->getEntry(); + SuccIdx--; + } + + return Block->getSuccessors()[SuccIdx]; + } + + VPTraversingIterator &operator++() { + Idx++; + return *this; + } + + VPTraversingIterator operator++(int x) { + VPTraversingIterator Orig = *this; + Idx++; + return Orig; + } +}; + +/// Helper for GraphTraits specialization that traverses through VPRegionBlocks. +template class VPBlockRecursiveTraversalWrapper { + BlockTy Entry; + +public: + VPBlockRecursiveTraversalWrapper(BlockTy Entry) : Entry(Entry) {} + BlockTy getEntry() { return Entry; } +}; + +/// GraphTraits specialization to recursively traverse VPBlockBase nodes, +/// including traversing through VPRegionBlocks. Exit blocks of a region +/// implicitly have their parent region's successors. This ensures all blocks in +/// a region are visited before any blocks in a successor region when doing a +/// reverse post-order traversal of the graph. +template <> +struct GraphTraits> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = VPTraversingIterator; + + static NodeRef + getEntryNode(VPBlockRecursiveTraversalWrapper N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType::end(N); + } +}; + +template <> +struct GraphTraits> { + using NodeRef = const VPBlockBase *; + using ChildIteratorType = VPTraversingIterator; + + static NodeRef + getEntryNode(VPBlockRecursiveTraversalWrapper N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType::end(N); + } +}; + /// VPlan models a candidate for vectorization, encoding various decisions take /// to produce efficient output IR, including which branches, basic-blocks and /// output IR instructions to generate, and their cost. VPlan holds a diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -324,6 +324,167 @@ } } +TEST(VPBasicBlockTest, TraversingIteratorTest) { + { + // VPBasicBlocks only + VPBasicBlock *VPBB1 = new VPBasicBlock(); + VPBasicBlock *VPBB2 = new VPBasicBlock(); + VPBasicBlock *VPBB3 = new VPBasicBlock(); + VPBasicBlock *VPBB4 = new VPBasicBlock(); + + // VPBB1 + // / \ + // VPBB2 VPBB3 + // \ / + // VPBB4 + VPBlockUtils::connectBlocks(VPBB1, VPBB2); + VPBlockUtils::connectBlocks(VPBB1, VPBB3); + VPBlockUtils::connectBlocks(VPBB2, VPBB4); + VPBlockUtils::connectBlocks(VPBB3, VPBB4); + + SmallVector FromIterator; + for (const VPBlockBase *VPBB : depth_first( + VPBlockRecursiveTraversalWrapper(VPBB1))) + FromIterator.push_back(VPBB); + EXPECT_EQ(4u, FromIterator.size()); + EXPECT_EQ(VPBB1, FromIterator[0]); + EXPECT_EQ(VPBB2, FromIterator[1]); + + // Use Plan to properly clean up created blocks. + VPlan Plan; + Plan.setEntry(VPBB1); + } + + // 2 consecutive regions. + { + VPBasicBlock *R1BB1 = new VPBasicBlock(); + VPBasicBlock *R1BB2 = new VPBasicBlock(); + VPBasicBlock *R1BB3 = new VPBasicBlock(); + VPBasicBlock *R1BB4 = new VPBasicBlock(); + VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB4, "R1"); + R1BB2->setParent(R1); + R1BB3->setParent(R1); + VPBlockUtils::connectBlocks(R1BB1, R1BB2); + VPBlockUtils::connectBlocks(R1BB1, R1BB3); + VPBlockUtils::connectBlocks(R1BB2, R1BB4); + VPBlockUtils::connectBlocks(R1BB3, R1BB4); + // Cycle. + VPBlockUtils::connectBlocks(R1BB3, R1BB3); + + VPBasicBlock *R2BB1 = new VPBasicBlock(); + VPBasicBlock *R2BB2 = new VPBasicBlock(); + VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2"); + VPBlockUtils::connectBlocks(R2BB1, R2BB2); + + VPBlockUtils::connectBlocks(R1, R2); + + VPBlockRecursiveTraversalWrapper N = {R1}; + SmallVector FromIterator; + for (VPBlockBase *VPBB : depth_first(N)) + FromIterator.push_back(VPBB); + + EXPECT_EQ(8u, FromIterator.size()); + EXPECT_EQ(R1, FromIterator[0]); + EXPECT_EQ(R1BB1, FromIterator[1]); + EXPECT_EQ(R1BB2, FromIterator[2]); + EXPECT_EQ(R1BB4, FromIterator[3]); + EXPECT_EQ(R2, FromIterator[4]); + EXPECT_EQ(R2BB1, FromIterator[5]); + EXPECT_EQ(R2BB2, FromIterator[6]); + EXPECT_EQ(R1BB3, FromIterator[7]); + + VPBlockRecursiveTraversalWrapper Start(R1); + FromIterator.clear(); + for (VPBlockBase *VPBB : post_order(Start)) + FromIterator.push_back(VPBB); + EXPECT_EQ(8u, FromIterator.size()); + EXPECT_EQ(R2BB2, FromIterator[0]); + EXPECT_EQ(R2BB1, FromIterator[1]); + EXPECT_EQ(R2, FromIterator[2]); + EXPECT_EQ(R1BB4, FromIterator[3]); + EXPECT_EQ(R1BB2, FromIterator[4]); + EXPECT_EQ(R1BB3, FromIterator[5]); + EXPECT_EQ(R1BB1, FromIterator[6]); + EXPECT_EQ(R1, FromIterator[7]); + + FromIterator.clear(); + FromIterator.append(po_begin(Start), po_end(Start)); + EXPECT_EQ(8u, FromIterator.size()); + EXPECT_EQ(R2BB2, FromIterator[0]); + EXPECT_EQ(R2BB1, FromIterator[1]); + EXPECT_EQ(R2, FromIterator[2]); + EXPECT_EQ(R1BB4, FromIterator[3]); + EXPECT_EQ(R1BB2, FromIterator[4]); + EXPECT_EQ(R1BB3, FromIterator[5]); + EXPECT_EQ(R1BB1, FromIterator[6]); + EXPECT_EQ(R1, FromIterator[7]); + + // Use Plan to properly clean up created blocks. + VPlan Plan; + Plan.setEntry(R1); + } + + // 2 nested regions. + // VPBB1 + // | + // R1 { + // R1BB1 + // / \ + // R2 { | + // R2BB1 | + // | R1BB2 + // R2BB2 | + // } | + // \ / + // R1BB3 + // } + // | + // VPBB2 + // + { + VPBasicBlock *R1BB1 = new VPBasicBlock("R1BB1"); + VPBasicBlock *R1BB2 = new VPBasicBlock("R1BB2"); + VPBasicBlock *R1BB3 = new VPBasicBlock("R1BB3"); + VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB3, "R1"); + + VPBasicBlock *R2BB1 = new VPBasicBlock("R2BB1"); + VPBasicBlock *R2BB2 = new VPBasicBlock("R2BB2"); + VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2"); + VPBlockUtils::connectBlocks(R2BB1, R2BB2); + + R2->setParent(R1); + VPBlockUtils::connectBlocks(R1BB1, R2); + R1BB2->setParent(R1); + VPBlockUtils::connectBlocks(R1BB1, R1BB2); + VPBlockUtils::connectBlocks(R1BB2, R1BB3); + VPBlockUtils::connectBlocks(R2, R1BB3); + + VPBasicBlock *VPBB1 = new VPBasicBlock("VPBB1"); + VPBlockUtils::connectBlocks(VPBB1, R1); + VPBasicBlock *VPBB2 = new VPBasicBlock("VPBB2"); + VPBlockUtils::connectBlocks(R1, VPBB2); + + SmallVector FromIterator( + df_begin(VPBlockRecursiveTraversalWrapper(VPBB1)), + df_end(VPBlockRecursiveTraversalWrapper(VPBB1))); + + EXPECT_EQ(9u, FromIterator.size()); + EXPECT_EQ(VPBB1, FromIterator[0]); + EXPECT_EQ(R1, FromIterator[1]); + EXPECT_EQ(R1BB1, FromIterator[2]); + EXPECT_EQ(R2, FromIterator[3]); + EXPECT_EQ(R2BB1, FromIterator[4]); + EXPECT_EQ(R2BB2, FromIterator[5]); + EXPECT_EQ(R1BB3, FromIterator[6]); + EXPECT_EQ(VPBB2, FromIterator[7]); + EXPECT_EQ(R1BB2, FromIterator[8]); + + // Use Plan to properly clean up created blocks. + VPlan Plan; + Plan.setEntry(VPBB1); + } +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) TEST(VPBasicBlockTest, print) { VPInstruction *I1 = new VPInstruction(Instruction::Add, {});