Index: llvm/include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -300,11 +300,12 @@ /// to assess the legality of duplicating atomic loads. Generally, this is /// true when moving out of loop and not true when moving into loops. /// If \p ORE is set use it to emit optimization remarks. -bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, - Loop *CurLoop, AliasSetTracker *CurAST, - MemorySSAUpdater *MSSAU, bool TargetExecutesOncePerLoop, - SinkAndHoistLICMFlags *LICMFlags = nullptr, - OptimizationRemarkEmitter *ORE = nullptr); +bool canSinkOrHoistInst( + Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, + AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, + bool TargetExecutesOncePerLoop, SinkAndHoistLICMFlags *LICMFlags = nullptr, + OptimizationRemarkEmitter *ORE = nullptr, + const Optional > &LoopMemWrites = None); /// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind. Value *createMinMaxOp(IRBuilderBase &Builder, Index: llvm/lib/Transforms/Scalar/LICM.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LICM.cpp +++ llvm/lib/Transforms/Scalar/LICM.cpp @@ -151,9 +151,10 @@ const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, const Instruction *CtxI = nullptr); -static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, - AliasSetTracker *CurAST, Loop *CurLoop, - AliasAnalysis *AA); +static Optional > collectMemWrites(Loop *CurLoop); +static bool pointerInvalidatedByLoop( + MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, + AliasAnalysis *AA, const Optional > &LoopMemWrites); static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, SinkAndHoistLICMFlags &Flags); @@ -467,6 +468,7 @@ // before their children in the worklist and process the worklist in reverse // order. SmallVector Worklist = collectChildrenInLoop(N, CurLoop); + auto LoopMemWrites = collectMemWrites(CurLoop); bool Changed = false; for (DomTreeNode *DTN : reverse(Worklist)) { @@ -486,6 +488,8 @@ salvageKnowledge(&I); salvageDebugInfo(I); ++II; + if (LoopMemWrites) + LoopMemWrites->erase(&I); eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); Changed = true; continue; @@ -505,6 +509,8 @@ if (!FreeInLoop) { ++II; salvageDebugInfoOrMarkUndef(I); + assert((!LoopMemWrites || !LoopMemWrites->count(&I)) && + "Should have no side effect and cannot write memory"); eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); } Changed = true; @@ -770,6 +776,9 @@ // re-hoisted if they end up not dominating all of their uses. SmallVector HoistedInstructions; + // Collect memory-writing instructions. + auto LoopMemWrites = collectMemWrites(CurLoop); + // For PHI hoisting to work we need to hoist blocks before their successors. // We can do this by iterating through the blocks in the loop in reverse // post-order. @@ -795,8 +804,11 @@ CurAST->copyValue(&I, C); // FIXME MSSA: Such replacements may make accesses unoptimized (D51960). I.replaceAllUsesWith(C); - if (isInstructionTriviallyDead(&I, TLI)) + if (isInstructionTriviallyDead(&I, TLI)) { + if (LoopMemWrites) + LoopMemWrites->erase(&I); eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + } Changed = true; continue; } @@ -809,13 +821,15 @@ // to that block. if (CurLoop->hasLoopInvariantOperands(&I) && canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, - ORE) && + ORE, LoopMemWrites) && isSafeToExecuteUnconditionally( I, DT, CurLoop, SafetyInfo, ORE, CurLoop->getLoopPreheader()->getTerminator())) { hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); HoistedInstructions.push_back(&I); + if (LoopMemWrites) + LoopMemWrites->erase(&I); Changed = true; continue; } @@ -837,8 +851,11 @@ SafetyInfo->insertInstructionTo(Product, I.getParent()); Product->insertAfter(&I); I.replaceAllUsesWith(Product); + assert((!LoopMemWrites || !LoopMemWrites->count(&I)) && + "Should not write to memory!"); eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); - + assert((!LoopMemWrites || !LoopMemWrites->count(ReciprocalDivisor)) && + "Should not write to memory!"); hoist(*ReciprocalDivisor, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); HoistedInstructions.push_back(ReciprocalDivisor); @@ -861,6 +878,8 @@ hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); HoistedInstructions.push_back(&I); + if (LoopMemWrites) + LoopMemWrites->erase(&I); Changed = true; continue; } @@ -872,6 +891,8 @@ for (unsigned int i = 0; i < PN->getNumIncomingValues(); ++i) PN->setIncomingBlock( i, CFH.getOrCreateHoistedBlock(PN->getIncomingBlock(i))); + assert((!LoopMemWrites || !LoopMemWrites->count(PN)) && + "Writing Phi?"); hoist(*PN, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); assert(DT->dominates(PN, BB) && "Conditional PHIs not expected"); @@ -1035,12 +1056,12 @@ } } -bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, - Loop *CurLoop, AliasSetTracker *CurAST, - MemorySSAUpdater *MSSAU, - bool TargetExecutesOncePerLoop, - SinkAndHoistLICMFlags *Flags, - OptimizationRemarkEmitter *ORE) { +bool llvm::canSinkOrHoistInst( + Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, + AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, + bool TargetExecutesOncePerLoop, SinkAndHoistLICMFlags *Flags, + OptimizationRemarkEmitter *ORE, + const Optional > &LoopMemWrites) { // If we don't understand the instruction, bail early. if (!isHoistableAndSinkableInst(I)) return false; @@ -1071,7 +1092,7 @@ bool Invalidated; if (CurAST) Invalidated = pointerInvalidatedByLoop(MemoryLocation::get(LI), CurAST, - CurLoop, AA); + CurLoop, AA, LoopMemWrites); else Invalidated = pointerInvalidatedByLoopWithMSSA( MSSA, cast(MSSA->getMemoryAccess(LI)), CurLoop, *Flags); @@ -1120,7 +1141,7 @@ if (CurAST) Invalidated = pointerInvalidatedByLoop( MemoryLocation(Op, LocationSize::unknown(), AAMDNodes()), - CurAST, CurLoop, AA); + CurAST, CurLoop, AA, LoopMemWrites); else Invalidated = pointerInvalidatedByLoopWithMSSA( MSSA, cast(MSSA->getMemoryAccess(CI)), CurLoop, @@ -2143,9 +2164,25 @@ return CurAST; } -static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, - AliasSetTracker *CurAST, Loop *CurLoop, - AliasAnalysis *AA) { +static Optional > collectMemWrites(Loop *CurLoop) { + DenseSet Dest; + int N = 0; + for (BasicBlock *BB : CurLoop->blocks()) + for (Instruction &I : *BB) { + if (N++ >= LICMN2Theshold) { + LLVM_DEBUG(dbgs() << "Alasing N2 threshold exhausted\n"); + return None; + } + if (I.mayWriteToMemory()) + Dest.insert(&I); + } + return Dest; +} + +static bool pointerInvalidatedByLoop( + MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, + AliasAnalysis *AA, + const Optional > &LoopMemWrites) { // First check to see if any of the basic blocks in CurLoop invalidate *V. bool isInvalidatedAccordingToAST = CurAST->getAliasSetFor(MemLoc).isMod(); @@ -2170,22 +2207,36 @@ if (CurLoop->begin() != CurLoop->end()) return true; - int N = 0; - for (BasicBlock *BB : CurLoop->getBlocks()) - for (Instruction &I : *BB) { - if (N >= LICMN2Theshold) { - LLVM_DEBUG(dbgs() << "Alasing N2 threshold exhausted for " - << *(MemLoc.Ptr) << "\n"); - return true; - } - N++; - auto Res = AA->getModRefInfo(&I, MemLoc); - if (isModSet(Res)) { - LLVM_DEBUG(dbgs() << "Aliasing failed on " << I << " for " - << *(MemLoc.Ptr) << "\n"); - return true; - } - } + // Bail if there is too many memwrites to process. + Optional > WritesToProcess = LoopMemWrites; + + // If nothing was passed, compute it anew. + if (!WritesToProcess) + WritesToProcess = collectMemWrites(CurLoop); + + // Too many writes to process. + if (!WritesToProcess) + return true; + +#ifndef NDEBUG + auto ControlGroup = collectMemWrites(CurLoop); + assert(ControlGroup && "Failed to collect properly!"); + + assert(ControlGroup->size() == WritesToProcess->size() && "Sizes mismatch!"); + for (Instruction *I : *ControlGroup) + assert(WritesToProcess->count(I) && + "Data about memory-writing instructions is wrong!"); +#endif + + for (Instruction *I : *WritesToProcess) { + auto Res = AA->getModRefInfo(I, MemLoc); + if (isModSet(Res)) { + LLVM_DEBUG(dbgs() << "Aliasing failed on " << *I << " for " + << *(MemLoc.Ptr) << "\n"); + return true; + } + } + LLVM_DEBUG(dbgs() << "Aliasing okay for " << *(MemLoc.Ptr) << "\n"); return false; } Index: llvm/test/Transforms/LICM/hoisting.ll =================================================================== --- llvm/test/Transforms/LICM/hoisting.ll +++ llvm/test/Transforms/LICM/hoisting.ll @@ -360,3 +360,22 @@ loopexit: ret i32 %sum } + +define i32 @test_store_load(i32* noalias %a, i32* noalias %b, i32* noalias %c, i32* noalias %d) { +entry: + br label %loop + +loop: + %indvar = phi i32 [ %indvar.next, %loop ], [ 0, %entry ] + %x = load i32, i32* %a + store i32 123, i32* %b + %y = load i32, i32* %c + store i32 456, i32* %d + %sum = add i32 %x, %y + %indvar.next = add i32 %indvar, 1 + %cond = icmp slt i32 %indvar.next, 1000 + br i1 %cond, label %loop, label %exit + +exit: + ret i32 %sum +}