Index: lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- lib/CodeGen/CodeGenPrepare.cpp +++ lib/CodeGen/CodeGenPrepare.cpp @@ -158,7 +158,8 @@ const TargetRegisterInfo *TRI; const TargetTransformInfo *TTI; const TargetLibraryInfo *TLInfo; - const LoopInfo *LI; + LoopInfo *LI; + std::unique_ptr BFI; std::unique_ptr BPI; @@ -430,6 +431,9 @@ bool isEntry = SinglePred == &SinglePred->getParent()->getEntryBlock(); MergeBasicBlockIntoOnlyPred(BB, nullptr); + if (LI->getLoopFor(SinglePred)) + LI->removeBlock(SinglePred); + if (isEntry && BB != &BB->getParent()->getEntryBlock()) BB->moveBefore(&BB->getParent()->getEntryBlock()); @@ -620,6 +624,10 @@ IndPHI->eraseFromParent(); } + if (Loop * L = LI->getLoopFor(Target)) { + L->addBasicBlockToLoop(BodyBlock, *LI); + L->addBasicBlockToLoop(DirectSucc, *LI); + } Changed = true; } @@ -832,6 +840,9 @@ bool isEntry = SinglePred == &SinglePred->getParent()->getEntryBlock(); MergeBasicBlockIntoOnlyPred(DestBB, nullptr); + if (LI->getLoopFor(SinglePred)) + LI->removeBlock(SinglePred); + if (isEntry && BB != &BB->getParent()->getEntryBlock()) BB->moveBefore(&BB->getParent()->getEntryBlock()); @@ -872,6 +883,9 @@ // The PHIs are now updated, change everything that refers to BB to use // DestBB and remove BB. BB->replaceAllUsesWith(DestBB); + if (LI->getLoopFor(BB)) + LI->removeBlock(BB); + BB->eraseFromParent(); ++NumBlocksElim; @@ -1584,7 +1598,7 @@ // %13 = icmp eq i1 %12, true // br i1 %13, label %cond.load4, label %else5 // -static void scalarizeMaskedLoad(CallInst *CI) { +static void scalarizeMaskedLoad(CallInst *CI, LoopInfo &LI) { Value *Ptr = CI->getArgOperand(0); Value *Alignment = CI->getArgOperand(1); Value *Mask = CI->getArgOperand(2); @@ -1648,6 +1662,8 @@ PHINode *Phi = nullptr; Value *PrevPhi = UndefVal; + Loop *L = LI.getLoopFor(CI->getParent()); + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { // Fill the "else" block, created in the previous iteration @@ -1689,6 +1705,12 @@ Builder.SetInsertPoint(InsertPt); Instruction *OldBr = IfBlock->getTerminator(); BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); + + if (L) { + L->addBasicBlockToLoop(CondBlock, LI); + L->addBasicBlockToLoop(NewIfBlock, LI); + } + OldBr->eraseFromParent(); PrevIfBlock = IfBlock; IfBlock = NewIfBlock; @@ -1730,7 +1752,7 @@ // store i32 %8, i32* %9 // br label %else2 // . . . -static void scalarizeMaskedStore(CallInst *CI) { +static void scalarizeMaskedStore(CallInst *CI, LoopInfo &LI) { Value *Src = CI->getArgOperand(0); Value *Ptr = CI->getArgOperand(1); Value *Alignment = CI->getArgOperand(2); @@ -1779,6 +1801,7 @@ return; } + Loop *L = LI.getLoopFor(CI->getParent()); for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { // Fill the "else" block, created in the previous iteration @@ -1813,6 +1836,12 @@ Instruction *OldBr = IfBlock->getTerminator(); BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); OldBr->eraseFromParent(); + + if (L) { + L->addBasicBlockToLoop(CondBlock, LI); + L->addBasicBlockToLoop(NewIfBlock, LI); + } + IfBlock = NewIfBlock; } CI->eraseFromParent(); @@ -1849,7 +1878,7 @@ // . . . // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src // ret <16 x i32> %Result -static void scalarizeMaskedGather(CallInst *CI) { +static void scalarizeMaskedGather(CallInst *CI, LoopInfo &LI) { Value *Ptrs = CI->getArgOperand(0); Value *Alignment = CI->getArgOperand(1); Value *Mask = CI->getArgOperand(2); @@ -1899,6 +1928,7 @@ PHINode *Phi = nullptr; Value *PrevPhi = UndefVal; + Loop *L = LI.getLoopFor(CI->getParent()); for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { // Fill the "else" block, created in the previous iteration @@ -1944,6 +1974,12 @@ Instruction *OldBr = IfBlock->getTerminator(); BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); OldBr->eraseFromParent(); + + if (L) { + L->addBasicBlockToLoop(CondBlock, LI); + L->addBasicBlockToLoop(NewIfBlock, LI); + } + PrevIfBlock = IfBlock; IfBlock = NewIfBlock; } @@ -1984,7 +2020,7 @@ // store i32 % Elt1, i32* % Ptr1, align 4 // br label %else2 // . . . -static void scalarizeMaskedScatter(CallInst *CI) { +static void scalarizeMaskedScatter(CallInst *CI, LoopInfo &LI) { Value *Src = CI->getArgOperand(0); Value *Ptrs = CI->getArgOperand(1); Value *Alignment = CI->getArgOperand(2); @@ -2021,6 +2057,8 @@ CI->eraseFromParent(); return; } + + Loop *L = LI.getLoopFor(CI->getParent()); for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { // Fill the "else" block, created in the previous iteration // @@ -2057,6 +2095,12 @@ Instruction *OldBr = IfBlock->getTerminator(); BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); OldBr->eraseFromParent(); + + if (L) { + L->addBasicBlockToLoop(CondBlock, LI); + L->addBasicBlockToLoop(NewIfBlock, LI); + } + IfBlock = NewIfBlock; } CI->eraseFromParent(); @@ -2082,7 +2126,8 @@ static bool despeculateCountZeros(IntrinsicInst *CountZeros, const TargetLowering *TLI, const DataLayout *DL, - bool &ModifiedDT) { + bool &ModifiedDT, + LoopInfo &LI) { if (!TLI || !DL) return false; @@ -2112,6 +2157,8 @@ BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(CountZeros)); BasicBlock *EndBlock = CallBlock->splitBasicBlock(SplitPt, "cond.end"); + Loop *L = LI.getLoopFor(StartBlock); + // Set up a builder to create a compare, conditional branch, and PHI. IRBuilder<> Builder(CountZeros->getContext()); Builder.SetInsertPoint(StartBlock->getTerminator()); @@ -2137,6 +2184,12 @@ // undefined zero argument to 'true'. This will also prevent reprocessing the // intrinsic; we only despeculate when a zero input is defined. CountZeros->setArgOperand(1, Builder.getTrue()); + + if (L) { + L->addBasicBlockToLoop(CallBlock, LI); + L->addBasicBlockToLoop(EndBlock, LI); + } + ModifiedDT = true; return true; } @@ -2245,7 +2298,7 @@ case Intrinsic::masked_load: { // Scalarize unsupported vector masked load if (!TTI->isLegalMaskedLoad(CI->getType())) { - scalarizeMaskedLoad(CI); + scalarizeMaskedLoad(CI, *LI); ModifiedDT = true; return true; } @@ -2253,7 +2306,7 @@ } case Intrinsic::masked_store: { if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) { - scalarizeMaskedStore(CI); + scalarizeMaskedStore(CI, *LI); ModifiedDT = true; return true; } @@ -2261,7 +2314,7 @@ } case Intrinsic::masked_gather: { if (!TTI->isLegalMaskedGather(CI->getType())) { - scalarizeMaskedGather(CI); + scalarizeMaskedGather(CI, *LI); ModifiedDT = true; return true; } @@ -2269,7 +2322,7 @@ } case Intrinsic::masked_scatter: { if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) { - scalarizeMaskedScatter(CI); + scalarizeMaskedScatter(CI, *LI); ModifiedDT = true; return true; } @@ -2296,7 +2349,7 @@ case Intrinsic::cttz: case Intrinsic::ctlz: // If counting zeros is expensive, try to avoid it. - return despeculateCountZeros(II, TLI, DL, ModifiedDT); + return despeculateCountZeros(II, TLI, DL, ModifiedDT, *LI); } if (TLI) { @@ -2460,8 +2513,11 @@ } // If we eliminated all predecessors of the block, delete the block now. - if (Changed && !BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) + if (Changed && !BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) { + if (LI->getLoopFor(BB)) + LI->removeBlock(BB); BB->eraseFromParent(); + } return Changed; } @@ -5148,7 +5204,8 @@ /// Returns true if a SelectInst should be turned into an explicit branch. static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI, const TargetLowering *TLI, - SelectInst *SI) { + SelectInst *SI, + const LoopInfo *LI) { // If even a predictable select is cheap, then a branch can't be cheaper. if (!TLI->isPredictableSelectExpensive()) return false; @@ -5177,6 +5234,13 @@ if (!Cmp || !Cmp->hasOneUse()) return false; + // If the select is in a critical path of a loop, we aggressively turn it into + // a branch so that we rely more on the branch predictor. + if (Loop *L = LI->getLoopFor(SI->getParent())) { + if (L->getLoopLatch() == SI->getParent()) + return true; + } + // If either operand of the select is expensive and only needed on one side // of the select, we should form a branch. if (sinkSelectOperand(TTI, SI->getTrueValue()) || @@ -5241,7 +5305,7 @@ SelectKind = TargetLowering::ScalarValSelect; if (TLI->isSelectSupported(SelectKind) && - !isFormingBranchFromSelectProfitable(TTI, TLI, SI)) + !isFormingBranchFromSelectProfitable(TTI, TLI, SI, LI)) return false; ModifiedDT = true; @@ -5273,6 +5337,10 @@ BasicBlock *StartBlock = SI->getParent(); BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI)); BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end"); + Loop *L = LI->getLoopFor(StartBlock); + + if (L) + L->addBasicBlockToLoop(EndBlock, *LI); // Delete the unconditional branch that was just created by the split. StartBlock->getTerminator()->eraseFromParent(); @@ -5291,6 +5359,9 @@ if (TrueBlock == nullptr) { TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink", EndBlock->getParent(), EndBlock); + if (L) + L->addBasicBlockToLoop(TrueBlock, *LI); + TrueBranch = BranchInst::Create(EndBlock, TrueBlock); } auto *TrueInst = cast(SI->getTrueValue()); @@ -5300,6 +5371,9 @@ if (FalseBlock == nullptr) { FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink", EndBlock->getParent(), EndBlock); + if (L) + L->addBasicBlockToLoop(FalseBlock, *LI); + FalseBranch = BranchInst::Create(EndBlock, FalseBlock); } auto *FalseInst = cast(SI->getFalseValue()); @@ -5315,6 +5389,9 @@ FalseBlock = BasicBlock::Create(SI->getContext(), "select.false", EndBlock->getParent(), EndBlock); + if (L) + L->addBasicBlockToLoop(FalseBlock, *LI); + BranchInst::Create(EndBlock, FalseBlock); } @@ -6371,6 +6448,9 @@ MadeChange = true; + if (Loop *L = LI->getLoopFor(TmpBB)) + L->addBasicBlockToLoop(TmpBB, *LI); + DEBUG(dbgs() << "After branch condition splitting\n"; BB.dump(); TmpBB->dump()); } Index: test/CodeGen/AArch64/aarch64-aggressive-select-to-branch.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/aarch64-aggressive-select-to-branch.ll @@ -0,0 +1,69 @@ +; RUN: opt -codegenprepare -S -mcpu=kryo < %s | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-linaro-linux-gnueabi" + +; This test if a SelectInst in a loop latch is turned into an explicit branch +; in codegenprepare pass. This test case was an input IR of codegenprepare +; when compiling the C code below in -O3. +; +;struct s1 {int idx; int cost;}; +;void foo(struct s1 **G, int n) { +; int j, i; +; while (j < n) { +; if (G[j+1]->cost < G[j]->cost) +; j++; +; if (G[j]->cost > G[i]->cost) +; break; +; } +;} +; + +%struct.s1 = type { i32, i32 } +; +; CHECK-LABEL: @foo +; CHECK-LABEL: while.body: +; CHECK-NOT: select i1 +; CHECK: [[CMP:%.*]] = icmp slt i32 %l2, %l4 +; CHECK: br i1 [[CMP]], label %select.end, label %select.false +; +define void @foo(i32 %n, %struct.s1** nocapture %G) local_unnamed_addr #0 { +entry: + ;%0 = load i32, i32* @n, align 4 + ;%1 = load %struct.s1**, %struct.s1*** @G, align 8 + br label %while.cond + +while.cond: ; preds = %while.body, %entry + %j.0 = phi i32 [ undef, %entry ], [ %add.j.0, %while.body ] + %cmp = icmp slt i32 %j.0, %n + br i1 %cmp, label %while.body, label %while.end + +while.body: ; preds = %while.cond + %add = add nsw i32 %j.0, 1 + %idxprom = sext i32 %add to i64 + %arrayidx = getelementptr inbounds %struct.s1*, %struct.s1** %G, i64 %idxprom + %l = load %struct.s1*, %struct.s1** %arrayidx, align 8 + %cost = getelementptr inbounds %struct.s1, %struct.s1* %l, i64 0, i32 0 + %l2 = load i32, i32* %cost, align 4 + %idxprom1 = sext i32 %j.0 to i64 + %arrayidx2 = getelementptr inbounds %struct.s1*, %struct.s1** %G, i64 %idxprom1 + %l3 = load %struct.s1*, %struct.s1** %arrayidx2, align 8 + %cost3 = getelementptr inbounds %struct.s1, %struct.s1* %l3, i64 0, i32 0 + %l4 = load i32, i32* %cost3, align 4 + %cmp4 = icmp slt i32 %l2, %l4 + %add.j.0 = select i1 %cmp4, i32 %add, i32 %j.0 + %idxprom5 = sext i32 %add.j.0 to i64 + %arrayidx6 = getelementptr inbounds %struct.s1*, %struct.s1** %G, i64 %idxprom5 + %l5 = load %struct.s1*, %struct.s1** %arrayidx6, align 8 + %cost7 = getelementptr inbounds %struct.s1, %struct.s1* %l5, i64 0, i32 0 + %l6 = load i32, i32* %cost7, align 4 + %l7 = load %struct.s1*, %struct.s1** %G, align 8 + %cost10 = getelementptr inbounds %struct.s1, %struct.s1* %l7, i64 0, i32 0 + %l8 = load i32, i32* %cost10, align 4 + %cmp11 = icmp sgt i32 %l6, %l8 + br i1 %cmp11, label %while.end, label %while.cond + +while.end: ; preds = %while.body, %while.cond + ret void +} +