Index: llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -19,10 +19,12 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -68,20 +70,65 @@ UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, cl::desc("The cost threshold for unswitching a loop.")); -static void replaceLoopUsesWithConstant(Loop &L, Value &LIC, - Constant &Replacement) { - assert(!isa(LIC) && "Why are we unswitching on a constant?"); +/// Collect all of the loop invariant input values transitively used by the +/// homogeneous instruction graph from a given root. +/// +/// This essentially walks from a root recursively through loop variant operands +/// which have the exact same opcode and finds all inputs which are loop +/// invariant. For some operations these can be re-associated and unswitched out +/// of the loop entirely. +static SmallVector +collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, + LoopInfo &LI) { + SmallVector Invariants; + assert(!L.isLoopInvariant(&Root) && + "Only need to walk the graph if root itself is not invariant."); + + // Build a worklist and recurse through operators collecting invariants. + SmallVector Worklist; + SmallPtrSet Visited; + Worklist.push_back(&Root); + Visited.insert(&Root); + do { + Instruction &I = *Worklist.pop_back_val(); + for (Value *OpV : I.operand_values()) { + // Skip constants as unswitching isn't interesting for them. + if (isa(OpV)) + continue; + + // Add it to our result if loop invariant. + if (L.isLoopInvariant(OpV)) { + Invariants.push_back(OpV); + continue; + } + + // If not an instruction with the same opcode, nothing we can do. + Instruction *OpI = dyn_cast(OpV); + if (!OpI || OpI->getOpcode() != Root.getOpcode()) + continue; + + // Visit this operand. + if (Visited.insert(OpI).second) + Worklist.push_back(OpI); + } + } while (!Worklist.empty()); + + return Invariants; +} + +static void replaceLoopInvariantUses(Loop &L, Value *Invariant, + Constant &Replacement) { + assert(!isa(Invariant) && "Why are we unswitching on a constant?"); // Replace uses of LIC in the loop with the given constant. - for (auto UI = LIC.use_begin(), UE = LIC.use_end(); UI != UE;) { + for (auto UI = Invariant->use_begin(), UE = Invariant->use_end(); UI != UE;) { // Grab the use and walk past it so we can clobber it in the use list. Use *U = &*UI++; Instruction *UserI = dyn_cast(U->getUser()); - if (!UserI || !L.contains(UserI)) - continue; // Replace this use within the loop body. - *U = &Replacement; + if (UserI && L.contains(UserI)) + U->set(&Replacement); } } @@ -135,7 +182,8 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, BasicBlock &UnswitchedBB, BasicBlock &OldExitingBB, - BasicBlock &OldPH) { + BasicBlock &OldPH, + bool FullUnswitch) { assert(&ExitBB != &UnswitchedBB && "Must have different loop exit and unswitched blocks!"); Instruction *InsertPt = &*UnswitchedBB.begin(); @@ -156,7 +204,11 @@ if (PN.getIncomingBlock(i) != &OldExitingBB) continue; - Value *Incoming = PN.removeIncomingValue(i); + Value *Incoming = PN.getIncomingValue(i); + if (FullUnswitch) + // No more edge from the old exiting block to the exit block. + PN.removeIncomingValue(i); + NewPN->addIncoming(Incoming, &OldPH); } @@ -186,22 +238,30 @@ assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); - Value *LoopCond = BI.getCondition(); + // The loop invariant values that we want to unswitch. + SmallVector Invariants; - // Need a trivial loop condition to unswitch. - if (!L.isLoopInvariant(LoopCond)) - return false; + // When true, we're fully unswitching the branch rather than just unswitching + // some input conditions to the branch. + bool FullUnswitch = false; + + if (L.isLoopInvariant(BI.getCondition())) { + Invariants.push_back(BI.getCondition()); + FullUnswitch = true; + } else { + if (auto *CondInst = dyn_cast(BI.getCondition())) + Invariants = collectHomogenousInstGraphLoopInvariants(L, *CondInst, LI); + if (Invariants.empty()) + // Couldn't find invariant inputs! + return false; + } - // Check to see if a successor of the branch is guaranteed to - // exit through a unique exit block without having any - // side-effects. If so, determine the value of Cond that causes - // it to do this. - ConstantInt *CondVal = ConstantInt::getTrue(BI.getContext()); - ConstantInt *Replacement = ConstantInt::getFalse(BI.getContext()); + // Check that one of the branch's successors exits, and which one. + bool ExitDirection = true; int LoopExitSuccIdx = 0; auto *LoopExitBB = BI.getSuccessor(0); if (L.contains(LoopExitBB)) { - std::swap(CondVal, Replacement); + ExitDirection = false; LoopExitSuccIdx = 1; LoopExitBB = BI.getSuccessor(1); if (L.contains(LoopExitBB)) @@ -212,8 +272,31 @@ if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) return false; - LLVM_DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal - << " == " << LoopCond << "\n"); + // When unswitching only part of the branch's condition, we need the exit + // block to be reached directly from the partially unswitched input. This can + // be done when the exit block is along the true edge and the branch condition + // is a graph of `or` operations, or the exit block is along the false edge + // and the condition is a graph of `and` operations. + if (!FullUnswitch) { + if (ExitDirection) { + if (cast(BI.getCondition())->getOpcode() != Instruction::Or) + return false; + } else { + if (cast(BI.getCondition())->getOpcode() != Instruction::And) + return false; + } + } + + LLVM_DEBUG({ + dbgs() << " unswitching trivial invariant conditions for: " << BI + << "\n"; + for (Value *Invariant : Invariants) { + dbgs() << " " << *Invariant << " == true"; + if (Invariant != Invariants.back()) + dbgs() << " ||"; + dbgs() << "\n"; + } + }); // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional @@ -226,41 +309,79 @@ // unswitching. We need to split this if there are other loop predecessors. // Because the loop is in simplified form, *any* other predecessor is enough. BasicBlock *UnswitchedBB; - if (BasicBlock *PredBB = LoopExitBB->getUniquePredecessor()) { - (void)PredBB; - assert(PredBB == BI.getParent() && + if (FullUnswitch && LoopExitBB->getUniquePredecessor()) { + assert(LoopExitBB->getUniquePredecessor() == BI.getParent() && "A branch's parent isn't a predecessor!"); UnswitchedBB = LoopExitBB; } else { UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); } - // Now splice the branch to gate reaching the new preheader and re-point its - // successors. - OldPH->getInstList().splice(std::prev(OldPH->end()), - BI.getParent()->getInstList(), BI); + // Actually move the invariant uses into the unswitched position. If possible, + // we do this by moving the instructions, but when doing partial unswitching + // we do it by building a new merge of the values in the unswitched position. OldPH->getTerminator()->eraseFromParent(); - BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); - BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); - - // Create a new unconditional branch that will continue the loop as a new - // terminator. - BranchInst::Create(ContinueBB, ParentBB); + if (FullUnswitch) { + // If fully unswitching, we can use the existing branch instruction. + // Splice it into the old PH to gate reaching the new preheader and re-point + // its successors. + OldPH->getInstList().splice(OldPH->end(), BI.getParent()->getInstList(), + BI); + BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); + BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); + + // Create a new unconditional branch that will continue the loop as a new + // terminator. + BranchInst::Create(ContinueBB, ParentBB); + } else { + // Only unswitching a subset of inputs to the condition, so we will need to + // build a new branch that merges the invariant inputs. + IRBuilder<> IRB(OldPH); + Value *Cond = Invariants.front(); + if (ExitDirection) + assert(cast(BI.getCondition())->getOpcode() == + Instruction::Or && + "Must have an `or` of `i1`s for the condition!"); + else + assert(cast(BI.getCondition())->getOpcode() == + Instruction::And && + "Must have an `and` of `i1`s for the condition!"); + for (Value *Invariant : + make_range(std::next(Invariants.begin()), Invariants.end())) + if (ExitDirection) + Cond = IRB.CreateOr(Cond, Invariant); + else + Cond = IRB.CreateAnd(Cond, Invariant); + + BasicBlock *Succs[2]; + Succs[LoopExitSuccIdx] = UnswitchedBB; + Succs[1 - LoopExitSuccIdx] = NewPH; + IRB.CreateCondBr(Cond, Succs[0], Succs[1]); + } // Rewrite the relevant PHI nodes. if (UnswitchedBB == LoopExitBB) rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH); else rewritePHINodesForExitAndUnswitchedBlocks(*LoopExitBB, *UnswitchedBB, - *ParentBB, *OldPH); + *ParentBB, *OldPH, FullUnswitch); // Now we need to update the dominator tree. - DT.applyUpdates( - {{DT.Delete, ParentBB, UnswitchedBB}, {DT.Insert, OldPH, UnswitchedBB}}); + DT.insertEdge(OldPH, UnswitchedBB); + if (FullUnswitch) + DT.deleteEdge(ParentBB, UnswitchedBB); + + // The constant we can replace all of our invariants with inside the loop + // body. If any of the invariants have a value other than this the loop won't + // be entered. + ConstantInt *Replacement = ExitDirection + ? ConstantInt::getFalse(BI.getContext()) + : ConstantInt::getTrue(BI.getContext()); // Since this is an i1 condition we can also trivially replace uses of it // within the loop with a constant. - replaceLoopUsesWithConstant(L, *LoopCond, *Replacement); + for (Value *Invariant : Invariants) + replaceLoopInvariantUses(L, Invariant, *Replacement); ++NumTrivial; ++NumBranches; @@ -393,8 +514,8 @@ } else { auto *SplitBB = SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); - rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, - *ParentBB, *OldPH); + rewritePHINodesForExitAndUnswitchedBlocks( + *DefaultExitBB, *SplitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; } } @@ -419,8 +540,8 @@ if (!SplitExitBB) { // If this is the first time we see this, do the split and remember it. SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); - rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, - *ParentBB, *OldPH); + rewritePHINodesForExitAndUnswitchedBlocks( + *ExitBB, *SplitExitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); } // Update the case pair to point to the split block. CasePair.second = SplitExitBB; @@ -560,11 +681,13 @@ // Mark that we managed to unswitch something. Changed = true; - // We unswitched the branch. This should always leave us with an - // unconditional branch that we can follow now. + // If we only unswitched some of the conditions feeding the branch, we won't + // have collapsed it to a single successor. BI = cast(CurrentBB->getTerminator()); - assert(!BI->isConditional() && - "Cannot form a conditional branch by unswitching1"); + if (BI->isConditional()) + return Changed; + + // Follow the newly unconditional branch into its successor. CurrentBB = BI->getSuccessor(0); // When continuing, if we exit the loop or reach a previous visited block, @@ -956,8 +1079,7 @@ // matter as we're just trying to build up the map from inside-out; we use // the map in a more stably ordered way below. auto OrderedClonedExitsInLoops = ClonedExitsInLoops; - llvm::sort(OrderedClonedExitsInLoops.begin(), - OrderedClonedExitsInLoops.end(), + llvm::sort(OrderedClonedExitsInLoops.begin(), OrderedClonedExitsInLoops.end(), [&](BasicBlock *LHS, BasicBlock *RHS) { return ExitLoopMap.lookup(LHS)->getLoopDepth() < ExitLoopMap.lookup(RHS)->getLoopDepth(); Index: llvm/trunk/test/Transforms/SimpleLoopUnswitch/LIV-loop-condtion.ll =================================================================== --- llvm/trunk/test/Transforms/SimpleLoopUnswitch/LIV-loop-condtion.ll +++ llvm/trunk/test/Transforms/SimpleLoopUnswitch/LIV-loop-condtion.ll @@ -4,17 +4,25 @@ ; itself is an LIV loop condition (not partial LIV which could occur in and/or). define i32 @test(i1 %cond1, i32 %var1) { +; CHECK-LABEL: define i32 @test( entry: br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %cond1, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin loop_begin: %var3 = phi i32 [%var1, %entry], [%var2, %do_something] %cond2 = icmp eq i32 %var3, 10 %cond.and = and i1 %cond1, %cond2 - -; %cond.and only has %cond1 as LIV so no unswitch should happen. -; CHECK: br i1 %cond.and, label %do_something, label %loop_exit - br i1 %cond.and, label %do_something, label %loop_exit + br i1 %cond.and, label %do_something, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: %[[VAR3:.*]] = phi i32 +; CHECK-NEXT: %[[COND2:.*]] = icmp eq i32 %[[VAR3]], 10 +; CHECK-NEXT: %[[COND_AND:.*]] = and i1 true, %[[COND2]] +; CHECK-NEXT: br i1 %[[COND_AND]], label %do_something, label %loop_exit do_something: %var2 = add i32 %var3, 1 Index: llvm/trunk/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll =================================================================== --- llvm/trunk/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll +++ llvm/trunk/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll @@ -443,3 +443,179 @@ ; CHECK: cleanup: ; CHECK-NEXT: ret void } + +define i32 @test_partial_condition_unswitch_and(i32* %var, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @test_partial_condition_unswitch_and( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %cond1, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: br i1 %cond2, label %entry.split.split, label %loop_exit +; +; CHECK: entry.split.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + br i1 %cond1, label %continue, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: br label %continue + +continue: + %var_val = load i32, i32* %var + %var_cond = trunc i32 %var_val to i1 + %cond_and = and i1 %var_cond, %cond2 + br i1 %cond_and, label %do_something, label %loop_exit +; CHECK: continue: +; CHECK-NEXT: %[[VAR:.*]] = load i32 +; CHECK-NEXT: %[[VAR_COND:.*]] = trunc i32 %[[VAR]] to i1 +; CHECK-NEXT: %[[COND_AND:.*]] = and i1 %[[VAR_COND]], true +; CHECK-NEXT: br i1 %[[COND_AND]], label %do_something, label %loop_exit + +do_something: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: do_something: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + ret i32 0 +; CHECK: loop_exit: +; CHECK-NEXT: br label %loop_exit.split +; +; CHECK: loop_exit.split: +; CHECK-NEXT: ret +} + +define i32 @test_partial_condition_unswitch_or(i32* %var, i1 %cond1, i1 %cond2, i1 %cond3, i1 %cond4, i1 %cond5, i1 %cond6) { +; CHECK-LABEL: @test_partial_condition_unswitch_or( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[INV_OR1:.*]] = or i1 %cond4, %cond2 +; CHECK-NEXT: %[[INV_OR2:.*]] = or i1 %[[INV_OR1]], %cond3 +; CHECK-NEXT: %[[INV_OR3:.*]] = or i1 %[[INV_OR2]], %cond1 +; CHECK-NEXT: br i1 %[[INV_OR3]], label %loop_exit.split, label %entry.split +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + %var_cond = trunc i32 %var_val to i1 + %cond_or1 = or i1 %var_cond, %cond1 + %cond_or2 = or i1 %cond2, %cond3 + %cond_or3 = or i1 %cond_or1, %cond_or2 + %cond_xor1 = xor i1 %cond5, %var_cond + %cond_and1 = and i1 %cond6, %var_cond + %cond_or4 = or i1 %cond_xor1, %cond_and1 + %cond_or5 = or i1 %cond_or3, %cond_or4 + %cond_or6 = or i1 %cond_or5, %cond4 + br i1 %cond_or6, label %loop_exit, label %do_something +; CHECK: loop_begin: +; CHECK-NEXT: %[[VAR:.*]] = load i32 +; CHECK-NEXT: %[[VAR_COND:.*]] = trunc i32 %[[VAR]] to i1 +; CHECK-NEXT: %[[COND_OR1:.*]] = or i1 %[[VAR_COND]], false +; CHECK-NEXT: %[[COND_OR2:.*]] = or i1 false, false +; CHECK-NEXT: %[[COND_OR3:.*]] = or i1 %[[COND_OR1]], %[[COND_OR2]] +; CHECK-NEXT: %[[COND_XOR:.*]] = xor i1 %cond5, %[[VAR_COND]] +; CHECK-NEXT: %[[COND_AND:.*]] = and i1 %cond6, %[[VAR_COND]] +; CHECK-NEXT: %[[COND_OR4:.*]] = or i1 %[[COND_XOR]], %[[COND_AND]] +; CHECK-NEXT: %[[COND_OR5:.*]] = or i1 %[[COND_OR3]], %[[COND_OR4]] +; CHECK-NEXT: %[[COND_OR6:.*]] = or i1 %[[COND_OR5]], false +; CHECK-NEXT: br i1 %[[COND_OR6]], label %loop_exit, label %do_something + +do_something: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: do_something: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + ret i32 0 +; CHECK: loop_exit.split: +; CHECK-NEXT: ret +} + +define i32 @test_partial_condition_unswitch_with_lcssa_phi1(i32* %var, i1 %cond, i32 %x) { +; CHECK-LABEL: @test_partial_condition_unswitch_with_lcssa_phi1( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %cond, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + %var_cond = trunc i32 %var_val to i1 + %cond_and = and i1 %var_cond, %cond + br i1 %cond_and, label %do_something, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: %[[VAR:.*]] = load i32 +; CHECK-NEXT: %[[VAR_COND:.*]] = trunc i32 %[[VAR]] to i1 +; CHECK-NEXT: %[[COND_AND:.*]] = and i1 %[[VAR_COND]], true +; CHECK-NEXT: br i1 %[[COND_AND]], label %do_something, label %loop_exit + +do_something: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: do_something: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + %x.lcssa = phi i32 [ %x, %loop_begin ] + ret i32 %x.lcssa +; CHECK: loop_exit: +; CHECK-NEXT: %[[LCSSA:.*]] = phi i32 [ %x, %loop_begin ] +; CHECK-NEXT: br label %loop_exit.split +; +; CHECK: loop_exit.split: +; CHECK-NEXT: %[[LCSSA_SPLIT:.*]] = phi i32 [ %x, %entry ], [ %[[LCSSA]], %loop_exit ] +; CHECK-NEXT: ret i32 %[[LCSSA_SPLIT]] +} + +define i32 @test_partial_condition_unswitch_with_lcssa_phi2(i32* %var, i1 %cond, i32 %x, i32 %y) { +; CHECK-LABEL: @test_partial_condition_unswitch_with_lcssa_phi2( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %cond, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + %var_cond = trunc i32 %var_val to i1 + %cond_and = and i1 %var_cond, %cond + br i1 %cond_and, label %do_something, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: %[[VAR:.*]] = load i32 +; CHECK-NEXT: %[[VAR_COND:.*]] = trunc i32 %[[VAR]] to i1 +; CHECK-NEXT: %[[COND_AND:.*]] = and i1 %[[VAR_COND]], true +; CHECK-NEXT: br i1 %[[COND_AND]], label %do_something, label %loop_exit + +do_something: + call void @some_func() noreturn nounwind + br i1 %var_cond, label %loop_begin, label %loop_exit +; CHECK: do_something: +; CHECK-NEXT: call +; CHECK-NEXT: br i1 %[[VAR_COND]], label %loop_begin, label %loop_exit + +loop_exit: + %xy.lcssa = phi i32 [ %x, %loop_begin ], [ %y, %do_something ] + ret i32 %xy.lcssa +; CHECK: loop_exit: +; CHECK-NEXT: %[[LCSSA:.*]] = phi i32 [ %x, %loop_begin ], [ %y, %do_something ] +; CHECK-NEXT: br label %loop_exit.split +; +; CHECK: loop_exit.split: +; CHECK-NEXT: %[[LCSSA_SPLIT:.*]] = phi i32 [ %x, %entry ], [ %[[LCSSA]], %loop_exit ] +; CHECK-NEXT: ret i32 %[[LCSSA_SPLIT]] +}