diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -12,6 +12,7 @@ #include "InstCombineInternal.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -65,7 +66,6 @@ const auto [Value, IsOffset] = Elem; for (auto &U : Value->uses()) { auto *I = cast(U.getUser()); - if (auto *LI = dyn_cast(I)) { // Ignore non-volatile loads, they are always ok. if (!LI->isSimple()) return false; @@ -261,29 +261,66 @@ public: PointerReplacer(InstCombinerImpl &IC) : IC(IC) {} - bool collectUsers(Instruction &I); + bool collectUsers(Instruction &I, MemTransferInst *Copy); void replacePointer(Instruction &I, Value *V); private: + bool collectUsersRecursive(Instruction &I, MemTransferInst *Copy); void replace(Instruction *I); Value *getReplacement(Value *I); + SmallPtrSet ValuesToRevisit; SmallSetVector Worklist; MapVector WorkMap; InstCombinerImpl &IC; }; } // end anonymous namespace -bool PointerReplacer::collectUsers(Instruction &I) { +bool PointerReplacer::collectUsers(Instruction &I, MemTransferInst *Copy) { + if (!collectUsersRecursive(I, Copy)) + return false; + + // Ensure that all outstanding (indirect) users of I + // are inserted into the Worklist. Return false + // otherwise. + for (auto *V : ValuesToRevisit) + if (!Worklist.contains(cast(V))) + return false; + return true; +} + +bool PointerReplacer::collectUsersRecursive(Instruction &I, + MemTransferInst *Copy) { for (auto *U : I.users()) { auto *Inst = cast(&*U); if (auto *Load = dyn_cast(Inst)) { if (Load->isVolatile()) return false; Worklist.insert(Load); - } else if (isa(Inst) || isa(Inst)) { + } else if (auto *PHI = dyn_cast(Inst)) { + // Check whether all operands of the PHI are in the Worklist. + // If not, keep track of the PHI and check again after collecting + // all users. + bool ValueInserted = false; + for (unsigned Idx = 0; Idx < PHI->getNumIncomingValues(); ++Idx) { + auto *V = PHI->getIncomingValue(Idx); + if (!isa(V)) + return false; + if (!Worklist.contains(cast(V))) { + ValuesToRevisit.insert(PHI); + ValueInserted = true; + break; + } + } + if (ValueInserted) + continue; + + Worklist.insert(PHI); + if (!collectUsers(*PHI, Copy)) + return false; + } else if (isa(Inst)) { Worklist.insert(Inst); - if (!collectUsers(*Inst)) + if (!collectUsers(*Inst, Copy)) return false; } else if (auto *MI = dyn_cast(Inst)) { if (MI->isVolatile()) @@ -296,7 +333,6 @@ return false; } } - return true; } @@ -318,6 +354,14 @@ IC.InsertNewInstWith(NewI, *LT); IC.replaceInstUsesWith(*LT, NewI); WorkMap[LT] = NewI; + } else if (auto *PHI = dyn_cast(I)) { + Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType(); + auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(), + PHI->getName(), PHI); + for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) + NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)), + PHI->getIncomingBlock(I)); + WorkMap[PHI] = NewPHI; } else if (auto *GEP = dyn_cast(I)) { auto *V = getReplacement(GEP->getPointerOperand()); assert(V && "Operand not replaced"); @@ -452,10 +496,9 @@ } PointerReplacer PtrReplacer(*this); - if (PtrReplacer.collectUsers(AI)) { + if (PtrReplacer.collectUsers(AI, Copy)) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); - Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); PtrReplacer.replacePointer(AI, Cast); ++NumGlobalCopies;