Index: llvm/include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -300,11 +300,12 @@ /// to assess the legality of duplicating atomic loads. Generally, this is /// true when moving out of loop and not true when moving into loops. /// If \p ORE is set use it to emit optimization remarks. -bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, - Loop *CurLoop, AliasSetTracker *CurAST, - MemorySSAUpdater *MSSAU, bool TargetExecutesOncePerLoop, - SinkAndHoistLICMFlags *LICMFlags = nullptr, - OptimizationRemarkEmitter *ORE = nullptr); +bool canSinkOrHoistInst( + Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, + AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, + bool TargetExecutesOncePerLoop, SinkAndHoistLICMFlags *LICMFlags = nullptr, + OptimizationRemarkEmitter *ORE = nullptr, + const Optional > &LoopMemWrites = None); /// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind. Value *createMinMaxOp(IRBuilderBase &Builder, Index: llvm/lib/Transforms/Scalar/LICM.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LICM.cpp +++ llvm/lib/Transforms/Scalar/LICM.cpp @@ -140,10 +140,12 @@ TargetTransformInfo *TTI, bool &FreeInLoop); static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, + Optional > LoopMemWrites, MemorySSAUpdater *MSSAU, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE); static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, + Optional > LoopMemWrites, MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE); static bool isSafeToExecuteUnconditionally(Instruction &Inst, const DominatorTree *DT, @@ -151,9 +153,10 @@ const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, const Instruction *CtxI = nullptr); -static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, - AliasSetTracker *CurAST, Loop *CurLoop, - AliasAnalysis *AA); +static Optional > collectMemWrites(Loop *CurLoop); +static bool pointerInvalidatedByLoop( + MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, + AliasAnalysis *AA, const Optional > &LoopMemWrites); static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, SinkAndHoistLICMFlags &Flags); @@ -162,7 +165,9 @@ const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU); static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, - AliasSetTracker *AST, MemorySSAUpdater *MSSAU); + AliasSetTracker *AST, + Optional > LoopMemWrites, + MemorySSAUpdater *MSSAU); static void moveInstructionBefore(Instruction &I, Instruction &Dest, ICFLoopSafetyInfo &SafetyInfo, @@ -468,6 +473,9 @@ // order. SmallVector Worklist = collectChildrenInLoop(N, CurLoop); + // Collect memory-writing instructions. This is needed for AST mode only. + auto LoopMemWrites = CurAST ? collectMemWrites(CurLoop) : None; + bool Changed = false; for (DomTreeNode *DTN : reverse(Worklist)) { BasicBlock *BB = DTN->getBlock(); @@ -486,7 +494,7 @@ salvageKnowledge(&I); salvageDebugInfo(I); ++II; - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(I, *SafetyInfo, CurAST, LoopMemWrites, MSSAU); Changed = true; continue; } @@ -501,11 +509,11 @@ isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, ORE)) { - if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) { + if (sink(I, LI, DT, CurLoop, SafetyInfo, LoopMemWrites, MSSAU, ORE)) { if (!FreeInLoop) { ++II; salvageDebugInfoOrMarkUndef(I); - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(I, *SafetyInfo, CurAST, LoopMemWrites, MSSAU); } Changed = true; } @@ -770,6 +778,9 @@ // re-hoisted if they end up not dominating all of their uses. SmallVector HoistedInstructions; + // Collect memory-writing instructions. This is needed for AST mode only. + auto LoopMemWrites = CurAST ? collectMemWrites(CurLoop) : None; + // For PHI hoisting to work we need to hoist blocks before their successors. // We can do this by iterating through the blocks in the loop in reverse // post-order. @@ -795,8 +806,9 @@ CurAST->copyValue(&I, C); // FIXME MSSA: Such replacements may make accesses unoptimized (D51960). I.replaceAllUsesWith(C); - if (isInstructionTriviallyDead(&I, TLI)) - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + if (isInstructionTriviallyDead(&I, TLI)) { + eraseInstruction(I, *SafetyInfo, CurAST, LoopMemWrites, MSSAU); + } Changed = true; continue; } @@ -809,12 +821,12 @@ // to that block. if (CurLoop->hasLoopInvariantOperands(&I) && canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, - ORE) && + ORE, LoopMemWrites) && isSafeToExecuteUnconditionally( I, DT, CurLoop, SafetyInfo, ORE, CurLoop->getLoopPreheader()->getTerminator())) { hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, - MSSAU, SE, ORE); + LoopMemWrites, MSSAU, SE, ORE); HoistedInstructions.push_back(&I); Changed = true; continue; @@ -837,10 +849,9 @@ SafetyInfo->insertInstructionTo(Product, I.getParent()); Product->insertAfter(&I); I.replaceAllUsesWith(Product); - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); - + eraseInstruction(I, *SafetyInfo, CurAST, LoopMemWrites, MSSAU); hoist(*ReciprocalDivisor, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), - SafetyInfo, MSSAU, SE, ORE); + SafetyInfo, LoopMemWrites, MSSAU, SE, ORE); HoistedInstructions.push_back(ReciprocalDivisor); Changed = true; continue; @@ -859,7 +870,7 @@ CurLoop->hasLoopInvariantOperands(&I) && MustExecuteWithoutWritesBefore(I)) { hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, - MSSAU, SE, ORE); + LoopMemWrites, MSSAU, SE, ORE); HoistedInstructions.push_back(&I); Changed = true; continue; @@ -873,7 +884,7 @@ PN->setIncomingBlock( i, CFH.getOrCreateHoistedBlock(PN->getIncomingBlock(i))); hoist(*PN, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, - MSSAU, SE, ORE); + LoopMemWrites, MSSAU, SE, ORE); assert(DT->dominates(PN, BB) && "Conditional PHIs not expected"); Changed = true; continue; @@ -1035,12 +1046,12 @@ } } -bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, - Loop *CurLoop, AliasSetTracker *CurAST, - MemorySSAUpdater *MSSAU, - bool TargetExecutesOncePerLoop, - SinkAndHoistLICMFlags *Flags, - OptimizationRemarkEmitter *ORE) { +bool llvm::canSinkOrHoistInst( + Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, + AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, + bool TargetExecutesOncePerLoop, SinkAndHoistLICMFlags *Flags, + OptimizationRemarkEmitter *ORE, + const Optional > &LoopMemWrites) { // If we don't understand the instruction, bail early. if (!isHoistableAndSinkableInst(I)) return false; @@ -1071,7 +1082,7 @@ bool Invalidated; if (CurAST) Invalidated = pointerInvalidatedByLoop(MemoryLocation::get(LI), CurAST, - CurLoop, AA); + CurLoop, AA, LoopMemWrites); else Invalidated = pointerInvalidatedByLoopWithMSSA( MSSA, cast(MSSA->getMemoryAccess(LI)), CurLoop, *Flags); @@ -1120,7 +1131,7 @@ if (CurAST) Invalidated = pointerInvalidatedByLoop( MemoryLocation(Op, LocationSize::unknown(), AAMDNodes()), - CurAST, CurLoop, AA); + CurAST, CurLoop, AA, LoopMemWrites); else Invalidated = pointerInvalidatedByLoopWithMSSA( MSSA, cast(MSSA->getMemoryAccess(CI)), CurLoop, @@ -1392,9 +1403,14 @@ } static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, - AliasSetTracker *AST, MemorySSAUpdater *MSSAU) { - if (AST) + AliasSetTracker *AST, + Optional > LoopMemWrites, + MemorySSAUpdater *MSSAU) { + if (AST) { + if (LoopMemWrites && I.mayWriteToMemory()) + LoopMemWrites->erase(&I); AST->deleteValue(&I); + } if (MSSAU) MSSAU->removeMemoryAccess(&I); SafetyInfo.removeInstruction(&I); @@ -1528,6 +1544,7 @@ /// static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, + Optional > LoopMemWrites, MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE) { LLVM_DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); ORE->emit([&]() { @@ -1541,6 +1558,10 @@ ++NumMovedCalls; ++NumSunk; + if (LoopMemWrites && I.mayWriteToMemory()) { + LoopMemWrites->erase(&I); + } + // Iterate over users to be ready for actual sinking. Replace users via // unreachable blocks with undef and make all user PHIs trivially replaceable. SmallPtrSet VisitedUsers; @@ -1618,9 +1639,10 @@ Instruction *New = sinkThroughTriviallyReplaceablePHI( PN, &I, LI, SunkCopies, SafetyInfo, CurLoop, MSSAU); PN->replaceAllUsesWith(New); - eraseInstruction(*PN, *SafetyInfo, nullptr, nullptr); + eraseInstruction(*PN, *SafetyInfo, nullptr, None, nullptr); Changed = true; } + return Changed; } @@ -1629,6 +1651,7 @@ /// static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, + Optional > LoopMemWrites, MemorySSAUpdater *MSSAU, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE) { LLVM_DEBUG(dbgs() << "LICM hoisting to " << Dest->getName() << ": " << I @@ -1638,6 +1661,10 @@ << ore::NV("Inst", &I); }); + if (LoopMemWrites && I.mayWriteToMemory()) { + LoopMemWrites->erase(&I); + } + // Metadata can be dependent on conditions we are hoisting above. // Conservatively strip all metadata on the instruction unless we were // guaranteed to execute I if we entered the loop, in which case the metadata @@ -2109,7 +2136,7 @@ MSSAU->getMemorySSA()->verifyMemorySSA(); // If the SSAUpdater didn't use the load in the preheader, just zap it now. if (PreheaderLoad->use_empty()) - eraseInstruction(*PreheaderLoad, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(*PreheaderLoad, *SafetyInfo, CurAST, None, MSSAU); return true; } @@ -2143,9 +2170,25 @@ return CurAST; } -static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, - AliasSetTracker *CurAST, Loop *CurLoop, - AliasAnalysis *AA) { +static Optional > collectMemWrites(Loop *CurLoop) { + DenseSet Dest; + int N = 0; + for (BasicBlock *BB : CurLoop->blocks()) + for (Instruction &I : *BB) { + if (N++ >= LICMN2Theshold) { + LLVM_DEBUG(dbgs() << "Alasing N2 threshold exhausted\n"); + return None; + } + if (I.mayWriteToMemory()) + Dest.insert(&I); + } + return Dest; +} + +static bool pointerInvalidatedByLoop( + MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, + AliasAnalysis *AA, + const Optional > &LoopMemWrites) { // First check to see if any of the basic blocks in CurLoop invalidate *V. bool isInvalidatedAccordingToAST = CurAST->getAliasSetFor(MemLoc).isMod(); @@ -2170,22 +2213,35 @@ if (CurLoop->begin() != CurLoop->end()) return true; - int N = 0; - for (BasicBlock *BB : CurLoop->getBlocks()) - for (Instruction &I : *BB) { - if (N >= LICMN2Theshold) { - LLVM_DEBUG(dbgs() << "Alasing N2 threshold exhausted for " - << *(MemLoc.Ptr) << "\n"); - return true; - } - N++; - auto Res = AA->getModRefInfo(&I, MemLoc); - if (isModSet(Res)) { - LLVM_DEBUG(dbgs() << "Aliasing failed on " << I << " for " - << *(MemLoc.Ptr) << "\n"); - return true; - } + Optional > WritesToProcess = LoopMemWrites; + + // If nothing was passed, compute it anew. + if (!WritesToProcess) + WritesToProcess = collectMemWrites(CurLoop); + + // Too many writes to process. + if (!WritesToProcess) + return true; + +#ifndef NDEBUG + auto ControlGroup = collectMemWrites(CurLoop); + assert(ControlGroup && "Failed to collect properly!"); + + assert(ControlGroup->size() == WritesToProcess->size() && "Sizes mismatch!"); + for (Instruction *I : *ControlGroup) + assert(WritesToProcess->count(I) && + "Data about memory-writing instructions is wrong!"); +#endif + + for (Instruction *I : *WritesToProcess) { + auto Res = AA->getModRefInfo(I, MemLoc); + if (isModSet(Res)) { + LLVM_DEBUG(dbgs() << "Aliasing failed on " << *I << " for " + << *(MemLoc.Ptr) << "\n"); + return true; } + } + LLVM_DEBUG(dbgs() << "Aliasing okay for " << *(MemLoc.Ptr) << "\n"); return false; } Index: llvm/test/Transforms/LICM/hoisting.ll =================================================================== --- llvm/test/Transforms/LICM/hoisting.ll +++ llvm/test/Transforms/LICM/hoisting.ll @@ -360,3 +360,22 @@ loopexit: ret i32 %sum } + +define i32 @test_store_load(i32* noalias %a, i32* noalias %b, i32* noalias %c, i32* noalias %d) { +entry: + br label %loop + +loop: + %indvar = phi i32 [ %indvar.next, %loop ], [ 0, %entry ] + %x = load i32, i32* %a + store i32 123, i32* %b + %y = load i32, i32* %c + store i32 456, i32* %d + %sum = add i32 %x, %y + %indvar.next = add i32 %indvar, 1 + %cond = icmp slt i32 %indvar.next, 1000 + br i1 %cond, label %loop, label %exit + +exit: + ret i32 %sum +}