diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7193,8 +7193,9 @@ // --------------------------------------------------------------------------- // Create a dummy pre-entry VPBasicBlock to start building the VPlan. + auto Plan = std::make_unique(); VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); - auto Plan = std::make_unique(VPBB); + Plan->setEntry(VPBB); // Represent values that will have defs inside VPlan. for (Value *V : NeedDef) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -365,6 +365,10 @@ /// Current block predicate - null if the block does not need a predicate. VPValue *Predicate = nullptr; + /// VPlan containing the block. Can only be set on the entry block of the + /// plan. + VPlan *Plan = nullptr; + /// Add \p Successor as the last successor to this block. void appendSuccessor(VPBlockBase *Successor) { assert(Successor && "Cannot add nullptr successor!"); @@ -418,6 +422,14 @@ VPRegionBlock *getParent() { return Parent; } const VPRegionBlock *getParent() const { return Parent; } + /// \return A pointer to the plan containing the current block. + VPlan *getPlan(); + const VPlan *getPlan() const; + + /// Sets the pointer of the plan containing the block. The block must be the + /// entry block into the VPlan. + void setPlan(VPlan *ParentPlan); + void setParent(VPRegionBlock *P) { Parent = P; } /// \return the VPBasicBlock that is the entry of this VPBlockBase, @@ -1402,7 +1414,11 @@ VPBlockBase *getEntry() { return Entry; } const VPBlockBase *getEntry() const { return Entry; } - VPBlockBase *setEntry(VPBlockBase *Block) { return Entry = Block; } + VPBlockBase *setEntry(VPBlockBase *Block) { + Entry = Block; + Block->setPlan(this); + return Entry; + } /// The backedge taken count of the original loop. VPValue *getOrCreateBackedgeTakenCount() { diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -56,6 +56,32 @@ return OS; } +// Get the top-most entry block of \p Start. This is the entry block of the +// containing VPlan. This function is templated to support both const and non-const blocks +template static T *getPlanEntry(T *Start) { + T *Next = Start; + T *Current = Start; + while ((Next = Next->getParent())) + Current = Next; + + SmallSetVector WorkList; + WorkList.insert(Current); + + for (unsigned i = 0; i < WorkList.size(); i++) { + T *Current = WorkList[i]; + if (Current->getNumPredecessors() == 0) + return Current; + auto &Predecessors = Current->getPredecessors(); + WorkList.insert(Predecessors.begin(), Predecessors.end()); + } + + llvm_unreachable("VPlan without any entry node without predecessors"); +} + +VPlan *VPBlockBase::getPlan() { return getPlanEntry(this)->Plan; } + +const VPlan *VPBlockBase::getPlan() const { return getPlanEntry(this)->Plan; } + /// \return the VPBasicBlock that is the entry of Block, possibly indirectly. const VPBasicBlock *VPBlockBase::getEntryBasicBlock() const { const VPBlockBase *Block = this; @@ -71,6 +97,12 @@ return cast(Block); } +void VPBlockBase::setPlan(VPlan *ParentPlan) { + assert(ParentPlan->getEntry() == this && + "Can only set plan on its entry block."); + Plan = ParentPlan; +} + /// \return the VPBasicBlock that is the exit of Block, possibly indirectly. const VPBasicBlock *VPBlockBase::getExitBasicBlock() const { const VPBlockBase *Block = this; diff --git a/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanHCFGTest.cpp @@ -50,6 +50,7 @@ EXPECT_EQ(7u, VecBB->size()); EXPECT_EQ(2u, VecBB->getNumPredecessors()); EXPECT_EQ(2u, VecBB->getNumSuccessors()); + EXPECT_EQ(&*Plan, VecBB->getPlan()); auto Iter = VecBB->begin(); VPInstruction *Phi = dyn_cast(&*Iter++); diff --git a/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp @@ -102,6 +102,13 @@ EXPECT_EQ(InnerLoopLinSucc, OuterIf); EXPECT_EQ(OuterIfLinSucc, InnerIf); EXPECT_EQ(InnerIfLinSucc, InnerLoopLatch); + + // Check that the containing VPlan is set correctly. + EXPECT_EQ(&*Plan, InnerLoopLinSucc->getPlan()); + EXPECT_EQ(&*Plan, OuterIfLinSucc->getPlan()); + EXPECT_EQ(&*Plan, InnerIfLinSucc->getPlan()); + EXPECT_EQ(&*Plan, InnerIf->getPlan()); + EXPECT_EQ(&*Plan, InnerLoopLatch->getPlan()); } // Test generation of Not and Or during predication. diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -86,5 +86,95 @@ EXPECT_EQ(I3->getParent(), I4->getParent()); } +TEST(VPBasicBlockTest, getPlan) { + { + VPBasicBlock *VPBB1 = new VPBasicBlock(); + VPBasicBlock *VPBB2 = new VPBasicBlock(); + VPBasicBlock *VPBB3 = new VPBasicBlock(); + VPBasicBlock *VPBB4 = new VPBasicBlock(); + + // VPBB1 + // / \ + // VPBB2 VPBB3 + // \ / + // VPBB4 + VPBlockUtils::connectBlocks(VPBB1, VPBB2); + VPBlockUtils::connectBlocks(VPBB1, VPBB3); + VPBlockUtils::connectBlocks(VPBB2, VPBB4); + VPBlockUtils::connectBlocks(VPBB3, VPBB4); + + VPlan Plan; + Plan.setEntry(VPBB1); + + EXPECT_EQ(&Plan, VPBB1->getPlan()); + EXPECT_EQ(&Plan, VPBB2->getPlan()); + EXPECT_EQ(&Plan, VPBB3->getPlan()); + EXPECT_EQ(&Plan, VPBB4->getPlan()); + } + + { + // Region block is entry into VPlan. + VPBasicBlock *R1BB1 = new VPBasicBlock(); + VPBasicBlock *R1BB2 = new VPBasicBlock(); + VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1"); + VPBlockUtils::connectBlocks(R1BB1, R1BB2); + + VPlan Plan; + Plan.setEntry(R1); + EXPECT_EQ(&Plan, R1->getPlan()); + EXPECT_EQ(&Plan, R1BB1->getPlan()); + EXPECT_EQ(&Plan, R1BB2->getPlan()); + } + + { + // VPBasicBlock is the entry into the VPlan, followed by a region. + VPBasicBlock *R1BB1 = new VPBasicBlock(); + VPBasicBlock *R1BB2 = new VPBasicBlock(); + VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1"); + VPBlockUtils::connectBlocks(R1BB1, R1BB2); + + VPBasicBlock *VPBB1 = new VPBasicBlock(); + VPBlockUtils::connectBlocks(VPBB1, R1); + + VPlan Plan; + Plan.setEntry(VPBB1); + EXPECT_EQ(&Plan, VPBB1->getPlan()); + EXPECT_EQ(&Plan, R1->getPlan()); + EXPECT_EQ(&Plan, R1BB1->getPlan()); + EXPECT_EQ(&Plan, R1BB2->getPlan()); + } + + { + VPBasicBlock *R1BB1 = new VPBasicBlock(); + VPBasicBlock *R1BB2 = new VPBasicBlock(); + VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1"); + VPBlockUtils::connectBlocks(R1BB1, R1BB2); + + VPBasicBlock *R2BB1 = new VPBasicBlock(); + VPBasicBlock *R2BB2 = new VPBasicBlock(); + VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2"); + VPBlockUtils::connectBlocks(R2BB1, R2BB2); + + VPBasicBlock *VPBB1 = new VPBasicBlock(); + VPBlockUtils::connectBlocks(VPBB1, R1); + VPBlockUtils::connectBlocks(VPBB1, R2); + + VPBasicBlock *VPBB2 = new VPBasicBlock(); + VPBlockUtils::connectBlocks(R1, VPBB2); + VPBlockUtils::connectBlocks(R2, VPBB2); + + VPlan Plan; + Plan.setEntry(VPBB1); + EXPECT_EQ(&Plan, VPBB1->getPlan()); + EXPECT_EQ(&Plan, R1->getPlan()); + EXPECT_EQ(&Plan, R1BB1->getPlan()); + EXPECT_EQ(&Plan, R1BB2->getPlan()); + EXPECT_EQ(&Plan, R2->getPlan()); + EXPECT_EQ(&Plan, R2BB1->getPlan()); + EXPECT_EQ(&Plan, R2BB2->getPlan()); + EXPECT_EQ(&Plan, VPBB2->getPlan()); + } +} + } // namespace } // namespace llvm