diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -640,6 +640,26 @@ return false; } +namespace { +/// Struct to hold information about a partially invariant condition. +struct IVConditionInfo { + /// Instructions that need to be duplicated and checked for the unswitching + /// condition. + SmallVector InstToDuplicate; + + /// Constant to indicate for which value the condition is invariant. + Constant *KnownValue = nullptr; + + /// True if the partially invariant path is no-op (=does not have any + /// side-effects and no loop value is used outside the loop). + bool PathIsNoop = true; + + /// If the partially invariant path reaches a single exit block, ExitForPath + /// is set to that block. Otherwise it is nullptr. + BasicBlock *ExitForPath = nullptr; +}; +} // namespace + /// Check if the loop header has a conditional branch that is not /// loop-invariant, because it involves load instructions. If all paths from /// either the true or false successor to the header or loop exists do not @@ -651,9 +671,8 @@ /// If the branch condition of the header is partially invariant, return a pair /// containing the instructions to duplicate and a boolean Constant to update /// the condition in the loops created for the true or false successors. -static std::pair, Constant *> -hasPartialIVCondition(Loop *L, MemorySSA &MSSA, AAResults *AA) { - SmallVector ToDuplicate; +static Optional hasPartialIVCondition(Loop *L, MemorySSA &MSSA, + AAResults *AA) { auto *TI = dyn_cast(L->getHeader()->getTerminator()); if (!TI || !TI->isConditional()) @@ -665,7 +684,8 @@ if (!CondI || !L->contains(CondI)) return {}; - ToDuplicate.push_back(CondI); + SmallVector InstToDuplicate; + InstToDuplicate.push_back(CondI); SmallVector WorkList; WorkList.append(CondI->op_begin(), CondI->op_end()); @@ -686,7 +706,7 @@ if (LI->isVolatile() || LI->isAtomic()) return {}; - ToDuplicate.push_back(I); + InstToDuplicate.push_back(I); if (MemoryAccess *MA = MSSA.getMemoryAccess(I)) { if (auto *MemUse = dyn_cast_or_null(MA)) { // Queue the defining access to check for alias checks. @@ -701,80 +721,126 @@ WorkList.append(I->op_begin(), I->op_end()); } - if (ToDuplicate.size() <= 1) + if (InstToDuplicate.size() <= 1) return {}; + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); auto HasNoClobbersOnPath = - [L, AA, &AccessedLocs](BasicBlock *Succ, BasicBlock *Header, - SmallVector AccessesToCheck) { - // First, collect all blocks in the loop that are on a patch from Succ - // to the header. - SmallVector WorkList; - WorkList.push_back(Succ); - WorkList.push_back(Header); - SmallPtrSet Seen; - Seen.insert(Header); - while (!WorkList.empty()) { - BasicBlock *Current = WorkList.pop_back_val(); - if (!L->contains(Current)) - continue; - const auto &SeenIns = Seen.insert(Current); - if (!SeenIns.second) - continue; + [L, AA, &AccessedLocs, &ExitingBlocks, + &InstToDuplicate](BasicBlock *Succ, BasicBlock *Header, + SmallVector AccessesToCheck) + -> Optional { + IVConditionInfo Info; + // First, collect all blocks in the loop that are on a patch from Succ + // to the header. + SmallVector WorkList; + WorkList.push_back(Succ); + WorkList.push_back(Header); + SmallPtrSet Seen; + Seen.insert(Header); + Info.PathIsNoop &= + all_of(*Header, [](Instruction &I) { return !I.mayHaveSideEffects(); }); + + while (!WorkList.empty()) { + BasicBlock *Current = WorkList.pop_back_val(); + if (!L->contains(Current)) + continue; + const auto &SeenIns = Seen.insert(Current); + if (!SeenIns.second) + continue; - WorkList.append(succ_begin(Current), succ_end(Current)); - } + Info.PathIsNoop &= all_of( + *Current, [](Instruction &I) { return !I.mayHaveSideEffects(); }); + WorkList.append(succ_begin(Current), succ_end(Current)); + } - // Require at least 2 blocks on a path through the loop. This skips - // paths that directly exit the loop. - if (Seen.size() < 2) - return false; + // Require at least 2 blocks on a path through the loop. This skips + // paths that directly exit the loop. + if (Seen.size() < 2) + return {}; - // Next, check if there are any MemoryDefs that are on the path through - // the loop (in the Seen set) and they may-alias any of the locations in - // AccessedLocs. If that is the case, they may modify the condition and - // partial unswitching is not possible. - SmallPtrSet SeenAccesses; - while (!AccessesToCheck.empty()) { - MemoryAccess *Current = AccessesToCheck.pop_back_val(); - auto SeenI = SeenAccesses.insert(Current); - if (!SeenI.second || !Seen.contains(Current->getBlock())) - continue; + // Next, check if there are any MemoryDefs that are on the path through + // the loop (in the Seen set) and they may-alias any of the locations in + // AccessedLocs. If that is the case, they may modify the condition and + // partial unswitching is not possible. + SmallPtrSet SeenAccesses; + while (!AccessesToCheck.empty()) { + MemoryAccess *Current = AccessesToCheck.pop_back_val(); + auto SeenI = SeenAccesses.insert(Current); + if (!SeenI.second || !Seen.contains(Current->getBlock())) + continue; - // Bail out if exceeded the threshold. - if (SeenAccesses.size() >= MSSAThreshold) - return false; + // Bail out if exceeded the threshold. + if (SeenAccesses.size() >= MSSAThreshold) + return {}; - // MemoryUse are read-only accesses. - if (isa(Current)) - continue; + // MemoryUse are read-only accesses. + if (isa(Current)) + continue; - // For a MemoryDef, check if is aliases any of the location feeding - // the original condition. - if (auto *CurrentDef = dyn_cast(Current)) { - if (any_of(AccessedLocs, [AA, CurrentDef](MemoryLocation &Loc) { - return isModSet( - AA->getModRefInfo(CurrentDef->getMemoryInst(), Loc)); - })) - return false; - } + // For a MemoryDef, check if is aliases any of the location feeding + // the original condition. + if (auto *CurrentDef = dyn_cast(Current)) { + if (any_of(AccessedLocs, [AA, CurrentDef](MemoryLocation &Loc) { + return isModSet( + AA->getModRefInfo(CurrentDef->getMemoryInst(), Loc)); + })) + return {}; + } + + for (Use &U : Current->uses()) + AccessesToCheck.push_back(cast(U.getUser())); + } - for (Use &U : Current->uses()) - AccessesToCheck.push_back(cast(U.getUser())); + // We could also allow loops with known trip counts without mustprogress, + // but ScalarEvolution may not be available. + Info.PathIsNoop &= + L->getHeader()->getParent()->mustProgress() || hasMustProgress(L); + + // If the path is considered a no-op so far, check if it reaches a + // single exit block without any phis. This ensures no values from the + // loop are used outside of the loop. + if (Info.PathIsNoop) { + for (auto *Exiting : ExitingBlocks) { + if (!Seen.contains(Exiting)) + continue; + for (auto *Succ : successors(Exiting)) { + if (L->contains(Succ)) + continue; + + Info.PathIsNoop &= empty(Succ->phis()) && + (!Info.ExitForPath || Info.ExitForPath == Succ); + if (!Info.PathIsNoop) + break; + assert(!Info.ExitForPath || Info.ExitForPath == Succ && + "cannot have multiple exit blocks"); + Info.ExitForPath = Succ; } + } + } + if (!Info.ExitForPath) + Info.PathIsNoop = false; - return true; - }; + Info.InstToDuplicate = InstToDuplicate; + return Info; + }; // If we branch to the same successor, partial unswitching will not be // beneficial. if (TI->getSuccessor(0) == TI->getSuccessor(1)) return {}; - if (HasNoClobbersOnPath(TI->getSuccessor(0), L->getHeader(), AccessesToCheck)) - return {ToDuplicate, ConstantInt::getTrue(TI->getContext())}; - if (HasNoClobbersOnPath(TI->getSuccessor(1), L->getHeader(), AccessesToCheck)) - return {ToDuplicate, ConstantInt::getFalse(TI->getContext())}; + if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(0), L->getHeader(), + AccessesToCheck)) { + Info->KnownValue = ConstantInt::getTrue(TI->getContext()); + return Info; + } + if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(1), L->getHeader(), + AccessesToCheck)) { + Info->KnownValue = ConstantInt::getFalse(TI->getContext()); + return Info; + } return {}; } @@ -986,17 +1052,56 @@ // metadata, to avoid unswitching the same loop multiple times. if (MSSA && !findOptionMDForLoop(CurrentLoop, "llvm.loop.unswitch.partial.disable")) { - auto ToDuplicate = hasPartialIVCondition(CurrentLoop, *MSSA, AA); - if (!ToDuplicate.first.empty()) { + if (auto Info = hasPartialIVCondition(CurrentLoop, *MSSA, AA)) { + assert(!Info->InstToDuplicate.empty() && + "need at least a partially invariant condition"); LLVM_DEBUG(dbgs() << "loop-unswitch: Found partially invariant condition " - << *ToDuplicate.first[0] << "\n"); - ++NumBranches; - unswitchIfProfitable(ToDuplicate.first[0], ToDuplicate.second, - CurrentLoop->getHeader()->getTerminator(), - ToDuplicate.first); + << *Info->InstToDuplicate[0] << "\n"); + + Instruction *TI = CurrentLoop->getHeader()->getTerminator(); + Value *LoopCond = Info->InstToDuplicate[0]; + + // If the partially unswitched path is a no-op and has a single exit + // block, we do not need to do full unswitching. Instead, we can directly + // branch to the exit. + // TODO: Instead of duplicating the checks, we could also just directly + // branch to the exit from the conditional branch in the loop. + if (Info->PathIsNoop) { + if (HasBranchDivergence && + getAnalysis().isDivergent(LoopCond)) { + LLVM_DEBUG(dbgs() << "NOT unswitching loop %" + << CurrentLoop->getHeader()->getName() + << " at non-trivial condition '" + << *Info->KnownValue << "' == " << *LoopCond << "\n" + << ". Condition is divergent.\n"); + return false; + } - RedoLoop = false; - return true; + ++NumBranches; + + BasicBlock *TrueDest = LoopHeader; + BasicBlock *FalseDest = Info->ExitForPath; + if (Info->KnownValue->isOneValue()) + std::swap(TrueDest, FalseDest); + + auto *OldBr = + cast(CurrentLoop->getLoopPreheader()->getTerminator()); + emitPreheaderBranchOnCondition(LoopCond, Info->KnownValue, TrueDest, + FalseDest, OldBr, TI, + Info->InstToDuplicate); + delete OldBr; + RedoLoop = false; + return true; + } + + // Otherwise, the path is not a no-op. Run regular unswitching. + if (unswitchIfProfitable(LoopCond, Info->KnownValue, + CurrentLoop->getHeader()->getTerminator(), + Info->InstToDuplicate)) { + ++NumBranches; + RedoLoop = false; + return true; + } } } diff --git a/llvm/test/Transforms/LoopUnswitch/partial-unswitch-cost.ll b/llvm/test/Transforms/LoopUnswitch/partial-unswitch-cost.ll --- a/llvm/test/Transforms/LoopUnswitch/partial-unswitch-cost.ll +++ b/llvm/test/Transforms/LoopUnswitch/partial-unswitch-cost.ll @@ -48,8 +48,16 @@ define i32 @partial_unswitch_shortcut_mustprogress(i32* %ptr, i32 %N) mustprogress { ; CHECK-LABEL: @partial_unswitch_shortcut_mustprogress ; CHECK-LABEL: entry: +; CHECK-NEXT: [[LV:%[0-9]+]] = load i32, i32* %ptr, align 4 +; CHECK-NEXT: [[C:%[0-9]+]] = icmp eq i32 [[LV]], 100 +; CHECK-NEXT: br i1 [[C]], label %[[CRIT_TO_EXIT:[a-z._]+]], label %[[CRIT_TO_HEADER:[a-z._]+]] +; +; CHECK: [[CRIT_TO_HEADER]]: ; CHECK-NEXT: br label %loop.header ; +; CHECK: [[CRIT_TO_EXIT]]: +; CHECK-NEXT: br label %exit +; entry: br label %loop.header @@ -86,8 +94,16 @@ define i32 @partial_unswitch_shortcut_mustprogress_single_exit_on_path(i32* %ptr, i32 %N) mustprogress { ; CHECK-LABEL: @partial_unswitch_shortcut_mustprogress_single_exit_on_path ; CHECK-LABEL: entry: +; CHECK-NEXT: [[LV:%[0-9]+]] = load i32, i32* %ptr, align 4 +; CHECK-NEXT: [[C:%[0-9]+]] = icmp eq i32 [[LV]], 100 +; CHECK-NEXT: br i1 [[C]], label %[[CRIT_TO_EXIT:.+]], label %[[CRIT_TO_HEADER:[a-z._]+]] +; +; CHECK: [[CRIT_TO_HEADER]]: ; CHECK-NEXT: br label %loop.header ; +; CHECK: [[CRIT_TO_EXIT]]: +; CHECK-NEXT: br label %exit +; entry: br label %loop.header @@ -208,8 +224,16 @@ define i32 @partial_unswitch_shortcut_multiple_exiting_blocks(i32* %ptr, i32 %N, i1 %ec.1) mustprogress { ; CHECK-LABEL: @partial_unswitch_shortcut_multiple_exiting_blocks ; CHECK-LABEL: entry: +; CHECK-NEXT: [[LV:%[0-9]+]] = load i32, i32* %ptr, align 4 +; CHECK-NEXT: [[C:%[0-9]+]] = icmp eq i32 [[LV]], 100 +; CHECK-NEXT: br i1 [[C]], label %[[CRIT_TO_EXIT:.+]], label %[[CRIT_TO_HEADER:[a-z._]+]] +; +; CHECK: [[CRIT_TO_HEADER]]: ; CHECK-NEXT: br label %loop.header ; +; CHECK: [[CRIT_TO_EXIT]]: +; CHECK-NEXT: br label %exit +; entry: br label %loop.header