Index: include/llvm/Transforms/Scalar/JumpThreading.h =================================================================== --- include/llvm/Transforms/Scalar/JumpThreading.h +++ include/llvm/Transforms/Scalar/JumpThreading.h @@ -110,7 +110,7 @@ void FindLoopHeaders(Function &F); bool ProcessBlock(BasicBlock *BB); bool ThreadEdge(BasicBlock *BB, const SmallVectorImpl &PredBBs, - BasicBlock *SuccBB); + BasicBlock *SuccBB, Value *PredVal = nullptr); bool DuplicateCondBranchOnPHIIntoPred( BasicBlock *BB, const SmallVectorImpl &PredBBs); Index: lib/Transforms/Scalar/JumpThreading.cpp =================================================================== --- lib/Transforms/Scalar/JumpThreading.cpp +++ lib/Transforms/Scalar/JumpThreading.cpp @@ -1071,6 +1071,14 @@ if (IB->getNumSuccessors() == 0) return false; Condition = IB->getAddress()->stripPointerCasts(); Preference = WantBlockAddress; + } else if (ReturnInst *RetInst = dyn_cast(Terminator)) { + auto *RV = RetInst->getReturnValue(); + if (!RV) + return false; + CmpInst *Cmp = dyn_cast(RV); + if (!Cmp) + return false; + Condition = Cmp; } else { return false; // Must be an invoke or callbr. } @@ -1090,7 +1098,7 @@ // If the terminator is branching on an undef, we can pick any of the // successors to branch to. Let GetBestDestForJumpOnUndef decide. - if (isa(Condition)) { + if (isa(Condition) && BB->getTerminator()->getNumSuccessors()) { unsigned BestSucc = GetBestDestForJumpOnUndef(BB); std::vector Updates; @@ -1631,17 +1639,21 @@ Constant *Val = PredValue.first; BasicBlock *DestBB; - if (isa(Val)) - DestBB = nullptr; - else if (BranchInst *BI = dyn_cast(BB->getTerminator())) { + auto *TI = BB->getTerminator(); + if (isa(Val) && TI->getNumSuccessors() != 0) + DestBB = TI->getSuccessor(GetBestDestForJumpOnUndef(BB)); + else if (BranchInst *BI = dyn_cast(TI)) { assert(isa(Val) && "Expecting a constant integer"); DestBB = BI->getSuccessor(cast(Val)->isZero()); - } else if (SwitchInst *SI = dyn_cast(BB->getTerminator())) { + } else if (SwitchInst *SI = dyn_cast(TI)) { assert(isa(Val) && "Expecting a constant integer"); DestBB = SI->findCaseValue(cast(Val))->getCaseSuccessor(); + } else if (ReturnInst *RI = dyn_cast(TI)) { + assert(RI->getReturnValue() == Cond && + "Expecting comparison result returned"); + DestBB = nullptr; } else { - assert(isa(BB->getTerminator()) - && "Unexpected terminator"); + assert(isa(TI) && "Unexpected terminator"); assert(isa(Val) && "Expecting a constant blockaddress"); DestBB = cast(Val)->getBasicBlock(); } @@ -1711,6 +1723,7 @@ CondInst->getParent() == BB) ReplaceFoldableUses(CondInst, OnlyVal); } + SimplifyInstructionsInBlock(BB, TLI); return true; } } @@ -1739,26 +1752,36 @@ // Now that we know what the most popular destination is, factor all // predecessors that will jump to it into a single predecessor. SmallVector PredsToFactor; - for (const auto &PredToDest : PredToDestList) - if (PredToDest.second == MostPopularDest) { - BasicBlock *Pred = PredToDest.first; + Value *PredVal = nullptr; + if (MostPopularDest) { + for (const auto &PredToDest : PredToDestList) + if (PredToDest.second == MostPopularDest) { + BasicBlock *Pred = PredToDest.first; - // This predecessor may be a switch or something else that has multiple - // edges to the block. Factor each of these edges by listing them - // according to # occurrences in PredsToFactor. - for (BasicBlock *Succ : successors(Pred)) - if (Succ == BB) - PredsToFactor.push_back(Pred); + // This predecessor may be a switch or something else that has multiple + // edges to the block. Factor each of these edges by listing them + // according to # occurrences in PredsToFactor. + for (BasicBlock *Succ : successors(Pred)) + if (Succ == BB) + PredsToFactor.push_back(Pred); + } + } else { + LLVM_DEBUG(auto *TI = BB->getTerminator(); + assert(isa(TI));); + auto *Pred = PredToDestList.begin()->first; + PredsToFactor.push_back(Pred); + // We need to find the pred val associated with the selected pred: + for (const auto &PredValue : PredValues) { + if (PredValue.second == Pred) { + PredVal = PredValue.first; + break; + } } + assert(PredVal && "non null pred val expected!"); + } - // If the threadable edges are branching on an undefined value, we get to pick - // the destination that these predecessors should get to. - if (!MostPopularDest) - MostPopularDest = BB->getTerminator()-> - getSuccessor(GetBestDestForJumpOnUndef(BB)); - // Ok, try to thread it! - return ThreadEdge(BB, PredsToFactor, MostPopularDest); + return ThreadEdge(BB, PredsToFactor, MostPopularDest, PredVal); } /// ProcessBranchOnPHI - We have an otherwise unthreadable conditional branch on @@ -1925,7 +1948,7 @@ /// across BB. Transform the IR to reflect this change. bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, const SmallVectorImpl &PredBBs, - BasicBlock *SuccBB) { + BasicBlock *SuccBB, Value *PredVal) { // If threading to the same block as we come from, we would infinite loop. if (SuccBB == BB) { LLVM_DEBUG(dbgs() << " Not threading across BB '" << BB->getName() @@ -1966,10 +1989,14 @@ } // And finally, do it! - LLVM_DEBUG(dbgs() << " Threading edge from '" << PredBB->getName() - << "' to '" << SuccBB->getName() - << "' with cost: " << JumpThreadCost - << ", across block:\n " << *BB << "\n"); + if (SuccBB) + LLVM_DEBUG(dbgs() << " Threading edge from '" << PredBB->getName() + << "' to '" << SuccBB->getName() + << "' with cost: " << JumpThreadCost << "\n"); + else + LLVM_DEBUG(dbgs() << " Cloning BB '" << BB->getName() << "' into '" + << PredBB->getName() << "' with cost: " << JumpThreadCost + << ", across block:\n " << *BB << "\n"); if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); @@ -2006,8 +2033,11 @@ // Clone the non-phi instructions of BB into NewBB, keeping track of the // mapping and using it to remap operands in the cloned instructions. - for (; !BI->isTerminator(); ++BI) { + for (; BI != BB->end(); ++BI) { + if (SuccBB && BI->isTerminator()) + break; Instruction *New = BI->clone(); + assert(!BI->isTerminator() || isa(New)); New->setName(BI->getName()); NewBB->getInstList().push_back(New); ValueMapping[&*BI] = New; @@ -2019,17 +2049,22 @@ if (I != ValueMapping.end()) New->setOperand(i, I->second); } + if (ReturnInst *RetInst = dyn_cast(New)) { + assert(PredVal && !SuccBB); + New->replaceUsesOfWith(RetInst->getReturnValue(), PredVal); + } } // We didn't copy the terminator from BB over to NewBB, because there is now // an unconditional jump to SuccBB. Insert the unconditional jump. - BranchInst *NewBI = BranchInst::Create(SuccBB, NewBB); - NewBI->setDebugLoc(BB->getTerminator()->getDebugLoc()); + BranchInst *NewBI = SuccBB ? BranchInst::Create(SuccBB, NewBB) : nullptr; + if (NewBI) { + NewBI->setDebugLoc(BB->getTerminator()->getDebugLoc()); + // Check to see if SuccBB has PHI nodes. If so, we need to add entries to + // the PHI nodes for NewBB now. + AddPHINodeEntriesForMappedBlock(SuccBB, BB, NewBB, ValueMapping); + } - // Check to see if SuccBB has PHI nodes. If so, we need to add entries to the - // PHI nodes for NewBB now. - AddPHINodeEntriesForMappedBlock(SuccBB, BB, NewBB, ValueMapping); - // Update the terminator of PredBB to jump to NewBB instead of BB. This // eliminates predecessors from BB, which requires us to simplify any PHI // nodes in BB. @@ -2090,7 +2125,8 @@ SimplifyInstructionsInBlock(NewBB, TLI); // Update the edge weight from BB to SuccBB, which should be less than before. - UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB); + if (NewBI) + UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB); // Threaded an edge! ++NumThreads; @@ -2565,11 +2601,19 @@ if (LoopHeaders.count(BB)) return false; + auto IsConstOrAddr = [](Value *V) { + if (isa(V)) + return true; + V = V->stripInBoundsConstantOffsets(); + if (isa(V) || isa(V) || isa(V)) + return true; + return false; + }; + for (BasicBlock::iterator BI = BB->begin(); PHINode *PN = dyn_cast(BI); ++BI) { // Look for a Phi having at least one constant incoming value. - if (llvm::all_of(PN->incoming_values(), - [](Value *V) { return !isa(V); })) + if (llvm::none_of(PN->incoming_values(), IsConstOrAddr)) continue; auto isUnfoldCandidate = [BB](SelectInst *SI, Value *V) { @@ -2586,7 +2630,7 @@ // Look for a ICmp in BB that compares PN with a constant and is the // condition of a Select. if (Cmp->getParent() == BB && Cmp->hasOneUse() && - isa(Cmp->getOperand(1 - U.getOperandNo()))) + IsConstOrAddr(Cmp->getOperand(1 - U.getOperandNo()))) if (SelectInst *SelectI = dyn_cast(Cmp->user_back())) if (isUnfoldCandidate(SelectI, Cmp->use_begin()->get())) { SI = SelectI; Index: test/Transforms/JumpThreading/addr.ll =================================================================== --- test/Transforms/JumpThreading/addr.ll +++ test/Transforms/JumpThreading/addr.ll @@ -0,0 +1,78 @@ +; RUN: opt -jump-threading -S < %s | FileCheck %s +%"struct.std::array" = type { [3 x i32] } + +@r = dso_local local_unnamed_addr global i32 0, align 4 +@_ZL1x = internal constant %"struct.std::array" { [3 x i32] [i32 1, i32 7, i32 17] }, align 4 + +; Function Attrs: nofree norecurse nounwind uwtable +define dso_local void @foo1(i32 %0) local_unnamed_addr #0 { + switch i32 %0, label %2 [ + i32 1, label %6 + i32 7, label %5 + ] + +2: ; preds = %1 + %3 = icmp eq i32 %0, 17 + %4 = select i1 %3, i32* getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 2), i32* getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 1, i32 0, i64 0) + br label %6 + +5: ; preds = %1 + br label %6 + +6: ; preds = %1, %5, %2 + %7 = phi i32* [ getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 0), %1 ], [ %4, %2 ], [ getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 1), %5 ] +; CHECK: [[VAL:%.*]] = phi i32 [ 10,{{.*}}], [ 20, {{.*}}] + %8 = icmp eq i32* %7, getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 1, i32 0, i64 0) +; CHECK-NOT: icmp +; CHECK-NOT: select + %9 = select i1 %8, i32 20, i32 10 + store i32 %9, i32* @r, align 4 +; CHECK: store i32 [[VAL]], i32* @r + ret void +} + +define dso_local i32 @foo2(i32 %0) local_unnamed_addr #0 { + %2 = alloca [100 x i32], align 16 + %3 = bitcast [100 x i32]* %2 to i8* + call void @llvm.lifetime.start.p0i8(i64 400, i8* nonnull %3) #2 + switch i32 %0, label %4 [ + i32 1, label %8 + i32 7, label %7 + ] + +4: ; preds = %1 + %5 = icmp eq i32 %0, 17 + %6 = select i1 %5, i32* getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 2), i32* getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 1, i32 0, i64 0) + br label %8 + +7: ; preds = %1 + br label %8 + +8: ; preds = %1, %7, %4 + %9 = phi i32* [ getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 0), %1 ], [ %6, %4 ], [ getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 1), %7 ] +; CHECK: [[VAL:%.*]] = phi i32 [ 10,{{.*}}], [ 20, {{.*}}] + %10 = icmp eq i32* %9, getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 1, i32 0, i64 0) + %11 = select i1 %10, i32 20, i32 10 +; CHECK-NOT: select + store i32 %11, i32* @r, align 4 +; CHECK: store i32 [[VAL]] + %12 = zext i32 %11 to i64 + %13 = getelementptr inbounds [100 x i32], [100 x i32]* %2, i64 0, i64 %12 + store i32 10, i32* %13, align 8 + %14 = getelementptr inbounds [100 x i32], [100 x i32]* %2, i64 0, i64 10 + %15 = load i32, i32* %14, align 8 + call void @llvm.lifetime.end.p0i8(i64 400, i8* nonnull %3) #2 + ret i32 %15 +} + +; Function Attrs: argmemonly nounwind willreturn +declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1 + +; Function Attrs: argmemonly nounwind willreturn +declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1 + +attributes #0 = { nounwind uwtable } +attributes #1 = { argmemonly nounwind willreturn } +attributes #2 = { nounwind } + + Index: test/Transforms/JumpThreading/return.ll =================================================================== --- test/Transforms/JumpThreading/return.ll +++ test/Transforms/JumpThreading/return.ll @@ -0,0 +1,101 @@ +; RUN: opt -jump-threading -S < %s | FileCheck %s +%"struct.std::array" = type { [3 x i32] } + +@_ZL1x = internal constant %"struct.std::array" { [3 x i32] [i32 1, i32 7, i32 17] }, align 4 +@g = dso_local global [10 x i32] zeroinitializer, align 16 + +; Function Attrs: norecurse nounwind readonly uwtable +define dso_local zeroext i1 @foo1(i32 %0) local_unnamed_addr #0 { + switch i32 %0, label %2 [ + i32 1, label %6 + i32 7, label %5 + ] + +2: ; preds = %1 + %3 = icmp eq i32 %0, 17 + %4 = select i1 %3, i32* getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 2), i32* getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 1, i32 0, i64 0) + br label %6 +; CHECK: ret i1 true + +5: ; preds = %1 + br label %6 +; CHECK: ret i1 true + +6: ; preds = %1, %5, %2 + %7 = phi i32* [ getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 0), %1 ], [ %4, %2 ], [ getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 0, i32 0, i64 1), %5 ] + %8 = icmp ne i32* %7, getelementptr inbounds (%"struct.std::array", %"struct.std::array"* @_ZL1x, i64 1, i32 0, i64 0) + ret i1 %8 +} + +; Function Attrs: norecurse nounwind readnone uwtable +define dso_local zeroext i1 @foo2(i32 %0, i32* readnone %1) local_unnamed_addr #0 { + %3 = icmp sgt i32 %0, 5 + br i1 %3, label %10, label %4 +; CHECK: ret i1 true + +4: ; preds = %2 + %5 = icmp sgt i32 %0, 1 + br i1 %5, label %10, label %6 +; CHECK: ret i1 false + +6: ; preds = %4 + %7 = icmp eq i32 %0, 1 + %8 = getelementptr inbounds i32, i32* %1, i64 1 + %9 = select i1 %7, i32* %1, i32* %8 + br label %10 + +10: ; preds = %6, %4, %2 + %11 = phi i32* [ getelementptr inbounds ([10 x i32], [10 x i32]* @g, i64 0, i64 0), %2 ], [ getelementptr inbounds ([10 x i32], [10 x i32]* @g, i64 0, i64 1), %4 ], [ %9, %6 ] + %12 = icmp eq i32* %11, getelementptr inbounds ([10 x i32], [10 x i32]* @g, i64 0, i64 0) + ret i1 %12 +} + +define linkonce_odr hidden i1 @foo3() { +entry: + br label %land.rhs7 + +land.rhs7: ; preds = %entry + br i1 undef, label %land.rhs22, label %lor.lhs.false17 + +lor.lhs.false17: ; preds = %land.rhs7 + unreachable + +land.rhs22: ; preds = %land.rhs7 + %tobool30 = icmp ne i32 undef, 0 + ret i1 %tobool30 +; CHECK: ret i1 undef +} + +define linkonce_odr hidden i1 @foo4() { +entry: + %neg = and i32 undef, 2097152 + %tobool = icmp eq i32 %neg, 0 + br label %land.rhs7 + +land.rhs7: ; preds = %entry + br i1 %tobool, label %land.rhs22, label %lor.lhs.false17 +; CHECK: ret i1 false + +lor.lhs.false17: ; preds = %land.rhs7 + unreachable + +land.rhs22: ; preds = %land.rhs7 + %tobool30 = icmp ne i32 %neg, 0 + ret i1 %tobool30 +} + +define i1 @foo5() { + %l = load i8, i8* undef, align 8 + %i = icmp ne i8 %l, 0 + br i1 %i, label %t1, label %t2 +; CHECK: ret i1 false +t1: ; preds = %2 + unreachable +t2: ; preds = %2 + ret i1 %i +} + + + +attributes #0 = { norecurse nounwind readonly uwtable } +