diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2201,6 +2201,37 @@ } return Count; } + + /// Return an iterator range over \p Iter which only includes VPBasicBlock *. + template static auto basicBlocksOnly(T Iter) { + // We need to first create an iterator range over VPBlockBase & instead of + // VPBlockBase * for filter_range to work properly. + auto Mapped = map_range( + Iter, [](VPBlockBase *Block) -> VPBlockBase & { return *Block; }); + auto Filter = make_filter_range( + Mapped, [](VPBlockBase &Block) { return isa(&Block); }); + return map_range(Filter, [](VPBlockBase &Block) -> VPBasicBlock * { + return cast(&Block); + }); + } + + /// Return an iterator range over \p Iter which only includes const + /// VPBasicBlock *. + template static auto constBasicBlocksOnly(T Iter) { + // We need to first create an iterator range over const VPBlockBase & + // instead of const VPBlockBase * for filter_range to work properly. + auto Mapped = + map_range(Iter, [](const VPBlockBase *Block) -> const VPBlockBase & { + return *Block; + }); + auto Filter = make_filter_range(Mapped, [](const VPBlockBase &Block) { + return isa(&Block); + }); + return map_range(Filter, + [](const VPBlockBase &Block) -> const VPBasicBlock * { + return cast(&Block); + }); + } }; class VPInterleavedAccessInfo { diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -405,6 +405,18 @@ EXPECT_EQ(R2BB2, FromIterator[6]); EXPECT_EQ(R1BB3, FromIterator[7]); + // VPBasicBlocks only. + FromIterator.clear(); + copy(VPBlockUtils::constBasicBlocksOnly(depth_first(Start)), + std::back_inserter(FromIterator)); + EXPECT_EQ(6u, FromIterator.size()); + EXPECT_EQ(R1BB1, FromIterator[0]); + EXPECT_EQ(R1BB2, FromIterator[1]); + EXPECT_EQ(R1BB4, FromIterator[2]); + EXPECT_EQ(R2BB1, FromIterator[3]); + EXPECT_EQ(R2BB2, FromIterator[4]); + EXPECT_EQ(R1BB3, FromIterator[5]); + // Post-order. FromIterator.clear(); copy(post_order(Start), std::back_inserter(FromIterator)); @@ -599,6 +611,14 @@ EXPECT_EQ(R3BB1, FromIterator[5]); EXPECT_EQ(VPBB2, FromIterator[6]); + FromIterator.clear(); + copy(VPBlockUtils::basicBlocksOnly(depth_first(Start)), + std::back_inserter(FromIterator)); + EXPECT_EQ(VPBB1, FromIterator[0]); + EXPECT_EQ(R2BB1, FromIterator[1]); + EXPECT_EQ(R3BB1, FromIterator[2]); + EXPECT_EQ(VPBB2, FromIterator[3]); + // Post-order. FromIterator.clear(); copy(post_order(Start), std::back_inserter(FromIterator)); @@ -611,6 +631,15 @@ EXPECT_EQ(R1, FromIterator[5]); EXPECT_EQ(VPBB1, FromIterator[6]); + FromIterator.clear(); + copy(VPBlockUtils::basicBlocksOnly(post_order(Start)), + std::back_inserter(FromIterator)); + EXPECT_EQ(FromIterator.size(), 4u); + EXPECT_EQ(VPBB2, FromIterator[0]); + EXPECT_EQ(R3BB1, FromIterator[1]); + EXPECT_EQ(R2BB1, FromIterator[2]); + EXPECT_EQ(VPBB1, FromIterator[3]); + // Use Plan to properly clean up created blocks. VPlan Plan; Plan.setEntry(VPBB1);