Index: lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- lib/Transforms/Scalar/JumpThreading.cpp +++ lib/Transforms/Scalar/JumpThreading.cpp @@ -163,6 +163,8 @@ bool SimplifyPartiallyRedundantLoad(LoadInst *LI); bool TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB); + bool TryToUnfoldSelectInCurrBB(BasicBlock *BB); + SelectInst *ShouldUnfoldSelect(BasicBlock *BB); private: BasicBlock *SplitBlockPreds(BasicBlock *BB, ArrayRef Preds, @@ -730,6 +732,9 @@ } } + if (TryToUnfoldSelectInCurrBB(BB)) + return true; + // What kind of constant we're looking for. ConstantPreference Preference = WantInteger; @@ -1884,3 +1889,74 @@ } return false; } + +/// TryToUnfoldSelectInCurrBB - Look for PHI/Select in the same BB of the form +/// bb: +/// %p = phi [false, %bb1], [true, %bb2 ], [false, %bb3], [true, %bb4], ... +/// %s = select p, trueval, falseval +/// +/// And expand the select into a branch structure and this later enables +/// threading over bb. +bool JumpThreading::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { + if (SelectInst *SI = ShouldUnfoldSelect(BB)) { + BasicBlock *BottomBB = SplitBlock(BB, SI); + BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "select.unfold", + BB->getParent(), BottomBB); + BB->getTerminator()->eraseFromParent(); + BranchInst::Create(BottomBB, NewBB, SI->getCondition(), BB); + BranchInst::Create(BottomBB, NewBB); + PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); + NewPN->addIncoming(SI->getTrueValue(), BB); + NewPN->addIncoming(SI->getFalseValue(), NewBB); + SI->replaceAllUsesWith(NewPN); + SI->eraseFromParent(); + return true; + } + + return false; +} + +/// ShouldUnfoldSelect - find select that can be unfolded by +/// TryToUnfoldSelectInCurrBB +SelectInst *JumpThreading::ShouldUnfoldSelect(BasicBlock *BB) { + // Only searches BB that can be threaded + if (LoopHeaders.count(BB)) + return nullptr; + + // Look for the Phi/Select pair in the same basic block. + // PHI is the condition of the Select. + // The incoming values of phi has at least one constant. + for (BasicBlock::iterator BI = BB->begin(); + PHINode *PN = dyn_cast(BI); ++BI) { + unsigned NumPHIValues = PN->getNumIncomingValues(); + if(NumPHIValues == 0) + continue; + if(!PN->hasOneUse()) + continue; + + SelectInst *SI = dyn_cast(PN->use_begin()->getUser()); + if (!SI || SI->getParent() != BB || SI->getCondition() != PN) + continue; + + bool hasConst = false; + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InVal = PN->getIncomingValue(i); + BasicBlock *InBB = PN->getIncomingBlock(i); + if(InBB == BB) + return nullptr; + if (isa(InVal) && !isa(InVal)) + hasConst = true; + } + + if (!hasConst) + return nullptr; + + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); + if (JumpThreadCost > BBDupThreshold) + return nullptr; + + return SI; + } + + return nullptr; +} Index: test/Transforms/JumpThreading/select.ll =================================================================== --- test/Transforms/JumpThreading/select.ll +++ test/Transforms/JumpThreading/select.ll @@ -220,3 +220,40 @@ ; CHECK: br i1 %cmp6, label %if.then, label %if.end ; CHECK: br label %if.end } + + +define i32 @unfold3(i32 %u, i32 %v, i32 %w, i32 %x, i32 %y, i32 %z, i32 %j) nounwind { +entry: + %add3 = add nsw i32 %j, 2 + %cmp.i = icmp slt i32 %u, %v + br i1 %cmp.i, label %.exit, label %cond.false.i + +cond.false.i: ; preds = %entry + %cmp4.i = icmp sgt i32 %u, %v + br i1 %cmp4.i, label %.exit, label %cond.false.6.i + +cond.false.6.i: ; preds = %cond.false.i + %cmp8.i = icmp slt i32 %w, %x + br i1 %cmp8.i, label %.exit, label %cond.false.10.i + +cond.false.10.i: ; preds = %cond.false.6.i + %cmp13.i = icmp sgt i32 %w, %x + br i1 %cmp13.i, label %.exit, label %cond.false.15.i + +cond.false.15.i: ; preds = %cond.false.10.i + %phitmp = icmp sge i32 %y, %z + br label %.exit + +.exit: ; preds = %entry, %cond.false.i, %cond.false.6.i, %cond.false.10.i, %cond.false.15.i + %cond23.i = phi i1 [ false, %entry ], [ true, %cond.false.i ], [ false, %cond.false.6.i ], [ %phitmp, %cond.false.15.i ], [ true, %cond.false.10.i ] + %j.add3 = select i1 %cond23.i, i32 %j, i32 %add3 + ret i32 %j.add3 + +; CHECK-LABEL: @unfold3 +; CHECK: br i1 %cmp.i, label %select.unfold, label %cond.false.i +; CHECK: br i1 %cmp4.i, label %.exit.split, label %cond.false.6.i +; CHECK: br i1 %cmp8.i, label %select.unfold, label %cond.false.10.i +; CHECK: br i1 %cmp13.i, label %.exit.split, label %.exit +; CHECK: br i1 %phitmp, label %.exit.split, label %select.unfold +; CHECK: br label %.exit.split +}