Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -2009,9 +2009,9 @@ */ BasicBlock *OldBasicBlock = OrigLoop->getHeader(); - BasicBlock *BypassBlock = OrigLoop->getLoopPreheader(); + BasicBlock *VectorPreHeader = OrigLoop->getLoopPreheader(); BasicBlock *ExitBlock = OrigLoop->getExitBlock(); - assert(BypassBlock && "Invalid loop structure"); + assert(VectorPreHeader && "Invalid loop structure"); assert(ExitBlock && "Must have an exit block"); // Some loops have a single integer induction variable, while other loops @@ -2045,50 +2045,21 @@ // Notice that the pre-header does not change, only the loop body. SCEVExpander Exp(*SE, DL, "induction"); - // We need to test whether the backedge-taken count is uint##_max. Adding one - // to it will cause overflow and an incorrect loop trip count in the vector - // body. In case of overflow we want to directly jump to the scalar remainder - // loop. - Value *BackedgeCount = - Exp.expandCodeFor(BackedgeTakeCount, BackedgeTakeCount->getType(), - BypassBlock->getTerminator()); - if (BackedgeCount->getType()->isPointerTy()) - BackedgeCount = CastInst::CreatePointerCast(BackedgeCount, IdxTy, - "backedge.ptrcnt.to.int", - BypassBlock->getTerminator()); - Instruction *CheckBCOverflow = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, BackedgeCount, - Constant::getAllOnesValue(BackedgeCount->getType()), - "backedge.overflow", BypassBlock->getTerminator()); - // The loop index does not have to start at Zero. Find the original start // value from the induction PHI node. If we don't have an induction variable // then we know that it starts at zero. - Builder.SetInsertPoint(BypassBlock->getTerminator()); - Value *StartIdx = ExtendedIdx = OldInduction ? - Builder.CreateZExt(OldInduction->getIncomingValueForBlock(BypassBlock), - IdxTy): - ConstantInt::get(IdxTy, 0); - - // We need an instruction to anchor the overflow check on. StartIdx needs to - // be defined before the overflow check branch. Because the scalar preheader - // is going to merge the start index and so the overflow branch block needs to - // contain a definition of the start index. - Instruction *OverflowCheckAnchor = BinaryOperator::CreateAdd( - StartIdx, ConstantInt::get(IdxTy, 0), "overflow.check.anchor", - BypassBlock->getTerminator()); - - // Count holds the overall loop count (N). - Value *Count = Exp.expandCodeFor(ExitCount, ExitCount->getType(), - BypassBlock->getTerminator()); + Builder.SetInsertPoint(VectorPreHeader->getTerminator()); + Value *StartIdx = ExtendedIdx = + OldInduction + ? Builder.CreateZExt( + OldInduction->getIncomingValueForBlock(VectorPreHeader), IdxTy) + : ConstantInt::get(IdxTy, 0); - LoopBypassBlocks.push_back(BypassBlock); + LoopBypassBlocks.push_back(VectorPreHeader); // Split the single block loop into the two loop structure described above. - BasicBlock *VectorPH = - BypassBlock->splitBasicBlock(BypassBlock->getTerminator(), "vector.ph"); - BasicBlock *VecBody = - VectorPH->splitBasicBlock(VectorPH->getTerminator(), "vector.body"); + BasicBlock *VecBody = VectorPreHeader->splitBasicBlock( + VectorPreHeader->getTerminator(), "vector.body"); BasicBlock *MiddleBlock = VecBody->splitBasicBlock(VecBody->getTerminator(), "middle.block"); BasicBlock *ScalarPH = @@ -2103,29 +2074,61 @@ if (ParentLoop) { ParentLoop->addChildLoop(Lp); ParentLoop->addBasicBlockToLoop(ScalarPH, *LI); - ParentLoop->addBasicBlockToLoop(VectorPH, *LI); ParentLoop->addBasicBlockToLoop(MiddleBlock, *LI); } else { LI->addTopLevelLoop(Lp); } Lp->addBasicBlockToLoop(VecBody, *LI); - // Use this IR builder to create the loop instructions (Phi, Br, Cmp) - // inside the loop. - Builder.SetInsertPoint(VecBody->getFirstNonPHI()); - - // Generate the induction variable. - setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); - Induction = Builder.CreatePHI(IdxTy, 2, "index"); // The loop step is equal to the vectorization factor (num of SIMD elements) // times the unroll factor (num of SIMD instructions). Constant *Step = ConstantInt::get(IdxTy, VF * UF); // This is the IR builder that we use to add all of the logic for bypassing // the new vector loop. - IRBuilder<> BypassBuilder(BypassBlock->getTerminator()); + IRBuilder<> BypassBuilder(VectorPreHeader->getTerminator()); + + // Generate code to check that the loops trip count that we computed by adding + // one to the backedge-taken count will not overflow. + + // We need to test whether the backedge-taken count is uint##_max. Adding one + // to it will cause overflow and an incorrect loop trip count in the vector + // body. In case of overflow we want to directly jump to the scalar remainder + // loop. + Value *BackedgeCount = + Exp.expandCodeFor(BackedgeTakeCount, BackedgeTakeCount->getType(), + VectorPreHeader->getTerminator()); + if (BackedgeCount->getType()->isPointerTy()) + BackedgeCount = BypassBuilder.CreatePointerCast(BackedgeCount, IdxTy, + "backedge.ptrcnt.to.int"); + Value *CheckBCOverflow = BypassBuilder.CreateICmpEQ( + BackedgeCount, Constant::getAllOnesValue(BackedgeCount->getType()), + "backedge.overflow"); + + BasicBlock *CheckBlock = VectorPreHeader->splitBasicBlock( + VectorPreHeader->getTerminator(), "overflow.checked"); + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); + + BypassBuilder.SetInsertPoint(VectorPreHeader->getTerminator()); + BypassBuilder.CreateCondBr(CheckBCOverflow, ScalarPH, CheckBlock); + VectorPreHeader->getTerminator()->eraseFromParent(); + BypassBuilder.SetInsertPoint(CheckBlock->getTerminator()); + VectorPreHeader = CheckBlock; + setDebugLocFromInst(BypassBuilder, getDebugLocFromInstOrOperands(OldInduction)); + // Use this IR builder to create the loop instructions (Phi, Br, Cmp) + // inside the loop. + Builder.SetInsertPoint(VecBody->getFirstNonPHI()); + + // Generate the induction variable. + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + Induction = Builder.CreatePHI(IdxTy, 2, "index"); + + // Count holds the overall loop count (N). + Value *Count = Exp.expandCodeFor(ExitCount, ExitCount->getType(), + VectorPreHeader->getTerminator()); // We may need to extend the index in case there is a type mismatch. // We know that the count starts at zero and does not overflow. @@ -2153,23 +2156,16 @@ Value *Cmp = BypassBuilder.CreateICmpEQ(IdxEndRoundDown, StartIdx, "cmp.zero"); - BasicBlock *LastBypassBlock = BypassBlock; - - // Generate code to check that the loops trip count that we computed by adding - // one to the backedge-taken count will not overflow. - { - auto PastOverflowCheck = - std::next(BasicBlock::iterator(OverflowCheckAnchor)); - BasicBlock *CheckBlock = - LastBypassBlock->splitBasicBlock(PastOverflowCheck, "overflow.checked"); - if (ParentLoop) - ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); - LoopBypassBlocks.push_back(CheckBlock); - Instruction *OldTerm = LastBypassBlock->getTerminator(); - BranchInst::Create(ScalarPH, CheckBlock, CheckBCOverflow, OldTerm); - OldTerm->eraseFromParent(); - LastBypassBlock = CheckBlock; - } + CheckBlock = VectorPreHeader->splitBasicBlock( + VectorPreHeader->getTerminator(), "vector.ph"); + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); + LoopBypassBlocks.push_back(VectorPreHeader); + BypassBuilder.SetInsertPoint(VectorPreHeader->getTerminator()); + BypassBuilder.CreateCondBr(Cmp, MiddleBlock, CheckBlock); + VectorPreHeader->getTerminator()->eraseFromParent(); + BypassBuilder.SetInsertPoint(CheckBlock->getTerminator()); + VectorPreHeader = CheckBlock; // Generate the code to check that the strides we assumed to be one are really // one. We want the new basic block to start at the first instruction in a @@ -2177,24 +2173,24 @@ Instruction *StrideCheck; Instruction *FirstCheckInst; std::tie(FirstCheckInst, StrideCheck) = - addStrideCheck(LastBypassBlock->getTerminator()); + addStrideCheck(VectorPreHeader->getTerminator()); if (StrideCheck) { AddedSafetyChecks = true; // Create a new block containing the stride check. - BasicBlock *CheckBlock = - LastBypassBlock->splitBasicBlock(FirstCheckInst, "vector.stridecheck"); + VectorPreHeader->setName("vector.stridecheck"); + BasicBlock *CheckBlock = VectorPreHeader->splitBasicBlock( + VectorPreHeader->getTerminator(), "vector.ph"); if (ParentLoop) ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); - LoopBypassBlocks.push_back(CheckBlock); + LoopBypassBlocks.push_back(VectorPreHeader); // Replace the branch into the memory check block with a conditional branch // for the "few elements case". - Instruction *OldTerm = LastBypassBlock->getTerminator(); - BranchInst::Create(MiddleBlock, CheckBlock, Cmp, OldTerm); - OldTerm->eraseFromParent(); - - Cmp = StrideCheck; - LastBypassBlock = CheckBlock; + BypassBuilder.SetInsertPoint(VectorPreHeader->getTerminator()); + BypassBuilder.CreateCondBr(StrideCheck, MiddleBlock, CheckBlock); + VectorPreHeader->getTerminator()->eraseFromParent(); + BypassBuilder.SetInsertPoint(CheckBlock->getTerminator()); + VectorPreHeader = CheckBlock; } // Generate the code that checks in runtime if arrays overlap. We put the @@ -2202,30 +2198,24 @@ // faster. Instruction *MemRuntimeCheck; std::tie(FirstCheckInst, MemRuntimeCheck) = - Legal->getLAI()->addRuntimeCheck(LastBypassBlock->getTerminator()); + Legal->getLAI()->addRuntimeCheck(VectorPreHeader->getTerminator()); if (MemRuntimeCheck) { AddedSafetyChecks = true; - // Create a new block containing the memory check. - BasicBlock *CheckBlock = - LastBypassBlock->splitBasicBlock(FirstCheckInst, "vector.memcheck"); + VectorPreHeader->setName("vector.memcheck"); + BasicBlock *CheckBlock = VectorPreHeader->splitBasicBlock( + VectorPreHeader->getTerminator(), "vector.ph"); + BypassBuilder.SetInsertPoint(CheckBlock->getTerminator()); if (ParentLoop) ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); - LoopBypassBlocks.push_back(CheckBlock); - - // Replace the branch into the memory check block with a conditional branch - // for the "few elements case". - Instruction *OldTerm = LastBypassBlock->getTerminator(); - BranchInst::Create(MiddleBlock, CheckBlock, Cmp, OldTerm); - OldTerm->eraseFromParent(); + LoopBypassBlocks.push_back(VectorPreHeader); - Cmp = MemRuntimeCheck; - LastBypassBlock = CheckBlock; + BypassBuilder.SetInsertPoint(VectorPreHeader->getTerminator()); + BypassBuilder.CreateCondBr(MemRuntimeCheck, MiddleBlock, CheckBlock); + VectorPreHeader->getTerminator()->eraseFromParent(); + BypassBuilder.SetInsertPoint(CheckBlock->getTerminator()); + VectorPreHeader = CheckBlock; } - LastBypassBlock->getTerminator()->eraseFromParent(); - BranchInst::Create(MiddleBlock, VectorPH, Cmp, - LastBypassBlock); - // We are going to resume the execution of the scalar loop. // Go over all of the induction variables that we found and fix the // PHIs that are left in the scalar version of the loop. @@ -2365,7 +2355,7 @@ // Create i+1 and fill the PHINode. Value *NextIdx = Builder.CreateAdd(Induction, Step, "index.next"); - Induction->addIncoming(StartIdx, VectorPH); + Induction->addIncoming(StartIdx, VectorPreHeader); Induction->addIncoming(NextIdx, VecBody); // Create the compare. Value *ICmp = Builder.CreateICmpEQ(NextIdx, IdxEndRoundDown); @@ -2378,7 +2368,7 @@ Builder.SetInsertPoint(VecBody->getFirstInsertionPt()); // Save the state. - LoopVectorPreHeader = VectorPH; + LoopVectorPreHeader = VectorPreHeader; LoopScalarPreHeader = ScalarPH; LoopMiddleBlock = MiddleBlock; LoopExitBlock = ExitBlock; Index: test/Transforms/LoopVectorize/induction.ll =================================================================== --- test/Transforms/LoopVectorize/induction.ll +++ test/Transforms/LoopVectorize/induction.ll @@ -113,8 +113,7 @@ ; condition and branch directly to the scalar loop. ; CHECK-LABEL: max_i32_backedgetaken -; CHECK: %backedge.overflow = icmp eq i32 -1, -1 -; CHECK: br i1 %backedge.overflow, label %scalar.ph, label %overflow.checked +; CHECK: br i1 true, label %scalar.ph, label %overflow.checked ; CHECK: scalar.ph: ; CHECK: %bc.resume.val = phi i32 [ %resume.val, %middle.block ], [ 0, %0 ]