Index: llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h =================================================================== --- llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h +++ llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h @@ -39,6 +39,7 @@ class InsertValueInst; class Instruction; class LoopInfo; +class MemorySSA; class OptimizationRemarkEmitter; class PHINode; class ScalarEvolution; @@ -69,6 +70,7 @@ DominatorTree *DT = nullptr; AssumptionCache *AC = nullptr; DemandedBits *DB = nullptr; + MemorySSA *MSSA = nullptr; // nullable, currently preserved, but not used const DataLayout *DL = nullptr; public: @@ -78,7 +80,7 @@ bool runImpl(Function &F, ScalarEvolution *SE_, TargetTransformInfo *TTI_, TargetLibraryInfo *TLI_, AAResults *AA_, LoopInfo *LI_, DominatorTree *DT_, AssumptionCache *AC_, DemandedBits *DB_, - OptimizationRemarkEmitter *ORE_); + MemorySSA *MSSA_, OptimizationRemarkEmitter *ORE_); private: /// Collect store and getelementptr instructions and organize them Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -41,6 +41,8 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -176,6 +178,10 @@ ViewSLPTree("view-slp-tree", cl::Hidden, cl::desc("Display the SLP trees with Graphviz")); +static cl::opt EnableMSSAInSLPVectorizer( + "enable-mssa-in-slp-vectorizer", cl::Hidden, cl::init(false), + cl::desc("Enable MemorySSA for SLPVectorizer in new pass manager")); + // Limit the number of alias checks. The limit is chosen so that // it has no negative effect on the llvm benchmarks. static const unsigned AliasedCheckLimit = 10; @@ -734,9 +740,9 @@ BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AAResults *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB, - const DataLayout *DL, OptimizationRemarkEmitter *ORE) + MemorySSA *MSSA, const DataLayout *DL, OptimizationRemarkEmitter *ORE) : F(Func), SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), - DB(DB), DL(DL), ORE(ORE), Builder(Se->getContext()) { + DB(DB), MSSA(MSSA), DL(DL), ORE(ORE), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); // Use the vector register size specified by the target unless overridden // by a command-line option. @@ -2754,6 +2760,7 @@ DominatorTree *DT; AssumptionCache *AC; DemandedBits *DB; + MemorySSA *MSSA; const DataLayout *DL; OptimizationRemarkEmitter *ORE; @@ -2866,6 +2873,13 @@ } // end namespace llvm BoUpSLP::~BoUpSLP() { + if (MSSA) { + MemorySSAUpdater MSSAU(MSSA); + for (const auto &Pair : DeletedInstructions) { + if (auto *Access = MSSA->getMemoryAccess(Pair.first)) + MSSAU.removeMemoryAccess(Access); + } + } for (const auto &Pair : DeletedInstructions) { // Replace operands of ignored instructions with Undefs in case if they were // marked for deletion. @@ -6435,6 +6449,15 @@ auto *PtrTy = PointerType::get(VecTy, LI->getPointerAddressSpace()); Value *Ptr = Builder.CreateBitCast(LI->getOperand(0), PtrTy); LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlign()); + if (MSSA) { + MemorySSAUpdater MSSAU(MSSA); + auto *Access = MSSA->getMemoryAccess(LI); + assert(Access); + MemoryUseOrDef *NewAccess = + MSSAU.createMemoryAccessBefore(V, Access->getDefiningAccess(), + Access); + MSSAU.insertUse(cast(NewAccess), true); + } Value *NewV = propagateMetadata(V, E->Scalars); ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); @@ -6687,6 +6710,17 @@ commonAlignment(CommonAlignment, cast(V)->getAlign()); NewLI = Builder.CreateMaskedGather(VecTy, VecPtr, CommonAlignment); } + + if (MSSA) { + MemorySSAUpdater MSSAU(MSSA); + auto *Access = MSSA->getMemoryAccess(LI); + assert(Access); + MemoryUseOrDef *NewAccess = + MSSAU.createMemoryAccessAfter(NewLI, Access->getDefiningAccess(), + Access); + MSSAU.insertUse(cast(NewAccess), true); + } + Value *V = propagateMetadata(NewLI, E->Scalars); ShuffleBuilder.addInversedMask(E->ReorderIndices); @@ -6712,6 +6746,16 @@ StoreInst *ST = Builder.CreateAlignedStore(VecValue, VecPtr, SI->getAlign()); + if (MSSA) { + MemorySSAUpdater MSSAU(MSSA); + auto *Access = MSSA->getMemoryAccess(SI); + assert(Access); + MemoryUseOrDef *NewAccess = + MSSAU.createMemoryAccessAfter(ST, Access->getDefiningAccess(), + Access); + MSSAU.insertDef(cast(NewAccess), true); + } + // The pointer operand uses an in-tree scalar, so add the new BitCast to // ExternalUses to make sure that an extract will be generated in the // future. @@ -7665,6 +7709,15 @@ BS->initialFillReadyList(ReadyInsts); Instruction *LastScheduledInst = BS->ScheduleEnd; + MemoryAccess *MemInsertPt = nullptr; + if (MSSA) { + for (auto I = LastScheduledInst->getIterator(); I != BS->BB->end(); I++) { + if (auto *Access = MSSA->getMemoryAccess(&*I)) { + MemInsertPt = Access; + break; + } + } + } // Do the "real" scheduling. while (!ReadyInsts.empty()) { @@ -7677,11 +7730,22 @@ BundleMember = BundleMember->NextInBundle) { Instruction *pickedInst = BundleMember->Inst; if (pickedInst->getNextNode() != LastScheduledInst) { - BS->BB->getInstList().remove(pickedInst); - BS->BB->getInstList().insert(LastScheduledInst->getIterator(), - pickedInst); + pickedInst->moveBefore(LastScheduledInst); + if (MSSA) { + MemorySSAUpdater MSSAU(MSSA); + if (auto *Access = MSSA->getMemoryAccess(pickedInst)) { + if (MemInsertPt) + MSSAU.moveBefore(Access, cast(MemInsertPt)); + else + MSSAU.moveToPlace(Access, BS->BB, + MemorySSA::InsertionPlace::End); + } + } } LastScheduledInst = pickedInst; + if (MSSA) + if (auto *Access = MSSA->getMemoryAccess(LastScheduledInst)) + MemInsertPt = Access; } BS->schedule(picked, ReadyInsts); @@ -8011,7 +8075,7 @@ auto *DB = &getAnalysis().getDemandedBits(); auto *ORE = &getAnalysis().getORE(); - return Impl.runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, ORE); + return Impl.runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, /*MSSA*/nullptr, ORE); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -8045,13 +8109,21 @@ auto *AC = &AM.getResult(F); auto *DB = &AM.getResult(F); auto *ORE = &AM.getResult(F); + auto *MSSA = EnableMSSAInSLPVectorizer ? + &AM.getResult(F).getMSSA() : (MemorySSA*)nullptr; - bool Changed = runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, ORE); + bool Changed = runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, MSSA, ORE); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet(); + if (MSSA) { +#ifdef EXPENSIVE_CHECKS + MSSA->verifyMemorySSA(); +#endif + PA.preserve(); + } return PA; } @@ -8060,6 +8132,7 @@ TargetLibraryInfo *TLI_, AAResults *AA_, LoopInfo *LI_, DominatorTree *DT_, AssumptionCache *AC_, DemandedBits *DB_, + MemorySSA *MSSA, OptimizationRemarkEmitter *ORE_) { if (!RunSLPVectorization) return false; @@ -8090,7 +8163,7 @@ // Use the bottom up slp vectorizer to construct chains that start with // store instructions. - BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB, DL, ORE_); + BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB, MSSA, DL, ORE_); // A general note: the vectorizer must use BoUpSLP::eraseInstruction() to // delete instructions.