Index: lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- lib/CodeGen/CodeGenPrepare.cpp +++ lib/CodeGen/CodeGenPrepare.cpp @@ -4541,6 +4541,32 @@ TTI->getUserCost(I) >= TargetTransformInfo::TCC_Expensive; } +/// Find all consecutive select instructions that share the same condition with +/// SI, and store them in ASI. Return true if all select instruction does not +/// depend on each other. +static bool findAdjacentSelectInstructions(SelectInst *SI, + SmallVector &ASI) { + for (BasicBlock::iterator It = BasicBlock::iterator(SI); + It != SI->getParent()->end(); ++It) { + SelectInst *I = dyn_cast(&*It); + if (I && SI->getCondition() == I->getCondition()) { + ASI.push_back(I); + } else { + break; + } + } + SmallPtrSet INS; + for (const auto *SI : ASI) { + const Instruction *I1 = dyn_cast(SI->getTrueValue()); + const Instruction *I2 = dyn_cast(SI->getFalseValue()); + + if (INS.count(I1) || INS.count(I2)) + return false; + INS.insert(SI); + } + return true; +} + /// Returns true if a SelectInst should be turned into an explicit branch. static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI, const TargetLowering *TLI, @@ -4586,6 +4612,16 @@ /// If we have a SelectInst that will likely profit from branch prediction, /// turn it into a branch. bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) { + SmallVector ASI; + bool ShouldOptimize = findAdjacentSelectInstructions(SI, ASI); + + // Increment the current iterator to skip all the rest of select instructions + // because they will be either "not lowered" or "all lowered" to branch. + CurInstIterator = std::next(ASI.back()->getIterator()); + + if (!ShouldOptimize) + return false; + bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1); // Can we convert the 'select' to CF ? @@ -4632,7 +4668,7 @@ // First, we split the block containing the select into 2 blocks. BasicBlock *StartBlock = SI->getParent(); - BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(SI)); + BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(ASI.back())); BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end"); // Delete the unconditional branch that was just created by the split. @@ -4642,22 +4678,30 @@ // At least one will become an actual new basic block. BasicBlock *TrueBlock = nullptr; BasicBlock *FalseBlock = nullptr; + BranchInst *TrueBranch = nullptr; + BranchInst *FalseBranch = nullptr; // Sink expensive instructions into the conditional blocks to avoid executing // them speculatively. - if (sinkSelectOperand(TTI, SI->getTrueValue())) { - TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink", - EndBlock->getParent(), EndBlock); - auto *TrueBranch = BranchInst::Create(EndBlock, TrueBlock); - auto *TrueInst = cast(SI->getTrueValue()); - TrueInst->moveBefore(TrueBranch); - } - if (sinkSelectOperand(TTI, SI->getFalseValue())) { - FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink", - EndBlock->getParent(), EndBlock); - auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock); - auto *FalseInst = cast(SI->getFalseValue()); - FalseInst->moveBefore(FalseBranch); + for (SelectInst *SI : ASI) { + if (sinkSelectOperand(TTI, SI->getTrueValue())) { + if (TrueBlock == nullptr) { + TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink", + EndBlock->getParent(), EndBlock); + TrueBranch = BranchInst::Create(EndBlock, TrueBlock); + } + auto *TrueInst = cast(SI->getTrueValue()); + TrueInst->moveBefore(TrueBranch); + } + if (sinkSelectOperand(TTI, SI->getFalseValue())) { + if (FalseBlock == nullptr) { + FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink", + EndBlock->getParent(), EndBlock); + FalseBranch = BranchInst::Create(EndBlock, FalseBlock); + } + auto *FalseInst = cast(SI->getFalseValue()); + FalseInst->moveBefore(FalseBranch); + } } // If there was nothing to sink, then arbitrarily choose the 'false' side @@ -4686,18 +4730,21 @@ BranchInst::Create(TrueBlock, FalseBlock, SI->getCondition(), SI); } - // The select itself is replaced with a PHI Node. - PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front()); - PN->takeName(SI); - PN->addIncoming(SI->getTrueValue(), TrueBlock); - PN->addIncoming(SI->getFalseValue(), FalseBlock); - - SI->replaceAllUsesWith(PN); - SI->eraseFromParent(); + for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) { + SelectInst *SI = *It; + // The select itself is replaced with a PHI Node. + PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front()); + PN->takeName(SI); + PN->addIncoming(SI->getTrueValue(), TrueBlock); + PN->addIncoming(SI->getFalseValue(), FalseBlock); + + SI->replaceAllUsesWith(PN); + SI->eraseFromParent(); + ++NumSelectsExpanded; + } // Instruct OptimizeBlock to skip to the next block. CurInstIterator = StartBlock->end(); - ++NumSelectsExpanded; return true; } Index: test/CodeGen/X86/pseudo_cmov_lower2.ll =================================================================== --- test/CodeGen/X86/pseudo_cmov_lower2.ll +++ test/CodeGen/X86/pseudo_cmov_lower2.ll @@ -98,3 +98,47 @@ %d5 = fdiv double %d4, %d3 ret double %d5 } + +; This test checks that only a single jae gets generated in the final code +; for lowering the CMOV pseudos that get created for this IR. The tricky part +; of this test is that it tests the special code in CodeGenPrepare. +; +; CHECK-LABEL: foo5: +; CHECK: jae +; CHECK-NOT: jae +define double @foo5(float %p1, double %p2, double %p3) nounwind { +entry: + %c1 = fcmp oge float %p1, 0.000000e+00 + %d0 = fadd double %p2, 1.25e0 + %d1 = fadd double %p3, 1.25e0 + %d2 = select i1 %c1, double %d0, double %d1, !prof !0 + %d3 = select i1 %c1, double %d2, double %p2, !prof !0 + %d4 = select i1 %c1, double %d3, double %p3, !prof !0 + %d5 = fsub double %d2, %d3 + %d6 = fadd double %d5, %d4 + ret double %d6 +} + +; We should expand select instructions into 3 conditional branches as their +; condtions are different. +; +; CHECK-LABEL: foo6: +; CHECK: jb +; CHECK: jae +; CHECK: jae +define double @foo6(float %p1, double %p2, double %p3) nounwind { +entry: + %c1 = fcmp oge float %p1, 0.000000e+00 + %c2 = fcmp oge float %p1, 1.000000e+00 + %c3 = fcmp oge float %p1, 2.000000e+00 + %d0 = fadd double %p2, 1.25e0 + %d1 = fadd double %p3, 1.25e0 + %d2 = select i1 %c1, double %d0, double %d1, !prof !0 + %d3 = select i1 %c2, double %d2, double %p2, !prof !0 + %d4 = select i1 %c3, double %d3, double %p3, !prof !0 + %d5 = fsub double %d2, %d3 + %d6 = fadd double %d5, %d4 + ret double %d6 +} + +!0 = !{!"branch_weights", i32 1, i32 2000}