Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -103,6 +103,11 @@ cl::init(false), cl::Hidden, cl::desc("If enabled, drop make.implicit metadata in unswitched implicit " "null checks to save time analyzing if we can keep it.")); +static cl::opt + MSSAThreshold("simple-loop-unswitch-memoryssa-threshold", + cl::desc("Max number of memory uses to explore during " + "partial unswitching analysis"), + cl::init(100), cl::Hidden); /// Collect all of the loop invariant input values transitively used by the /// homogeneous instruction graph from a given root. @@ -202,6 +207,49 @@ Direction ? &NormalSucc : &UnswitchedSucc); } +/// Copy a set of loop invariant values, and conditionally branch on them. +static void buildPartialInvariantUnswitchConditionalBranch( + BasicBlock &BB, ArrayRef Invariants, bool Direction, + BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, Loop &L, + MemorySSAUpdater *MSSAU) { + ValueToValueMapTy VMap; + for (auto *Val : reverse(Invariants)) { + Instruction *Inst = cast(Val); + Instruction *NewInst = Inst->clone(); + BB.getInstList().insert(BB.end(), NewInst); + RemapInstruction(NewInst, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + VMap[Val] = NewInst; + + if (!MSSAU) + continue; + + MemorySSA *MSSA = MSSAU->getMemorySSA(); + if (auto *MemUse = + dyn_cast_or_null(MSSA->getMemoryAccess(Inst))) { + auto *DefiningAccess = MemUse->getDefiningAccess(); + // Get the first defining access before the loop. + while (L.contains(DefiningAccess->getBlock())) { + // If the defining access is a MemoryPhi, get the incoming + // value for the pre-header as defining access. + if (auto *MemPhi = dyn_cast(DefiningAccess)) + DefiningAccess = + MemPhi->getIncomingValueForBlock(L.getLoopPreheader()); + else + DefiningAccess = cast(DefiningAccess)->getDefiningAccess(); + } + MSSAU->createMemoryAccessInBB(NewInst, DefiningAccess, + NewInst->getParent(), + MemorySSA::BeforeTerminator); + } + } + + IRBuilder<> IRB(&BB); + Value *Cond = VMap[Invariants[0]]; + IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, + Direction ? &NormalSucc : &UnswitchedSucc); +} + /// Rewrite the PHI nodes in an unswitched loop exit basic block. /// /// Requires that the loop exit and unswitched basic block are the same, and @@ -1964,18 +2012,23 @@ static void unswitchNontrivialInvariants( Loop &L, Instruction &TI, ArrayRef Invariants, - SmallVectorImpl &ExitBlocks, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, function_ref)> UnswitchCB, + SmallVectorImpl &ExitBlocks, + struct IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, + function_ref)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast(&TI); SwitchInst *SI = BI ? nullptr : cast(&TI); // We can only unswitch switches, conditional branches with an invariant - // condition, or combining invariant conditions with an instruction. + // condition, or combining invariant conditions with an instruction or + // partially invariant instructions. assert((SI || (BI && BI->isConditional())) && "Can only unswitch switches and conditional branch!"); - bool FullUnswitch = SI || BI->getCondition() == Invariants[0]; + bool PartiallyInvariant = !PartialIVInfo.InstToDuplicate.empty(); + bool FullUnswitch = + SI || (BI->getCondition() == Invariants[0] && !PartiallyInvariant); if (FullUnswitch) assert(Invariants.size() == 1 && "Cannot have other invariants with full unswitching!"); @@ -1989,18 +2042,23 @@ // Constant and BBs tracking the cloned and continuing successor. When we are // unswitching the entire condition, this can just be trivially chosen to // unswitch towards `true`. However, when we are unswitching a set of - // invariants combined with `and` or `or`, the combining operation determines - // the best direction to unswitch: we want to unswitch the direction that will - // collapse the branch. + // invariants combined with `and` or `or` or partially invariant instructions, + // the combining operation determines the best direction to unswitch: we want + // to unswitch the direction that will collapse the branch. bool Direction = true; int ClonedSucc = 0; if (!FullUnswitch) { if (!match(BI->getCondition(), m_LogicalOr())) { - assert(match(BI->getCondition(), m_LogicalAnd()) && - "Only `or`, `and`, an `select` instructions can combine " - "invariants being unswitched."); - Direction = false; - ClonedSucc = 1; + assert( + (match(BI->getCondition(), m_LogicalAnd()) || PartiallyInvariant) && + "Only `or`, `and`, an `select` instructions can combine invariants " + "being unswitched. Partially invariant instructions can also be " + "unswitched."); + if (match(BI->getCondition(), m_LogicalAnd()) || + (PartiallyInvariant && !PartialIVInfo.KnownValue->isOneValue())) { + Direction = false; + ClonedSucc = 1; + } } } @@ -2088,10 +2146,18 @@ VMaps.reserve(UnswitchedSuccBBs.size()); SmallDenseMap ClonedPHs; for (auto *SuccBB : UnswitchedSuccBBs) { - VMaps.emplace_back(new ValueToValueMapTy()); - ClonedPHs[SuccBB] = buildClonedLoopBlocks( - L, LoopPH, SplitBB, ExitBlocks, ParentBB, SuccBB, RetainedSuccBB, - DominatingSucc, *VMaps.back(), DTUpdates, AC, DT, LI, MSSAU); + // In Partially invariant case, if UnswithcedSuccBB is exit block, do not + // clone loop and assigned the UnswitchedSuccBB to ClonedPHs. + if (PartiallyInvariant && llvm::any_of(ExitBlocks, [&](BasicBlock *ExitBB) { + return ExitBB == SuccBB; + })) + ClonedPHs[SuccBB] = SuccBB; + else { + VMaps.emplace_back(new ValueToValueMapTy()); + ClonedPHs[SuccBB] = buildClonedLoopBlocks( + L, LoopPH, SplitBB, ExitBlocks, ParentBB, SuccBB, RetainedSuccBB, + DominatingSucc, *VMaps.back(), DTUpdates, AC, DT, LI, MSSAU); + } } // Drop metadata if we may break its semantics by moving this instr into the @@ -2218,8 +2284,12 @@ BasicBlock *ClonedPH = ClonedPHs.begin()->second; // When doing a partial unswitch, we have to do a bit more work to build up // the branch in the split block. - buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, - *ClonedPH, *LoopPH); + if (PartiallyInvariant) + buildPartialInvariantUnswitchConditionalBranch( + *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, L, MSSAU); + else + buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, + *ClonedPH, *LoopPH); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); if (MSSAU) { @@ -2289,7 +2359,8 @@ // for each invariant operand. // So it happens that for multiple-partial case we dont replace // in the unswitched branch. - bool ReplaceUnswitched = FullUnswitch || (Invariants.size() == 1); + bool ReplaceUnswitched = + FullUnswitch || (Invariants.size() == 1) || PartiallyInvariant; ConstantInt *UnswitchedReplacement = Direction ? ConstantInt::getTrue(BI->getContext()) @@ -2301,7 +2372,7 @@ // Use make_early_inc_range here as set invalidates the iterator. for (Use &U : llvm::make_early_inc_range(Invariant->uses())) { Instruction *UserI = dyn_cast(U.getUser()); - if (!UserI) + if (!UserI || PartiallyInvariant) continue; // Replace it with the 'continue' side if in the main loop body, and the @@ -2384,7 +2455,7 @@ for (Loop *UpdatedL : llvm::concat(NonChildClonedLoops, HoistedLoops)) if (UpdatedL->getParentLoop() == ParentL) SibLoops.push_back(UpdatedL); - UnswitchCB(IsStillLoop, SibLoops); + UnswitchCB(IsStillLoop, PartiallyInvariant, SibLoops); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -2599,11 +2670,11 @@ return CostMultiplier; } -static bool -unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, TargetTransformInfo &TTI, - function_ref)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { +static bool unswitchBestCondition( + Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + AAResults &AA, TargetTransformInfo &TTI, + function_ref)> UnswitchCB, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector>, 4> @@ -2618,6 +2689,7 @@ CollectGuards = true; } + struct IVConditionInfo PartialIVInfo; for (auto *BB : L.blocks()) { if (LI.getLoopFor(BB) != &L) continue; @@ -2651,15 +2723,33 @@ } Instruction &CondI = *cast(BI->getCondition()); - if (!match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) - continue; + if (match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) { + TinyPtrVector Invariants = + collectHomogenousInstGraphLoopInvariants(L, CondI, LI); + if (Invariants.empty()) + continue; - TinyPtrVector Invariants = - collectHomogenousInstGraphLoopInvariants(L, CondI, LI); - if (Invariants.empty()) + UnswitchCandidates.push_back({BI, std::move(Invariants)}); continue; + } + } - UnswitchCandidates.push_back({BI, std::move(Invariants)}); + if (MSSAU && + !llvm::any_of(UnswitchCandidates, [&](auto &TerminatorAndInvariants) { + return TerminatorAndInvariants.first == L.getHeader()->getTerminator(); + })) { + MemorySSA *MSSA = MSSAU->getMemorySSA(); + if (auto Info = llvm::hasPartialIVCondition(L, MSSAThreshold, *MSSA, AA)) { + LLVM_DEBUG( + dbgs() << "simple-loop-unswitch: Found partially invariant condition " + << *Info->InstToDuplicate[0] << "\n"); + PartialIVInfo = *Info; + TinyPtrVector ValsToDuplicate; + for (auto *Inst : Info->InstToDuplicate) + ValsToDuplicate.push_back(Inst); + UnswitchCandidates.push_back( + {L.getHeader()->getTerminator(), std::move(ValsToDuplicate)}); + } } // If we didn't find any candidates, we're done. @@ -2765,20 +2855,25 @@ continue; // If this is a partial unswitch candidate, then it must be a conditional - // branch with a condition of either `or`, `and`, or their corresponding - // select forms. In that case, one of the successors is necessarily - // duplicated, so don't even try to remove its cost. + // branch with a condition of either `or`, `and`, their corresponding + // select forms or partially invariant instructions. In that case, one of + // the successors is necessarily duplicated, so don't even try to remove + // its cost. if (!FullUnswitch) { auto &BI = cast(TI); if (match(BI.getCondition(), m_LogicalAnd())) { if (SuccBB == BI.getSuccessor(1)) continue; - } else { - assert(match(BI.getCondition(), m_LogicalOr()) && - "Only `and` and `or` conditions can result in a partial " - "unswitch!"); + } else if (match(BI.getCondition(), m_LogicalOr())) { if (SuccBB == BI.getSuccessor(0)) continue; + } else if (!PartialIVInfo.InstToDuplicate.empty()) { + if (PartialIVInfo.KnownValue->isOneValue() && + SuccBB == BI.getSuccessor(1)) + continue; + else if (!PartialIVInfo.KnownValue->isOneValue() && + SuccBB == BI.getSuccessor(0)) + continue; } } @@ -2852,11 +2947,11 @@ BestUnswitchTI = turnGuardIntoBranch(cast(BestUnswitchTI), L, ExitBlocks, DT, LI, MSSAU); - LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " - << BestUnswitchCost << ") terminator: " << *BestUnswitchTI - << "\n"); + LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << BestUnswitchCost + << ") terminator: " << *BestUnswitchTI << "\n"); unswitchNontrivialInvariants(L, *BestUnswitchTI, BestUnswitchInvariants, - ExitBlocks, DT, LI, AC, UnswitchCB, SE, MSSAU); + ExitBlocks, PartialIVInfo, DT, LI, AC, + UnswitchCB, SE, MSSAU); return true; } @@ -2867,9 +2962,9 @@ /// looks at other loop invariant control flows and tries to unswitch those as /// well by cloning the loop if the result is small enough. /// -/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also -/// updated based on the unswitch. -/// The `MSSA` analysis is also updated if valid (i.e. its use is enabled). +/// The `DT`, `LI`, `AC`, `AA`, `TTI` parameters are required analyses that are +/// also updated based on the unswitch. The `MSSA` analysis is also updated if +/// valid (i.e. its use is enabled). /// /// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is /// true, we will attempt to do non-trivial unswitching as well as trivial @@ -2881,11 +2976,11 @@ /// /// If `SE` is non-null, we will update that analysis based on the unswitching /// done. -static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, TargetTransformInfo &TTI, - bool NonTrivial, - function_ref)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { +static bool +unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + AAResults &AA, TargetTransformInfo &TTI, bool NonTrivial, + function_ref)> UnswitchCB, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -2897,7 +2992,7 @@ if (unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. - UnswitchCB(/*CurrentLoopValid*/ true, {}); + UnswitchCB(/*CurrentLoopValid*/ true, false, {}); return true; } @@ -2933,7 +3028,7 @@ // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE, MSSAU)) + if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU)) return true; // No other opportunities to unswitch. @@ -2954,6 +3049,7 @@ std::string LoopName = std::string(L.getName()); auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, + bool PartiallyInvariant, ArrayRef NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. if (!NewLoops.empty()) @@ -2961,9 +3057,10 @@ // If the current loop remains valid, we should revisit it to catch any // other unswitch opportunities. Otherwise, we need to mark it as deleted. - if (CurrentLoopValid) - U.revisitCurrentLoop(); - else + if (CurrentLoopValid) { + if (!PartiallyInvariant) + U.revisitCurrentLoop(); + } else U.markLoopAsDeleted(L, LoopName); }; @@ -2973,8 +3070,9 @@ if (VerifyMemorySSA) AR.MSSA->verifyMemorySSA(); } - if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB, - &AR.SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, NonTrivial, + UnswitchCB, &AR.SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) return PreservedAnalyses::all(); if (AR.MSSA && VerifyMemorySSA) @@ -3031,6 +3129,7 @@ auto &DT = getAnalysis().getDomTree(); auto &LI = getAnalysis().getLoopInfo(); auto &AC = getAnalysis().getAssumptionCache(F); + auto &AA = getAnalysis().getAAResults(); auto &TTI = getAnalysis().getTTI(F); MemorySSA *MSSA = nullptr; Optional MSSAU; @@ -3042,7 +3141,7 @@ auto *SEWP = getAnalysisIfAvailable(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; - auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, + auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, bool PartiallyInvariant, ArrayRef NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. for (auto *NewL : NewLoops) @@ -3051,17 +3150,19 @@ // If the current loop remains valid, re-add it to the queue. This is // a little wasteful as we'll finish processing the current loop as well, // but it is the best we can do in the old PM. - if (CurrentLoopValid) - LPM.addLoop(*L); - else + if (CurrentLoopValid) { + if (!PartiallyInvariant) + LPM.addLoop(*L); + } else LPM.markLoopAsDeleted(*L); }; if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA(); - bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); + bool Changed = + unswitchLoop(*L, DT, LI, AC, AA, TTI, NonTrivial, UnswitchCB, SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA();