Index: lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- lib/CodeGen/CodeGenPrepare.cpp +++ lib/CodeGen/CodeGenPrepare.cpp @@ -4582,10 +4582,45 @@ return false; } +/// If \p isTrue is true, return the true value of \p SI, otherwise return +/// false value of \p SI. If the true/false value of \p SI is defined by any +/// select instructions in \p Selects, look through the defining select +/// instruction until the true/false value is not defined in \p Selects. +static Value *getTrueOrFalseValue( + SelectInst *SI, bool isTrue, + const SmallPtrSet &Selects) { + Value *V; + + for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI); + DefSI = dyn_cast(V)) { + assert(DefSI.getCondition() == SI->getCondition() && + "The condition of DefSI does not match with SI"); + V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue()); + } + return V; +} /// If we have a SelectInst that will likely profit from branch prediction, /// turn it into a branch. bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) { + // Find all consecutive select instructions that share the same condition. + SmallVector ASI; + ASI.push_back(SI); + for (BasicBlock::iterator It = ++BasicBlock::iterator(SI); + It != SI->getParent()->end(); ++It) { + SelectInst *I = dyn_cast(&*It); + if (I && SI->getCondition() == I->getCondition()) { + ASI.push_back(I); + } else { + break; + } + } + + SelectInst *LastSI = ASI.back(); + // Increment the current iterator to skip all the rest of select instructions + // because they will be either "not lowered" or "all lowered" to branch. + CurInstIterator = std::next(LastSI->getIterator()); + bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1); // Can we convert the 'select' to CF ? @@ -4632,7 +4667,7 @@ // First, we split the block containing the select into 2 blocks. BasicBlock *StartBlock = SI->getParent(); - BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(SI)); + BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI)); BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end"); // Delete the unconditional branch that was just created by the split. @@ -4642,22 +4677,30 @@ // At least one will become an actual new basic block. BasicBlock *TrueBlock = nullptr; BasicBlock *FalseBlock = nullptr; + BranchInst *TrueBranch = nullptr; + BranchInst *FalseBranch = nullptr; // Sink expensive instructions into the conditional blocks to avoid executing // them speculatively. - if (sinkSelectOperand(TTI, SI->getTrueValue())) { - TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink", - EndBlock->getParent(), EndBlock); - auto *TrueBranch = BranchInst::Create(EndBlock, TrueBlock); - auto *TrueInst = cast(SI->getTrueValue()); - TrueInst->moveBefore(TrueBranch); - } - if (sinkSelectOperand(TTI, SI->getFalseValue())) { - FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink", - EndBlock->getParent(), EndBlock); - auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock); - auto *FalseInst = cast(SI->getFalseValue()); - FalseInst->moveBefore(FalseBranch); + for (SelectInst *SI : ASI) { + if (sinkSelectOperand(TTI, SI->getTrueValue())) { + if (TrueBlock == nullptr) { + TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink", + EndBlock->getParent(), EndBlock); + TrueBranch = BranchInst::Create(EndBlock, TrueBlock); + } + auto *TrueInst = cast(SI->getTrueValue()); + TrueInst->moveBefore(TrueBranch); + } + if (sinkSelectOperand(TTI, SI->getFalseValue())) { + if (FalseBlock == nullptr) { + FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink", + EndBlock->getParent(), EndBlock); + FalseBranch = BranchInst::Create(EndBlock, FalseBlock); + } + auto *FalseInst = cast(SI->getFalseValue()); + FalseInst->moveBefore(FalseBranch); + } } // If there was nothing to sink, then arbitrarily choose the 'false' side @@ -4686,18 +4729,27 @@ BranchInst::Create(TrueBlock, FalseBlock, SI->getCondition(), SI); } - // The select itself is replaced with a PHI Node. - PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front()); - PN->takeName(SI); - PN->addIncoming(SI->getTrueValue(), TrueBlock); - PN->addIncoming(SI->getFalseValue(), FalseBlock); - - SI->replaceAllUsesWith(PN); - SI->eraseFromParent(); + SmallPtrSet INS; + INS.insert(ASI.begin(), ASI.end()); + // Use reverse iterator because later select may use the value of the + // earlier select, and we need to propagate value through earlier select + // to get the PHI operand. + for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) { + SelectInst *SI = *It; + // The select itself is replaced with a PHI Node. + PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front()); + PN->takeName(SI); + PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock); + PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock); + + SI->replaceAllUsesWith(PN); + SI->eraseFromParent(); + INS.erase(SI); + ++NumSelectsExpanded; + } // Instruct OptimizeBlock to skip to the next block. CurInstIterator = StartBlock->end(); - ++NumSelectsExpanded; return true; } Index: test/CodeGen/X86/pseudo_cmov_lower2.ll =================================================================== --- test/CodeGen/X86/pseudo_cmov_lower2.ll +++ test/CodeGen/X86/pseudo_cmov_lower2.ll @@ -98,3 +98,47 @@ %d5 = fdiv double %d4, %d3 ret double %d5 } + +; This test checks that only a single jae gets generated in the final code +; for lowering the CMOV pseudos that get created for this IR. The tricky part +; of this test is that it tests the special code in CodeGenPrepare. +; +; CHECK-LABEL: foo5: +; CHECK: jb +; CHECK-NOT: jb +define double @foo5(float %p1, double %p2, double %p3) nounwind { +entry: + %c1 = fcmp oge float %p1, 0.000000e+00 + %d0 = fadd double %p2, 1.25e0 + %d1 = fadd double %p3, 1.25e0 + %d2 = select i1 %c1, double %d0, double %d1, !prof !0 + %d3 = select i1 %c1, double %d2, double %p2, !prof !0 + %d4 = select i1 %c1, double %d3, double %p3, !prof !0 + %d5 = fsub double %d2, %d3 + %d6 = fadd double %d5, %d4 + ret double %d6 +} + +; We should expand select instructions into 3 conditional branches as their +; condtions are different. +; +; CHECK-LABEL: foo6: +; CHECK: jb +; CHECK: jae +; CHECK: jae +define double @foo6(float %p1, double %p2, double %p3) nounwind { +entry: + %c1 = fcmp oge float %p1, 0.000000e+00 + %c2 = fcmp oge float %p1, 1.000000e+00 + %c3 = fcmp oge float %p1, 2.000000e+00 + %d0 = fadd double %p2, 1.25e0 + %d1 = fadd double %p3, 1.25e0 + %d2 = select i1 %c1, double %d0, double %d1, !prof !0 + %d3 = select i1 %c2, double %d2, double %p2, !prof !0 + %d4 = select i1 %c3, double %d3, double %p3, !prof !0 + %d5 = fsub double %d2, %d3 + %d6 = fadd double %d5, %d4 + ret double %d6 +} + +!0 = !{!"branch_weights", i32 1, i32 2000}