Index: llvm/include/llvm/Analysis/Utils/Local.h =================================================================== --- llvm/include/llvm/Analysis/Utils/Local.h +++ llvm/include/llvm/Analysis/Utils/Local.h @@ -141,6 +141,18 @@ bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI = nullptr); +/// Delete all of the instructions in the provided vector, and all other +/// instructions that deleting these in turn causes to be trivially dead. +/// +/// The initial instructions in the provided vector must all have empty use +/// lists and satisfy `isInstructionTriviallyDead`. +/// +/// `DeadInsts` will be used as scratch storage for this routine and will be +/// empty afterward. +void RecursivelyDeleteTriviallyDeadInstructions( + SmallVectorImpl &DeadInsts, + const TargetLibraryInfo *TLI = nullptr); + /// If the specified value is an effectively dead PHI node, due to being a /// def-use chain of single-use nodes that either forms a cycle or is terminated /// by a trivially dead instruction, delete it. If that makes any of its Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -19,10 +19,12 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -68,21 +70,175 @@ UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, cl::desc("The cost threshold for unswitching a loop.")); -static void replaceLoopUsesWithConstant(Loop &L, Value &LIC, - Constant &Replacement) { - assert(!isa(LIC) && "Why are we unswitching on a constant?"); +/// Collect all of the loop invariant input values transitively used by the +/// homogeneous instruction graph from a given root. +/// +/// This essentially walks from a root recursively through loop variant operands +/// which have the exact same opcode and finds all inputs which are loop +/// invariant. For some operations these can be re-associated and unswitched out +/// of the loop entirely. +static SmallVector +collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, + LoopInfo &LI) { + SmallVector Invariants; - // Replace uses of LIC in the loop with the given constant. - for (auto UI = LIC.use_begin(), UE = LIC.use_end(); UI != UE;) { - // Grab the use and walk past it so we can clobber it in the use list. - Use *U = &*UI++; - Instruction *UserI = dyn_cast(U->getUser()); - if (!UserI || !L.contains(UserI)) - continue; + // Build a worklist and recurse through operators collecting invariants. + SmallVector Worklist; + SmallPtrSet Visited; + Worklist.push_back(&Root); + Visited.insert(&Root); + do { + Instruction &I = *Worklist.pop_back_val(); + for (Value *OpV : I.operand_values()) { + // Skip constants as unswitching isn't interesting for them. + if (isa(OpV)) + continue; + + // Add it to our result if loop invariant. + if (L.isLoopInvariant(OpV)) { + Invariants.push_back(OpV); + continue; + } + + // If not an instruction with the same opcode, nothing we can do. + Instruction *OpI = dyn_cast(OpV); + if (!OpI || OpI->getOpcode() != Root.getOpcode()) + continue; + + // Visit this operand. + if (Visited.insert(OpI).second) + Worklist.push_back(OpI); + } + } while (!Worklist.empty()); + + return Invariants; +} - // Replace this use within the loop body. - *U = &Replacement; +static void replaceLoopInvariantUsesAndSimplify(Loop &L, + ArrayRef Invariants, + Constant &Replacement, + LoopInfo &LI) { + SmallPtrSet SimplifyInstSet; + SmallPtrSet SimplifyBlockSet; + + for (Value *Invariant : Invariants) { + assert(!isa(Invariant) && + "Why are we unswitching on a constant?"); + + // Replace uses of LIC in the loop with the given constant. + for (auto UI = Invariant->use_begin(), UE = Invariant->use_end(); + UI != UE;) { + // Grab the use and walk past it so we can clobber it in the use list. + Use *U = &*UI++; + Instruction *UserI = dyn_cast(U->getUser()); + if (!UserI || !L.contains(UserI)) + continue; + + // Replace this use within the loop body. + *U = &Replacement; + + // Skip some instruction kinds that aren't interesting to simplify to + // avoid the cost of tracking them. + if (isa(UserI) || isa(UserI) || + isa(UserI)) + continue; + + // Track that this is something we may need to simplify. + if (SimplifyInstSet.insert(UserI).second) + SimplifyBlockSet.insert(UserI->getParent()); + } } + + if (SimplifyInstSet.empty()) + return; + + // We'll need to simplify some instructions. Wire up the necessary + // infrastructure. + + Module &M = *L.getHeader()->getParent()->getParent(); + auto &DL = M.getDataLayout(); + + // Remember the instructions that are now dead. + SmallVector DeadInsts; + + auto Simplify = [&](Instruction &I) { + // First check if this is already dead and skip it. + if (I.use_empty()) { + if (isInstructionTriviallyDead(&I)) + DeadInsts.push_back(&I); + return; + } + + // Try to do very basic simplifications of these instructions. + Value *SimpleV = SimplifyInstruction(&I, SimplifyQuery(DL, &I)); + if (!SimpleV) + return; + + // Manually do RAUW so that we can intercept users within the loop and + // update them. + for (auto UI = I.use_begin(), UE = I.use_end(); UI != UE;) { + // Grab the use and walk past it so we can clobber it in the use list. + Use *U = &*UI++; + + // Replace the use with the simplified value. + *U = SimpleV; + + // If this is an instruction within the loop, recurse through it. The + // reason we restrict this to in-loop uses is that we don't want to + // simplify the LCSSA phi node away. We could solve this, but the only + // real goal of this is to do very basic cleanup of unswitched conditions. + // We don't need powerful tools here. A proper pass can be scheduled to do + // more comprehensive cleanup. + Instruction *UserI = dyn_cast(U->getUser()); + if (!UserI || !L.contains(UserI)) + continue; + + // This is an in-loop use, so consider simplifying it as well. + if (SimplifyInstSet.insert(UserI).second) + SimplifyBlockSet.insert(UserI->getParent()); + } + + // We expect the instruction to no longer have uses, so check if it can be + // deleted. + assert(I.use_empty() && "Didn't rewrite a use after simplification?"); + if (isInstructionTriviallyDead(&I)) + DeadInsts.push_back(&I); + }; + + // While we have a single block in need of simplifying, we can just walking + // the instructions in that block and simplify the instructions in the order + // we find them. This ensures that we simplify defs before uses. + // + // FIXME: This could be really slow for huge basic blocks, but we shouldn't be + // rewriting invariants *that* many times in the same huge block. + while (SimplifyBlockSet.size() == 1) { + BasicBlock &BB = **SimplifyBlockSet.begin(); + + for (Instruction &I : BB) + if (SimplifyInstSet.count(&I)) + Simplify(I); + + // Erase the current block from the set as we've simplified it. This makes + // it easier to detect if we've introduced new interesting simplification + // challenges. + SimplifyBlockSet.erase(&BB); + } + + if (!SimplifyBlockSet.empty()) { + // If we still have blocks to process it is because we have instructions + // across multiple blocks. We want to ensure we simplify defs before their + // uses and so we use an RPO over the loop blocks to order our visit. + LoopBlocksRPO RPOT(&L); + RPOT.perform(&LI); + for (BasicBlock *BB : RPOT) + if (SimplifyBlockSet.count(BB)) + for (Instruction &I : *BB) + if (SimplifyInstSet.count(&I)) + Simplify(I); + } + + // Finally, delete any dead instructions. + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts); } /// Check that all the LCSSA PHI nodes in the loop exit block have trivial @@ -186,22 +342,29 @@ assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); - Value *LoopCond = BI.getCondition(); + // The loop invariant values that we want to unswitch. + SmallVector Invariants; - // Need a trivial loop condition to unswitch. - if (!L.isLoopInvariant(LoopCond)) + // When true, we're fully unswitching the branch rather than just unswitching + // some input conditions to the branch. + bool FullUnswitch = false; + + if (L.isLoopInvariant(BI.getCondition())) { + Invariants.push_back(BI.getCondition()); + FullUnswitch = true; + } else if (auto *CondInst = dyn_cast(BI.getCondition())) { + Invariants = collectHomogenousInstGraphLoopInvariants(L, *CondInst, LI); + } else { + // Couldn't find invariant inputs! return false; + } - // Check to see if a successor of the branch is guaranteed to - // exit through a unique exit block without having any - // side-effects. If so, determine the value of Cond that causes - // it to do this. - ConstantInt *CondVal = ConstantInt::getTrue(BI.getContext()); - ConstantInt *Replacement = ConstantInt::getFalse(BI.getContext()); + // Check that one of the branch's successors exits, and which one. + bool ExitDirection = true; int LoopExitSuccIdx = 0; auto *LoopExitBB = BI.getSuccessor(0); if (L.contains(LoopExitBB)) { - std::swap(CondVal, Replacement); + ExitDirection = false; LoopExitSuccIdx = 1; LoopExitBB = BI.getSuccessor(1); if (L.contains(LoopExitBB)) @@ -212,8 +375,31 @@ if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) return false; - LLVM_DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal - << " == " << LoopCond << "\n"); + // When unswitching only part of the branch's condition, we need the exit + // block to be reached directly from the partially unswitched input. This can + // be done when the exit block is along the true edge and the branch condition + // is a graph of `or` operations, or the exit block is along the false edge + // and the condition is a graph of `and` operations. + if (!FullUnswitch) { + if (ExitDirection) { + if (cast(BI.getCondition())->getOpcode() != Instruction::Or) + return false; + } else { + if (cast(BI.getCondition())->getOpcode() != Instruction::And) + return false; + } + } + + LLVM_DEBUG({ + dbgs() << " unswitching trivial invariant conditions for: " << BI + << "\n"; + for (Value *Invariant : Invariants) { + dbgs() << " " << *Invariant << " == true"; + if (Invariant != Invariants.back()) + dbgs() << " ||"; + dbgs() << "\n"; + } + }); // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional @@ -226,26 +412,55 @@ // unswitching. We need to split this if there are other loop predecessors. // Because the loop is in simplified form, *any* other predecessor is enough. BasicBlock *UnswitchedBB; - if (BasicBlock *PredBB = LoopExitBB->getUniquePredecessor()) { - (void)PredBB; - assert(PredBB == BI.getParent() && + if (FullUnswitch && LoopExitBB->getUniquePredecessor()) { + assert(LoopExitBB->getUniquePredecessor() == BI.getParent() && "A branch's parent isn't a predecessor!"); UnswitchedBB = LoopExitBB; } else { UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); } - // Now splice the branch to gate reaching the new preheader and re-point its - // successors. - OldPH->getInstList().splice(std::prev(OldPH->end()), - BI.getParent()->getInstList(), BI); + // Actually move the invariant uses into the unswitched position. If possible, + // we do this by moving the instructions, but when doing partial unswitching + // we do it by building a new merge of the values in the unswitched position. OldPH->getTerminator()->eraseFromParent(); - BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); - BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); - - // Create a new unconditional branch that will continue the loop as a new - // terminator. - BranchInst::Create(ContinueBB, ParentBB); + if (FullUnswitch) { + // If fully unswitching, we can use the existing branch instruction. + // Splice it into the old PH to gate reaching the new preheader and re-point + // its successors. + OldPH->getInstList().splice(OldPH->end(), BI.getParent()->getInstList(), + BI); + BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); + BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); + + // Create a new unconditional branch that will continue the loop as a new + // terminator. + BranchInst::Create(ContinueBB, ParentBB); + } else { + // Only unswitching a subset of inputs to the condition, so we will need to + // build a new branch that merges the invariant inputs. + IRBuilder<> IRB(OldPH); + Value *Cond = Invariants.front(); + if (ExitDirection) + assert(cast(BI.getCondition())->getOpcode() == + Instruction::Or && + "Must have an `or` of `i1`s for the condition!"); + else + assert(cast(BI.getCondition())->getOpcode() == + Instruction::And && + "Must have an `and` of `i1`s for the condition!"); + for (Value *Invariant : + make_range(std::next(Invariants.begin()), Invariants.end())) + if (ExitDirection) + Cond = IRB.CreateOr(Cond, Invariant); + else + Cond = IRB.CreateAnd(Cond, Invariant); + + BasicBlock *Succs[2]; + Succs[LoopExitSuccIdx] = UnswitchedBB; + Succs[1 - LoopExitSuccIdx] = NewPH; + IRB.CreateCondBr(Cond, Succs[0], Succs[1]); + } // Rewrite the relevant PHI nodes. if (UnswitchedBB == LoopExitBB) @@ -255,12 +470,20 @@ *ParentBB, *OldPH); // Now we need to update the dominator tree. - DT.applyUpdates( - {{DT.Delete, ParentBB, UnswitchedBB}, {DT.Insert, OldPH, UnswitchedBB}}); + DT.insertEdge(OldPH, UnswitchedBB); + if (FullUnswitch) + DT.deleteEdge(ParentBB, UnswitchedBB); + + // The constant we can replace all of our invariants with inside the loop + // body. If any of the invariants have a value other than this the loop won't + // be entered. + ConstantInt *Replacement = ExitDirection + ? ConstantInt::getFalse(BI.getContext()) + : ConstantInt::getTrue(BI.getContext()); // Since this is an i1 condition we can also trivially replace uses of it // within the loop with a constant. - replaceLoopUsesWithConstant(L, *LoopCond, *Replacement); + replaceLoopInvariantUsesAndSimplify(L, Invariants, *Replacement, LI); ++NumTrivial; ++NumBranches; @@ -560,11 +783,13 @@ // Mark that we managed to unswitch something. Changed = true; - // We unswitched the branch. This should always leave us with an - // unconditional branch that we can follow now. + // If we only unswitched some of the conditions feeding the branch, we won't + // have collapsed it to a single successor. BI = cast(CurrentBB->getTerminator()); - assert(!BI->isConditional() && - "Cannot form a conditional branch by unswitching1"); + if (BI->isConditional()) + return Changed; + + // Follow the newly unconditional branch into its successor. CurrentBB = BI->getSuccessor(0); // When continuing, if we exit the loop or reach a previous visited block, @@ -955,8 +1180,7 @@ // matter as we're just trying to build up the map from inside-out; we use // the map in a more stably ordered way below. auto OrderedClonedExitsInLoops = ClonedExitsInLoops; - llvm::sort(OrderedClonedExitsInLoops.begin(), - OrderedClonedExitsInLoops.end(), + llvm::sort(OrderedClonedExitsInLoops.begin(), OrderedClonedExitsInLoops.end(), [&](BasicBlock *LHS, BasicBlock *RHS) { return ExitLoopMap.lookup(LHS)->getLoopDepth() < ExitLoopMap.lookup(RHS)->getLoopDepth(); Index: llvm/lib/Transforms/Utils/Local.cpp =================================================================== --- llvm/lib/Transforms/Utils/Local.cpp +++ llvm/lib/Transforms/Utils/Local.cpp @@ -434,18 +434,26 @@ SmallVector DeadInsts; DeadInsts.push_back(I); + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, TLI); - do { - I = DeadInsts.pop_back_val(); - salvageDebugInfo(*I); + return true; +} + +void llvm::RecursivelyDeleteTriviallyDeadInstructions( + SmallVectorImpl &DeadInsts, const TargetLibraryInfo *TLI) { + // Process the dead instruction list until empty. + while (!DeadInsts.empty()) { + Instruction &I = *DeadInsts.pop_back_val(); + salvageDebugInfo(I); // Null out all of the instruction's operands to see if any operand becomes // dead as we go. - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - Value *OpV = I->getOperand(i); - I->setOperand(i, nullptr); + for (Use &OpU : I.operands()) { + Value *OpV = OpU.get(); + OpU = nullptr; - if (!OpV->use_empty()) continue; + if (!OpV->use_empty()) + continue; // If the operand is an instruction that became dead as we nulled out the // operand, and if it is 'trivially' dead, delete it in a future loop @@ -455,10 +463,8 @@ DeadInsts.push_back(OpI); } - I->eraseFromParent(); - } while (!DeadInsts.empty()); - - return true; + I.eraseFromParent(); + } } /// areAllUsesEqual - Check whether the uses of a value are all the same. Index: llvm/test/Transforms/SimpleLoopUnswitch/LIV-loop-condtion.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/LIV-loop-condtion.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/LIV-loop-condtion.ll @@ -4,17 +4,24 @@ ; itself is an LIV loop condition (not partial LIV which could occur in and/or). define i32 @test(i1 %cond1, i32 %var1) { +; CHECK-LABEL: define i32 @test( entry: br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %cond1, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin loop_begin: %var3 = phi i32 [%var1, %entry], [%var2, %do_something] %cond2 = icmp eq i32 %var3, 10 %cond.and = and i1 %cond1, %cond2 - -; %cond.and only has %cond1 as LIV so no unswitch should happen. -; CHECK: br i1 %cond.and, label %do_something, label %loop_exit - br i1 %cond.and, label %do_something, label %loop_exit + br i1 %cond.and, label %do_something, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: %[[VAR3:.*]] = phi i32 +; CHECK-NEXT: %[[COND2:.*]] = icmp eq i32 %[[VAR3]], 10 +; CHECK-NEXT: br i1 %[[COND2]], label %do_something, label %loop_exit do_something: %var2 = add i32 %var3, 1 Index: llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll @@ -443,3 +443,94 @@ ; CHECK: cleanup: ; CHECK-NEXT: ret void } + +define i32 @test_partial_condition_unswitch_and(i32* %var, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @test_partial_condition_unswitch_and( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %cond1, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: br i1 %cond2, label %entry.split.split, label %loop_exit +; +; CHECK: entry.split.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + br i1 %cond1, label %continue, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: br label %continue + +continue: + %var_val = load i32, i32* %var + %var_cond = trunc i32 %var_val to i1 + %cond_and = and i1 %var_cond, %cond2 + br i1 %cond_and, label %do_something, label %loop_exit +; CHECK: continue: +; CHECK-NEXT: %[[VAL:.*]] = load i32 +; CHECK-NEXT: %[[VAL_COND:.*]] = trunc i32 %[[VAL]] to i1 +; CHECK-NEXT: br i1 %[[VAL_COND]], label %do_something, label %loop_exit + +do_something: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: do_something: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + ret i32 0 +; CHECK: loop_exit: +; CHECK-NEXT: br label %loop_exit.split +; +; CHECK: loop_exit.split: +; CHECK-NEXT: ret +} + +define i32 @test_partial_condition_unswitch_or(i32* %var, i1 %cond1, i1 %cond2, i1 %cond3, i1 %cond4, i1 %cond5, i1 %cond6) { +; CHECK-LABEL: @test_partial_condition_unswitch_or( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[INV_OR1:.*]] = or i1 %cond4, %cond2 +; CHECK-NEXT: %[[INV_OR2:.*]] = or i1 %[[INV_OR1]], %cond3 +; CHECK-NEXT: %[[INV_OR3:.*]] = or i1 %[[INV_OR2]], %cond1 +; CHECK-NEXT: br i1 %[[INV_OR3]], label %loop_exit.split, label %entry.split +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + %var_cond = trunc i32 %var_val to i1 + %cond_or1 = or i1 %var_cond, %cond1 + %cond_or2 = or i1 %cond2, %cond3 + %cond_or3 = or i1 %cond_or1, %cond_or2 + %cond_xor1 = xor i1 %cond5, %var_cond + %cond_and1 = and i1 %cond6, %var_cond + %cond_or4 = or i1 %cond_xor1, %cond_and1 + %cond_or5 = or i1 %cond_or3, %cond_or4 + %cond_or6 = or i1 %cond_or5, %cond4 + br i1 %cond_or6, label %loop_exit, label %do_something +; CHECK: loop_begin: +; CHECK-NEXT: %[[VAR:.*]] = load i32 +; CHECK-NEXT: %[[VAR_COND:.*]] = trunc i32 %[[VAR]] to i1 +; CHECK-NEXT: %[[COND_XOR:.*]] = xor i1 %cond5, %[[VAR_COND]] +; CHECK-NEXT: %[[COND_AND:.*]] = and i1 %cond6, %[[VAR_COND]] +; CHECK-NEXT: %[[COND_OR1:.*]] = or i1 %[[COND_XOR]], %[[COND_AND]] +; CHECK-NEXT: %[[COND_OR2:.*]] = or i1 %[[VAR_COND]], %[[COND_OR1]] +; CHECK-NEXT: br i1 %[[COND_OR2]], label %loop_exit, label %do_something + +do_something: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: do_something: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + ret i32 0 +; CHECK: loop_exit.split: +; CHECK-NEXT: ret +} \ No newline at end of file