diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -584,8 +584,6 @@ Value *CountRoundDown, Value *EndValue, BasicBlock *MiddleBlock); - void createLatchTerminator(Loop *L); - /// Handle all cross-iteration phis in the header. void fixCrossIterationPHIs(VPTransformState &State); @@ -760,6 +758,8 @@ /// The original loop. Loop *OrigLoop; + MDNode *OrigLoopID; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. Applies /// dynamic knowledge to simplify SCEV expressions and converts them to a /// more usable form. @@ -3077,27 +3077,6 @@ PredicatedInstructions.push_back(Cloned); } -void InnerLoopVectorizer::createLatchTerminator(Loop *L) { - BasicBlock *Header = L->getHeader(); - BasicBlock *Latch = L->getLoopLatch(); - // As we're just creating this loop, it's possible no latch exists - // yet. If so, use the header as this will be a single block loop. - if (!Latch) - Latch = Header; - - IRBuilder<> B(&*Header->getFirstInsertionPt()); - Instruction *OldInst = getDebugLocFromInstOrOperands(OldInduction); - - B.SetInsertPoint(Latch->getTerminator()); - setDebugLocFromInst(OldInst, &B); - - // Create the compare. - B.CreateCondBr(B.getTrue(), L->getUniqueExitBlock(), Header); - - // Now we have two terminators. Remove the old one from the block. - Latch->getTerminator()->eraseFromParent(); -} - Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { if (TripCount) return TripCount; @@ -3612,24 +3591,10 @@ "Inconsistent vector loop preheader"); Builder.SetInsertPoint(&*LoopVectorBody->getFirstInsertionPt()); - Optional VectorizedLoopID = - makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, - LLVMLoopVectorizeFollowupVectorized}); - if (VectorizedLoopID.hasValue()) { - L->setLoopID(VectorizedLoopID.getValue()); - - // Do not setAlreadyVectorized if loop attributes have been defined - // explicitly. - return LoopVectorPreHeader; - } - // Keep all loop hints from the original loop on the vector loop (we'll // replace the vectorizer-specific hints below). - if (MDNode *LID = OrigLoop->getLoopID()) - L->setLoopID(LID); - - LoopVectorizeHints Hints(L, true, *ORE); - Hints.setAlreadyVectorized(); + if (OrigLoopID) + OrigLoop->setLoopID(OrigLoopID); #ifdef EXPENSIVE_CHECKS assert(DT->verify(DominatorTree::VerificationLevel::Fast)); @@ -3720,7 +3685,6 @@ // times the unroll factor (num of SIMD instructions). Builder.SetInsertPoint(&*Lp->getHeader()->getFirstInsertionPt()); Value *CountRoundDown = getOrCreateVectorTripCount(Lp); - createLatchTerminator(Lp); // Emit phis for the new starting index of the scalar loop. createInductionResumeValues(Lp, CountRoundDown); @@ -8219,6 +8183,27 @@ State.setVFandUF(BestVF, BestUF); BestVPlan.execute(&State); + // Keep all loop hints from the original loop on the vector loop (we'll + // replace the vectorizer-specific hints below). + MDNode *OrigLoopID = OrigLoop->getLoopID(); + + Optional VectorizedLoopID = + makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, + LLVMLoopVectorizeFollowupVectorized}); + + Loop *L = LI->getLoopFor(State.CFG.PrevBB); + if (VectorizedLoopID.hasValue()) + L->setLoopID(VectorizedLoopID.getValue()); + + LoopVectorizeHints Hints(L, true, *ORE); + Hints.setAlreadyVectorized(); + + /* if (OrigLoopID)*/ + /*OrigLoop->setLoopID(OrigLoopID);*/ + + /*LoopVectorizeHints Hints(OrigLoop, true, *ORE);*/ + /*Hints.setAlreadyVectorized();*/ + // 3. Fix the vectorized code: take care of header phi's, live-outs, // predication, updating analyses. ILV.fixVectorizedLoop(State); @@ -8384,7 +8369,6 @@ OldInduction = Legal->getPrimaryInduction(); Value *CountRoundDown = getOrCreateVectorTripCount(Lp); EPI.VectorTripCount = CountRoundDown; - createLatchTerminator(Lp); // Skip induction resume value creation here because they will be created in // the second pass. If we created them here, they wouldn't be used anyway, @@ -8540,7 +8524,6 @@ // Generate the induction variable. OldInduction = Legal->getPrimaryInduction(); - createLatchTerminator(Lp); // Generate induction resume values. These variables save the new starting // indexes for the scalar loop. They are used to test if there are any tail @@ -9279,6 +9262,7 @@ DFS.perform(LI); VPBasicBlock *VPBB = nullptr; + VPRecipeBase *PrimaryInd = nullptr; for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { // Relevant instructions from basic block BB will be grouped into VPRecipe // ingredients and fill a new VPBasicBlock. @@ -9290,6 +9274,12 @@ auto *TopRegion = new VPRegionBlock("vector loop"); TopRegion->setEntry(FirstVPBBForBB); Plan->setEntry(TopRegion); + Type *IdxTy = Legal->getWidestInductionType(); + Value *StartIdx = ConstantInt::get(IdxTy, 0); + auto *StartV = Plan->getOrAddVPValue(StartIdx); + + PrimaryInd = new VPCanonicalIVRecipe(StartV); + FirstVPBBForBB->appendRecipe(PrimaryInd); } VPBB = FirstVPBBForBB; Builder.setInsertPoint(VPBB); @@ -9327,6 +9317,8 @@ } // Otherwise, add the new recipe. VPRecipeBase *Recipe = RecipeOrValue.get(); + if (auto *IndR = dyn_cast(Recipe)) + IndR->addOperand(PrimaryInd->getVPSingleValue()); for (auto *Def : Recipe->definedValues()) { auto *UV = Def->getUnderlyingValue(); if (UV) @@ -9499,10 +9491,6 @@ Value *StartIdx = ConstantInt::get(IdxTy, 0); auto *StartV = Plan->getOrAddVPValue(StartIdx); - auto *PrimaryInd = new VPCanonicalIVRecipe(StartV); - PrimaryInd->insertBefore( - &*TopRegion->getEntry()->getEntryBasicBlock()->begin()); - auto *InductionIncrement = cast(new VPInstruction( !CM.foldTailByMasking() ? VPInstruction::InductionIncrementNUW : VPInstruction::InductionIncrement, @@ -9510,6 +9498,16 @@ PrimaryInd->addOperand(InductionIncrement->getVPSingleValue()); VPBB->appendRecipe(InductionIncrement); + VPValue *BTC = Plan->getOrCreateScalarBackedgeTakenCount(); + auto *ExitCheck = cast( + new VPInstruction(VPInstruction::ICmpEQ, {InductionIncrement, BTC})); + VPBB->appendRecipe(ExitCheck); + + auto *Branch = cast( + new VPInstruction(VPInstruction::ExitBranch, {ExitCheck})); + + VPBB->appendRecipe(Branch); + // From this point onwards, VPlan-to-VPlan transformations may change the plan // in ways that accessing values using original IR values is incorrect. Plan->disableValue2VPValue(); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -310,6 +310,7 @@ /// The last IR BasicBlock in the output IR. Set to the new latch /// BasicBlock, used for placing the newly created BasicBlocks. BasicBlock *LastBB = nullptr; + BasicBlock *ExitBB = nullptr; /// The IR BasicBlock that is the preheader of the vector loop in the output /// IR. @@ -788,11 +789,13 @@ // values of a first-order recurrence. Not, ICmpULE, + ICmpEQ, SLPLoad, SLPStore, ActiveLaneMask, InductionIncrement, InductionIncrementNUW, + ExitBranch, }; private: @@ -2158,6 +2161,7 @@ /// Represents the backedge taken count of the original loop, for folding /// the tail. VPValue *BackedgeTakenCount = nullptr; + VPValue *ScalarBackedgeTakenCount = nullptr; /// Holds a mapping between Values and their corresponding VPValue inside /// VPlan. @@ -2192,6 +2196,8 @@ delete VPV; if (BackedgeTakenCount) delete BackedgeTakenCount; + if (ScalarBackedgeTakenCount) + delete ScalarBackedgeTakenCount; for (VPValue *Def : VPExternalDefs) delete Def; } @@ -2215,6 +2221,12 @@ return BackedgeTakenCount; } + VPValue *getOrCreateScalarBackedgeTakenCount() { + if (!ScalarBackedgeTakenCount) + ScalarBackedgeTakenCount = new VPValue(); + return ScalarBackedgeTakenCount; + } + /// Mark the plan to indicate that using Value2VPValue is not safe any /// longer, because it may be stale. void disableValue2VPValue() { Value2VPValueEnabled = false; } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -327,7 +327,7 @@ UnreachableInst *Terminator = State->Builder.CreateUnreachable(); State->Builder.SetInsertPoint(Terminator); // Register NewBB in its loop. In innermost loops its the same for all BB's. - Loop *L = State->LI->getLoopFor(State->CFG.LastBB); + Loop *L = State->LI->getLoopFor(State->CFG.PrevBB); L->addBasicBlockToLoop(NewBB, *State->LI); State->CFG.PrevBB = NewBB; } @@ -684,6 +684,26 @@ State.set(this, V, Part); break; } + case VPInstruction::ICmpEQ: { + if (Part == 0) { + Value *IV = State.get(getOperand(0), Part); + Value *TC = State.get(getOperand(1), Part); + Value *V = Builder.CreateICmpEQ(IV, TC); + State.set(this, V, Part); + } + break; + } + case VPInstruction::ExitBranch: { + if (Part == 0) { + Value *C = State.get(getOperand(0), Part); + auto *Plan = getParent()->getPlan(); + VPRegionBlock *TopRegion = Plan->getVectorLoopRegion(); + VPBasicBlock *Header = TopRegion->getEntry()->getEntryBasicBlock(); + Builder.CreateCondBr(C, State.CFG.ExitBB, State.CFG.VPBB2IRBB[Header]); + Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); + } + break; + } case Instruction::Select: { Value *Cond = State.get(getOperand(0), Part); Value *Op1 = State.get(getOperand(1), Part); @@ -744,6 +764,7 @@ break; } + default: llvm_unreachable("Unsupported opcode for instruction"); } @@ -822,6 +843,10 @@ for (unsigned Part = 0, UF = State->UF; Part < UF; ++Part) State->set(BackedgeTakenCount, VTCMO, Part); } + if (ScalarBackedgeTakenCount && ScalarBackedgeTakenCount->getNumUsers()) { + for (unsigned Part = 0, UF = State->UF; Part < UF; ++Part) + State->set(ScalarBackedgeTakenCount, State->VectorTripCount, Part); + } // 0. Set the reverse mapping from VPValues to Values for code generation. for (auto &Entry : Value2VPValue) @@ -833,10 +858,12 @@ assert(VectorHeaderBB && "Loop preheader does not have a single successor."); // 1. Make room to generate basic-blocks inside loop body if needed. - BasicBlock *VectorLatchBB = VectorHeaderBB->splitBasicBlock( - VectorHeaderBB->getFirstInsertionPt(), "vector.body.latch"); Loop *L = State->LI->getLoopFor(VectorHeaderBB); - L->addBasicBlockToLoop(VectorLatchBB, *State->LI); + State->CFG.PrevVPBB = nullptr; + State->CFG.PrevBB = VectorHeaderBB; + State->CFG.LastBB = L->getExitBlock(); + State->CFG.ExitBB = L->getExitBlock(); + // Remove the edge between Header and Latch to allow other connections. // Temporarily terminate with unreachable until CFG is rewired. // Note: this asserts the generated code's assumption that @@ -847,13 +874,27 @@ State->Builder.SetInsertPoint(Terminator); // 2. Generate code in loop body. - State->CFG.PrevVPBB = nullptr; - State->CFG.PrevBB = VectorHeaderBB; - State->CFG.LastBB = VectorLatchBB; - for (VPBlockBase *Block : depth_first(Entry)) Block->execute(State); + if (!State->CFG.VPBBsToFix.empty()) { + State->CFG.VPBB2IRBB[cast(getEntry()->getExitBasicBlock())] = + State->CFG.ExitBB; + VPBasicBlock *OuterLatch = cast( + getEntry()->getExitBasicBlock()->getSinglePredecessor()); + BasicBlock *BB = State->CFG.VPBB2IRBB[OuterLatch]; + assert(BB && "Unexpected null basic block for VPBB"); + + unsigned Idx = 0; + auto *BBTerminator = BB->getTerminator(); + + for (VPBlockBase *SuccVPBlock : OuterLatch->getSuccessors()) { + VPBasicBlock *SuccVPBB = SuccVPBlock->getEntryBasicBlock(); + BBTerminator->setSuccessor(Idx, State->CFG.VPBB2IRBB[SuccVPBB]); + ++Idx; + } + } + // Setup branch terminator successors for VPBBs in VPBBsToFix based on // VPBB's successors. for (auto VPBB : State->CFG.VPBBsToFix) { @@ -872,23 +913,7 @@ } } - // 3. Merge the temporary latch created with the last basic-block filled. - BasicBlock *LastBB = State->CFG.PrevBB; - // Connect LastBB to VectorLatchBB to facilitate their merge. - assert((EnableVPlanNativePath || - isa(LastBB->getTerminator())) && - "Expected InnerLoop VPlan CFG to terminate with unreachable"); - assert((!EnableVPlanNativePath || isa(LastBB->getTerminator())) && - "Expected VPlan CFG to terminate with branch in NativePath"); - LastBB->getTerminator()->eraseFromParent(); - BranchInst::Create(VectorLatchBB, LastBB); - - // Merge LastBB with Latch. - bool Merged = MergeBlockIntoPredecessor(VectorLatchBB, nullptr, State->LI); - (void)Merged; - assert(Merged && "Could not merge last basic block with latch."); - VectorLatchBB = LastBB; - + BasicBlock *VectorLatchBB = State->CFG.PrevBB; // Fix the latch value of reduction and first-order recurrences phis in the // vector loop. VPBasicBlock *Header = Entry->getEntryBasicBlock(); @@ -902,13 +927,6 @@ auto *P = cast(State->get(Ind->getVPSingleValue(), 0)); BasicBlock *LatchBB = State->CFG.VPBB2IRBB[BackedgeValue->getParent()]; P->addIncoming(State->get(BackedgeValue, 0), LatchBB); - auto *Next = cast(P->getIncomingValueForBlock(LatchBB)); - auto *TermBr = cast(LatchBB->getTerminator()); - State->Builder.SetInsertPoint(TermBr); - auto *ICmp = cast( - State->Builder.CreateICmpEQ(Next, State->VectorTripCount)); - TermBr->setCondition(ICmp); - Next->moveBefore(ICmp); continue; }