diff --git a/llvm/include/llvm/Analysis/MemorySSA.h b/llvm/include/llvm/Analysis/MemorySSA.h --- a/llvm/include/llvm/Analysis/MemorySSA.h +++ b/llvm/include/llvm/Analysis/MemorySSA.h @@ -909,7 +909,7 @@ // Internal MemorySSA utils, for use by MemorySSA classes and walkers class MemorySSAUtil { -protected: +public: friend class GVNHoist; friend class MemorySSAWalker; diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -2855,6 +2855,15 @@ copy(BBRange, const_cast(block_begin()) + ToIdx); } + void setIncomingBlock(const Use &U, BasicBlock *BB) { + assert(this == U.getUser() && "Iterator doesn't point to PHI's Uses?"); + return setIncomingBlock(unsigned(&U - op_begin()), BB); + } + + void setIncomingBlock(Value::const_user_iterator I, BasicBlock *BB) { + setIncomingBlock(I.getUse(), BB); + } + /// Replace every incoming basic block \p Old to basic block \p New. void replaceIncomingBlockWith(const BasicBlock *Old, BasicBlock *New) { assert(New && Old && "PHI node got a null basic block!"); diff --git a/llvm/include/llvm/Transforms/Vectorize/MultiExitLoopVectorize.h b/llvm/include/llvm/Transforms/Vectorize/MultiExitLoopVectorize.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Vectorize/MultiExitLoopVectorize.h @@ -0,0 +1,132 @@ +//===- MultiExitLoopVectorize.h ------------------------------------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_MULTIEXITLOOPVECTORIZE_H +#define LLVM_TRANSFORMS_VECTORIZE_MULTIEXITLOOPVECTORIZE_H + +#include "llvm/IR/PassManager.h" +#include "llvm/IR/ValueMap.h" +#include "llvm/Support/CommandLine.h" + +// My list +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" + +namespace llvm { + +class AAResults; +class AssumptionCache; +class BlockFrequencyInfo; +class DemandedBits; +class DominatorTree; +class Function; +class Loop; +class LoopAccessInfoManager; +class LoopInfo; +class OptimizationRemarkEmitter; +class ProfileSummaryInfo; +class ScalarEvolution; +class TargetLibraryInfo; +class TargetTransformInfo; + +// My list +class LoopBlocksRPO; +class BasicBlockEdge; + +class MultiExitLoopVectorizer { + using BlockEdgeT = std::pair; + using ValueToValueMapTy = ValueMap; + + Function &F; + FunctionAnalysisManager &FAM; + LoopInfo *LI; + ScalarEvolution *SE; + AliasAnalysis *AA; + AssumptionCache *AC = nullptr; + + std::unique_ptr DTU = nullptr; + std::unique_ptr MSSAU = nullptr; + + PreservedAnalyses GlobalPA = PreservedAnalyses::all(); + + DenseMap HoistBarriers; + +public: + MultiExitLoopVectorizer(Function &F, FunctionAnalysisManager &FAM, + LoopInfo &LI, ScalarEvolution &SE, AliasAnalysis &AA, + DominatorTree &DT, MemorySSAAnalysis::Result &MSSA, + AssumptionCache *AC); + + PreservedAnalyses runImpl(); + +private: + PreservedAnalyses &getAllLocalAnalysis() const; + const Instruction *getHoistBarrier(const BasicBlock *BB); + bool hasBarrierOnPath(const Instruction *HoistPoint, const Instruction *I); + bool hasUseInBB(const Instruction *HoistPoint, MemoryDef *Def, + const BasicBlock *BB, SetVector &HoistInsts); + bool hasUseOnPath(const Instruction *HoistPoint, MemoryDef *Def, + SetVector &HoistInsts); + bool safeToHoistScalar(const Instruction *HoistPoint, Instruction *I, + SetVector &HoistInsts); + bool safeToHoistReadOrWrite(const Instruction *HoistPoint, Instruction *I, + bool IsWrite, + SetVector &HoistInsts); + bool isSafeToHoistInst(const Instruction *HoistPoint, Instruction *I, + SetVector &HoistInsts); + bool isSafeToHoistReq(const Instruction *HoistPoint, Value *I, + SetVector &HoistInsts); + PHINode *hoist(Loop &ClonedLoop, SetVector &HoistInsts, + SmallVectorImpl > &Conditions, + ValueToValueMapTy &VMap); + + void collectDeps(Value &V, Loop &VecLoop, Instruction *&SplitPoint, + SmallVectorImpl &EarlyExitGuardDeps); + + void replaceAllReachableUsesWith(Value *OrigValue, Instruction *NewValue); + Use *FindIncomingValueFrom(PHINode &PHI, const BasicBlock *FromBB, + bool MatchBB); + PHINode *FindPhiWithIncomingValueInBB(Value *InVal, BasicBlock *BB); + + void postProcess(Loop &OrigLoop, Loop &ClonedLoop, Loop &VecLoop, + PHINode &EarlyExitGuardLcssa); + + Loop *cloneLoopAndRemoveExits(Loop &L, LoopBlocksRPO &RPO, + SmallVectorImpl &AllExitEdges, + SmallVectorImpl &RemoveExitEdges, + BasicBlock *InsertAfter, const Twine &Suffix, + ValueToValueMapTy &VMap); + void updateAnalysis(PreservedAnalyses &Required); + bool lcssaLoop(Loop &L); + bool lcssaLoops(SmallVectorImpl &Loops); + bool cloneLoops(const SmallVectorImpl &OrigLoops, + SmallVectorImpl &ClonedLoops, + SmallVectorImpl &ExitConds); + bool vectorizeLoops(const SmallVectorImpl &Loops, + SmallVectorImpl &VecLoops); + bool postProcessLoops(const SmallVectorImpl &OrigLoops, + const SmallVectorImpl &ClonedLoops, + const SmallVectorImpl &VecLoops, + const SmallVectorImpl &ExitConds); +}; + +/// The LoopVectorize Pass. +struct MultiExitLoopVectorizePass + : public PassInfoMixin { +public: + MultiExitLoopVectorizePass() = default; + + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + // bool processLoop(Loop *L); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_MULTIEXITLOOPVECTORIZE_H diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -258,6 +258,7 @@ #include "llvm/Transforms/Utils/UnifyLoopExits.h" #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h" #include "llvm/Transforms/Vectorize/LoopVectorize.h" +#include "llvm/Transforms/Vectorize/MultiExitLoopVectorize.h" #include "llvm/Transforms/Vectorize/SLPVectorizer.h" #include "llvm/Transforms/Vectorize/VectorCombine.h" #include diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -430,6 +430,8 @@ FUNCTION_PASS("tsan", ThreadSanitizerPass()) FUNCTION_PASS("memprof", MemProfilerPass()) FUNCTION_PASS("declare-to-assign", llvm::AssignmentTrackingPass()) +FUNCTION_PASS("multi-exit-loop-vectorize", MultiExitLoopVectorizePass()) + #undef FUNCTION_PASS #ifndef FUNCTION_PASS_WITH_PARAMS diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt --- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt +++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt @@ -2,6 +2,7 @@ LoadStoreVectorizer.cpp LoopVectorizationLegality.cpp LoopVectorize.cpp + MultiExitLoopVectorize.cpp SLPVectorizer.cpp Vectorize.cpp VectorCombine.cpp diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -172,6 +172,10 @@ STATISTIC(LoopsAnalyzed, "Number of loops analyzed for vectorization"); STATISTIC(LoopsEpilogueVectorized, "Number of epilogues vectorized"); +static cl::opt ForceScalarEpilogue( + "require-scalar-epilogue", cl::init(false), cl::Hidden, + cl::desc("TODO")); + static cl::opt EnableEpilogueVectorization( "enable-epilogue-vectorization", cl::init(true), cl::Hidden, cl::desc("Enable vectorization of epilogue loops.")); @@ -1556,7 +1560,8 @@ return false; // If we might exit from anywhere but the latch, must run the exiting // iteration in scalar form. - if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) + if (ForceScalarEpilogue || + TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) return true; return IsVectorizing && InterleaveInfo.requiresScalarEpilogue(); } @@ -9802,6 +9807,9 @@ Function *F, Loop *L, LoopVectorizeHints &Hints, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, TargetTransformInfo *TTI, TargetLibraryInfo *TLI, LoopVectorizationLegality &LVL, InterleavedAccessInfo *IAI) { + if (ForceScalarEpilogue) + return CM_ScalarEpilogueAllowed; + // 1) OptSize takes precedence over all other options, i.e. if this is set, // don't look at hints or options, and don't request a scalar epilogue. // (For PGSO, as shouldOptimizeForSize isn't currently accessible from diff --git a/llvm/lib/Transforms/Vectorize/MultiExitLoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/MultiExitLoopVectorize.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Vectorize/MultiExitLoopVectorize.cpp @@ -0,0 +1,1245 @@ +//===- MultiExitLoopVectorize.cpp - A Loop Vectorizer +//------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "llvm/Transforms/Vectorize/MultiExitLoopVectorize.h" + +// My list +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/DominanceFrontier.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm/Transforms/Vectorize/LoopVectorize.h" + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/InjectTLIMappings.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" +#include "llvm/Transforms/Utils/SizeOpts.h" +#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace llvm; + +#define LV_NAME "multi-exit-loop-vectorize" +#define DEBUG_TYPE LV_NAME + +STATISTIC(MultiExitLoopsDetected, "Number of multi-exit loops found"); +STATISTIC(MultiExitLoopsVectorized, "Number of multi-exit loops vectorized"); + +namespace { +void addCfgLocalAnalysis(PreservedAnalyses &CfgPA) { + // TODO: Workaround for how preserved analysis sets are implemented. + CfgPA.preserve(); + CfgPA.preserve(); + CfgPA.preserve(); + CfgPA.preserve(); + CfgPA.preserve(); + CfgPA.preserve(); + CfgPA.preserve(); +} +} + +namespace llvm { + +extern cl::opt EnableVPlanNativePath; + + +MultiExitLoopVectorizer::MultiExitLoopVectorizer( + Function &F, FunctionAnalysisManager &FAM, LoopInfo &LI, + ScalarEvolution &SE, AliasAnalysis &AA, DominatorTree &DT, + MemorySSAAnalysis::Result &MSSA, AssumptionCache *AC) + : F(F), FAM(FAM), LI(&LI), SE(&SE), AA(&AA), AC(AC) { + DTU = std::make_unique(DT, + DomTreeUpdater::UpdateStrategy::Eager); + MSSAU = std::make_unique(&MSSA.getMSSA()); +} + +PreservedAnalyses +MultiExitLoopVectorizePass::run(Function &F, FunctionAnalysisManager &FAM) { + if (F.hasOptSize()) + return PreservedAnalyses::all(); + + auto &LI = FAM.getResult(F); + if (LI.empty()) + return PreservedAnalyses::all(); + + // TODO: Copy-pasted from LV. Can be fixed? + auto &TTI = FAM.getResult(F); + if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true)) && + TTI.getMaxInterleaveFactor(ElementCount::getFixed(1)) < 2) + return PreservedAnalyses::all(); + + auto &SE = FAM.getResult(F); + auto &DT = FAM.getResult(F); + auto &AA = FAM.getResult(F); + auto &MSSA = FAM.getResult(F); + auto *AC = FAM.getCachedResult(F); + + MultiExitLoopVectorizer MELV(F, FAM, LI, SE, AA, DT, MSSA, AC); + auto PA = MELV.runImpl(); + + return PA; +} + +PreservedAnalyses MultiExitLoopVectorizer::runImpl() { + // bool IsAnySimplified = false; + // for (auto *L : LI) + // if (!L->isLoopSimplifyForm()) + // IsAnySimplified |= simplifyLoop(L, &DT, &LI, &SE, AC, MSSAU.get(), + // true /* PreserveLCSSA */); + // + // if (IsAnySimplified) { + // PreservedAnalyses SimplifyPA; + // SimplifyPA.preserve(); + // SimplifyPA.preserve(); + // SimplifyPA.preserve(); + // SimplifyPA.preserve(); + // if (MSSAU.get()) + // SimplifyPA.preserve(); + // SimplifyPA.preserve(); + // PA.intersect(SimplifyPA); + // } + + // Build up a worklist of inner multi-exit loops to vectorize. This is + // necessary as the act of vectorizing can invalidate iterators across the + // loops. + SmallVector OrigLoops; + for (auto *L : *LI) { + LoopBlocksRPO RPO(L); + RPO.perform(LI); + if (!(L->isInnermost() && L->isLoopSimplifyForm() && L->isRotatedForm() && + L->isSafeToClone() && + !containsIrreducibleCFG(RPO, *LI))) { + LLVM_DEBUG(dbgs() << "MELV: Skip loop " << L->getHeader()->getName() + << " due to unsupported CFG"); + continue; + } + + auto *Latch = L->getLoopLatch(); + if (!(Latch && L->isLoopExiting(Latch) && + !isa(SE->getExitCount(L, Latch)))) { + LLVM_DEBUG(dbgs() << "MELV: Skip loop " << L->getHeader()->getName() + << " due to unsupported Latch"); + continue; + } + + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + LLVM_DEBUG(dbgs() << "MELV: Skip loop " << L->getHeader()->getName() + << " due to computable trip count"); + continue; + } + + OrigLoops.emplace_back(L); + ++MultiExitLoopsDetected; + } + + if (OrigLoops.empty()) + return GlobalPA; + + lcssaLoops(OrigLoops); + + SmallVector ClonedLoops; + SmallVector ExitConds; + if (!cloneLoops(OrigLoops, ClonedLoops, ExitConds)) + return GlobalPA; + + SmallVector VecLoops; + if (!vectorizeLoops(ClonedLoops, VecLoops)) + return GlobalPA; + + lcssaLoops(VecLoops); + + postProcessLoops(OrigLoops, ClonedLoops, VecLoops, ExitConds); + + return GlobalPA; +} + +PreservedAnalyses &MultiExitLoopVectorizer::getAllLocalAnalysis() const { + static std::optional AllPA; + if (!AllPA) { + PreservedAnalyses PA; + PA.preserve(); + PA.preserve(); + PA.preserve(); + PA.preserve(); + PA.preserve(); + PA.preserve(); + AllPA = PA; + } + return *AllPA; +} + +const Instruction * +MultiExitLoopVectorizer::getHoistBarrier(const BasicBlock *BB) { + auto FindBarrierInst = [](const BasicBlock *BB) -> const Instruction * { + if (BB->getTerminator()->mayThrow() || BB->hasAddressTaken()) + return &*BB->end(); + + for (const Instruction &I : *BB) + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + return &I; + + auto *I = BB->getFirstNonPHI(); + if (I->isEHPad()) + return I; + return nullptr; + }; + auto It = HoistBarriers.find(BB); + if (It == HoistBarriers.end()) { + It = HoistBarriers.try_emplace(BB, FindBarrierInst(BB)).first; + } + return It->second; +} +bool MultiExitLoopVectorizer::hasUseInBB(const Instruction *HoistPoint, + MemoryDef *Def, const BasicBlock *BB, + SetVector &HoistInsts) { + const MemorySSA::AccessList *Acc = + MSSAU->getMemorySSA()->getBlockAccesses(BB); + if (!Acc) + return false; + + Instruction *SrcInst = Def->getMemoryInst(); + const BasicBlock *SrcBB = SrcInst->getParent(); + const BasicBlock *TargetBB = HoistPoint->getParent(); + for (const MemoryAccess &MA : *Acc) + if (const MemoryUse *MU = dyn_cast(&MA)) { + Instruction *I = MU->getMemoryInst(); + + // Do not check whether MU aliases Def when MU occurs after SrsInst. + if (BB == SrcBB && SrcInst->comesBefore(I)) + continue; + + // Do not check whether MU aliases Def when MU occurs before IP. + if (BB == TargetBB && I->comesBefore(HoistPoint)) + continue; + + if (MemorySSAUtil::defClobbersUseOrDef(Def, MU, *AA) && + !isSafeToHoistReq(HoistPoint, MU->getMemoryInst(), HoistInsts)) + return true; + } + + return false; +} + +bool MultiExitLoopVectorizer::hasUseOnPath( + const Instruction *HoistPoint, MemoryDef *Def, + SetVector &HoistInsts) { + const BasicBlock *TargetBB = HoistPoint->getParent(); + const BasicBlock *SrcBB = Def->getBlock(); + assert(DTU->getDomTree().dominates(TargetBB, SrcBB) && "invalid path"); + assert(DTU->getDomTree().dominates(Def->getDefiningAccess()->getBlock(), + TargetBB) && + "Def does not dominate new hoisting point"); + + // Walk all basic blocks reachable in depth-first iteration on the inverse + // CFG from OldBB to NewBB. These blocks are all the blocks that may be + // executed between the execution of NewBB and OldBB. Hoisting an expression + // from OldBB into NewBB has to be safe on all execution paths. + for (auto BBIT = idf_begin(SrcBB), E = idf_end(SrcBB); BBIT != E;) { + const BasicBlock *BB = *BBIT; + + // Check that we do not move a store before load. + if (hasUseInBB(HoistPoint, Def, BB, HoistInsts)) + return true; + + if (BB == TargetBB) { + // Stop traversal when reaching NewHoistPt. + BBIT.skipChildren(); + continue; + } + ++BBIT; + } + + return false; +} + +bool MultiExitLoopVectorizer::safeToHoistScalar( + const Instruction *HoistPoint, Instruction *I, + SetVector &HoistInsts) { + auto *TargetBB = HoistPoint->getParent(); + auto *SrcBB = I->getParent(); + + assert(DTU->getDomTree().dominates(TargetBB, SrcBB) && "Invalid path"); + + // Walk all basic blocks reachable in depth-first iteration on + // the inverse CFG from BBInsn to TargetBB. These blocks are all the + // blocks that may be executed between the execution of TargetBB and + // BBInsn. Hoisting an expression from BBInsn into TargetBB has to be safe + // on all execution paths. + for (auto BBIT = idf_begin(SrcBB), E = idf_end(SrcBB); BBIT != E;) { + const BasicBlock *BB = *BBIT; + + if (auto *BarrierInst = getHoistBarrier(BB)) { + if (BB == SrcBB && I->comesBefore(BarrierInst)) + continue; + if (BB == TargetBB && BarrierInst->comesBefore(HoistPoint)) + continue; + return false; + } + if (BB == TargetBB) { + // Stop traversal when reaching TargetBB. + BBIT.skipChildren(); + continue; + } + ++BBIT; + } + + HoistInsts.insert(I); + return true; +} + +bool MultiExitLoopVectorizer::safeToHoistReadOrWrite( + const Instruction *HoistPoint, Instruction *I, bool IsWrite, + SetVector &HoistInsts) { + + // In place hoisting is safe. + if (HoistPoint == I) { + return true; + } + + MemoryUseOrDef *MemAccess = MSSAU->getMemorySSA()->getMemoryAccess(I); + if (!MemAccess) + return false; + + const BasicBlock *TargetBB = HoistPoint->getParent(); + MemoryAccess *DefiningAccess = MemAccess->getDefiningAccess(); + BasicBlock *DefiningAccessBB = DefiningAccess->getBlock(); + + // Check that "DefiningAccess" is "available" by "HoistPoint". + if (!MSSAU->getMemorySSA()->isLiveOnEntryDef(DefiningAccess)) { + if (TargetBB == DefiningAccessBB) { + if (auto *UD = dyn_cast(DefiningAccess)) + if (!UD->getMemoryInst()->comesBefore(HoistPoint)) + return false; + } else if (DTU->getDomTree().properlyDominates(TargetBB, DefiningAccessBB)) + return false; + } + + // Check for unsafe hoistings due to side effects. + if (!safeToHoistScalar(HoistPoint, I, HoistInsts)) + return false; + if (IsWrite && + hasUseOnPath(HoistPoint, cast(MemAccess), HoistInsts)) + return false; + + HoistInsts.insert(I); + return true; +} + +bool MultiExitLoopVectorizer::isSafeToHoistInst( + const Instruction *HoistPoint, Instruction *I, + SetVector &HoistInsts) { + if (HoistInsts.count(I)) + return true; + + if (isa(I) || isa(I)) + return safeToHoistReadOrWrite(HoistPoint, I, isa(I), HoistInsts); + else if (auto *Call = dyn_cast(I)) { + if (auto *Intr = dyn_cast(Call)) { + if (isa(Intr) || + Intr->getIntrinsicID() == Intrinsic::assume || + Intr->getIntrinsicID() == Intrinsic::sideeffect) + return false; + } + if (Call->mayHaveSideEffects() || Call->isConvergent()) + return false; + + if (Call->doesNotAccessMemory()) + return safeToHoistScalar(HoistPoint, Call, HoistInsts); + else + return safeToHoistReadOrWrite(HoistPoint, Call, !Call->onlyReadsMemory(), + HoistInsts); + } + return safeToHoistScalar(HoistPoint, I, HoistInsts); +} + +bool MultiExitLoopVectorizer::isSafeToHoistReq( + const Instruction *HoistPoint, Value *V, + SetVector &HoistInsts) { + + if (V == HoistPoint) { + HoistInsts.insert(cast(V)); + return true; + } + + if (DTU->getDomTree().dominates(V, HoistPoint)) + return true; + + if (auto *I = dyn_cast(V)) { + if (isa(I)) + return false; + + for (Use &U : I->operands()) { + if (!isSafeToHoistReq(HoistPoint, U.get(), HoistInsts)) { + return false; + } + } + return isSafeToHoistInst(HoistPoint, I, HoistInsts); + } + return false; +} + +PHINode *MultiExitLoopVectorizer::hoist( + Loop &ClonedLoop, SetVector &HoistInsts, + SmallVectorImpl> &Conditions, + ValueToValueMapTy &VMap) { + BasicBlock *ClonedLoopPreheader = ClonedLoop.getLoopPreheader(); + BasicBlock *ClonedLoopHeader = ClonedLoop.getHeader(); + BasicBlock *ClonedLoopLatch = ClonedLoop.getLoopLatch(); + BasicBlock *ClonedLoopExitBB = ClonedLoop.getUniqueExitBlock(); + LLVMContext &Context = ClonedLoopHeader->getContext(); + + SetVector ClonedHoistInsts; + for (Instruction *I : HoistInsts) + ClonedHoistInsts.insert(cast(VMap[I])); + + SmallVector, 16> ClonedConditions; + ClonedConditions.reserve(Conditions.size()); + for (auto &C : Conditions) { + ClonedConditions.emplace_back(cast(VMap[C.first]), C.second); + } + + // TODO: Skip ... + auto ClonedHoistPointIt = ClonedLoopHeader->getFirstInsertionPt(); + BasicBlock *HoistBB = ClonedHoistPointIt->getParent(); + for (auto It = ClonedHoistPointIt; It != HoistBB->end(); ++It) { + Instruction *BBInst = &*It; + if (!ClonedHoistInsts.contains(BBInst)) { + ClonedHoistPointIt = It; + break; + } + ClonedHoistInsts.remove_if([&](Instruction *I) { return I == BBInst; }); + } + + auto *HoistBBAccesses = MSSAU->getMemorySSA()->getBlockAccesses(HoistBB); + MemoryUseOrDef *FirstUseOrDefAccess = nullptr; + if (HoistBBAccesses) + for (const MemoryAccess &Access : *HoistBBAccesses) + if (const MemoryUseOrDef *UD = dyn_cast(&Access)) + if (ClonedHoistPointIt->comesBefore(UD->getMemoryInst())) + FirstUseOrDefAccess = const_cast(UD); + + for (Instruction *I : ClonedHoistInsts) { + I->moveBefore(*HoistBB, ClonedHoistPointIt); + if (auto *Access = MSSAU->getMemorySSA()->getMemoryAccess(I)) { + if (FirstUseOrDefAccess) + MSSAU->moveBefore(Access, FirstUseOrDefAccess); + else + MSSAU->moveToPlace(Access, HoistBB, + MemorySSA::InsertionPlace::Beginning); + } + } + + IRBuilder<> Builder(Context); + + auto IP = ClonedLoopHeader->begin(); + Builder.SetInsertPoint(IP->getParent(), IP); + + auto *EarlyExitGuardInit = + Builder.CreatePHI(Type::getInt1Ty(Context), 2, "early.exit.guard"); + EarlyExitGuardInit->addIncoming(ConstantInt::getFalse(Context), + ClonedLoopPreheader); + + Builder.SetInsertPoint(ClonedHoistPointIt->getParent(), ClonedHoistPointIt); + Value *EarlyExitGuard = EarlyExitGuardInit; + for (auto &C : ClonedConditions) { + Value *V = C.first; + if (!C.second) { + V = Builder.CreateNot(V, Twine(V->getName()) + ".neg"); + } + EarlyExitGuard = Builder.CreateOr(EarlyExitGuard, V); + } + + EarlyExitGuardInit->addIncoming(EarlyExitGuard, ClonedLoopLatch); + + IP = ClonedLoopExitBB->getFirstInsertionPt(); + Builder.SetInsertPoint(IP->getParent(), IP); + + auto *EarlyExitGuardLcssa = + Builder.CreatePHI(EarlyExitGuard->getType(), 1, + Twine(EarlyExitGuard->getName()) + ".lcssa"); + EarlyExitGuardLcssa->addIncoming(EarlyExitGuard, ClonedLoopLatch); + return EarlyExitGuardLcssa; +} + +void MultiExitLoopVectorizer::collectDeps( + Value &V, Loop &VecLoop, Instruction *&SplitPoint, + SmallVectorImpl &EarlyExitGuardDeps) { + if (DTU->getDomTree().dominates(&V, SplitPoint)) + return; + + Instruction *I = cast(&V); + if (VecLoop.contains(I->getParent())) { + SplitPoint = I; + return; + } + + EarlyExitGuardDeps.emplace_back(I); + for (auto &Use : I->operands()) { + Value *UseVal = Use.get(); + if (auto *Phi = dyn_cast(UseVal)) { + auto *NewUseVal = + FindIncomingValueFrom(*Phi, VecLoop.getLoopLatch(), true)->get(); + Use.set(NewUseVal); + UseVal = NewUseVal; + } + if (auto *UI = dyn_cast(UseVal)) { + // Bypass PHI + collectDeps(*UI, VecLoop, SplitPoint, EarlyExitGuardDeps); + } + } +} + +void MultiExitLoopVectorizer::replaceAllReachableUsesWith( + Value *OrigValue, Instruction *NewValue) { + + DenseMap, uint32_t> BlockRPONumber; + uint32_t NextBlockNumber = 1; + ReversePostOrderTraversal RPOT(&F); + for (BasicBlock *BB : RPOT) + BlockRPONumber[BB] = NextBlockNumber++; + + BasicBlock *NewValueBB = NewValue->getParent(); + uint32_t NewValRPON = BlockRPONumber[NewValueBB]; + for (Use &U : make_early_inc_range(OrigValue->uses())) { + if (Instruction *User = dyn_cast(U.getUser())) { + BasicBlock *UserBB = User->getParent(); + if ((UserBB == NewValueBB && (NewValue->comesBefore(User))) || + NewValRPON < BlockRPONumber[UserBB]) + U.set(NewValue); + } + } +} + +Use *MultiExitLoopVectorizer::FindIncomingValueFrom(PHINode &PHI, + const BasicBlock *FromBB, + bool MatchBB) { + Use *ResUse = nullptr; + for (auto &IncomingUse : PHI.incoming_values()) { + BasicBlock *IncomingBB = PHI.getIncomingBlock(IncomingUse); + if (MatchBB == (IncomingBB == FromBB)) { + assert(!ResUse && "More than one incoming value outside of cloned loop"); + ResUse = &IncomingUse; + } + } + return ResUse; +} + +PHINode *MultiExitLoopVectorizer::FindPhiWithIncomingValueInBB(Value *InVal, + BasicBlock *BB) { + for (User *U : InVal->users()) { + PHINode *PHIUser = dyn_cast(U); + if (PHIUser && PHIUser->getParent() == BB) + return PHIUser; + } + return nullptr; +} + +Loop *MultiExitLoopVectorizer::cloneLoopAndRemoveExits( + Loop &OrigLoop, LoopBlocksRPO &RPO, + SmallVectorImpl &AllExitEdges, + SmallVectorImpl &RemoveExitEdges, BasicBlock *InsertAfter, + const Twine &Suffix, ValueToValueMapTy &VMap) { + DominatorTree &DT = DTU->getDomTree(); + SmallVector DTUpdates; + + Function *F = InsertAfter->getParent(); + LLVMContext &Context = F->getContext(); + IRBuilder<> Builder(Context); + + BasicBlock *OrigPreHeader = OrigLoop.getLoopPreheader(); + BasicBlock *OrigHeader = OrigLoop.getHeader(); + BasicBlock *OrigLatch = OrigLoop.getLoopLatch(); + + BasicBlock *ClonedPreHeader = splitBlockBefore( + OrigPreHeader, OrigPreHeader->getTerminator(), DTU.get(), LI, MSSAU.get(), + Twine(OrigPreHeader->getName()) + Suffix); + VMap[OrigPreHeader] = ClonedPreHeader; + + // For each block in the original loop, create a new copy, + // and update the value map with the newly created values. + SmallVector NewBlocks; + NewBlocks.reserve(OrigLoop.getNumBlocks()); + for (BasicBlock *OrigBB : RPO) { + BasicBlock *NewBB = CloneBasicBlock(OrigBB, VMap, Suffix, F); + NewBlocks.push_back(NewBB); + VMap[OrigBB] = NewBB; + NewBB->moveBefore(OrigPreHeader); + + // Update DT. Copy information from original loop to unrolled loop. + BasicBlock *IDomBB = DT.getNode(OrigBB)->getIDom()->getBlock(); + DT.addNewBlock(NewBB, cast(VMap[IDomBB])); + } + + remapInstructionsInBlocks(NewBlocks, VMap); + + // Set of unique exit blocks form cloned loop. + SmallSetVector ClonedUniqueExits; + llvm::for_each(AllExitEdges, [&](BlockEdgeT &ExitEdge) { + ClonedUniqueExits.insert(ExitEdge.second); + }); + + MSSAU->updateForClonedLoop(RPO, ClonedUniqueExits.getArrayRef(), VMap); + + Loop *ClonedLoop = + cloneLoop(&OrigLoop, OrigLoop.getParentLoop(), VMap, LI, nullptr); + BasicBlock *ClonedLatch = cast(VMap[OrigLatch]); + + // Make 'ClonedPreHeader' preheader of cloned loop. + BasicBlock *OldSucc = ClonedPreHeader->getSingleSuccessor(); + auto *BI = cast(ClonedPreHeader->getTerminator()); + BI->setSuccessor(0, ClonedLoop->getHeader()); + DTUpdates.emplace_back(DominatorTree::Delete, ClonedPreHeader, OldSucc); + DTUpdates.emplace_back(DominatorTree::Insert, ClonedPreHeader, + ClonedPreHeader->getSingleSuccessor()); + + // Remove all dedicated exiting edges (Theses are exiting edges for + // data-dependent exits). + for (BlockEdgeT &ExitEdge : RemoveExitEdges) { + BasicBlock *OrigExitingBB = ExitEdge.first; + BasicBlock *ClonedExitingBB = cast(VMap[OrigExitingBB]); + auto *ClonedBranch = dyn_cast(ClonedExitingBB->getTerminator()); + assert((ClonedBranch && ClonedBranch->isConditional()) && + "Unsupported terminator"); + unsigned ExitingIndex = + (ClonedBranch->getSuccessor(0) == ExitEdge.second) ? 0 : 1; + ClonedBranch->eraseFromParent(); + BranchInst::Create(ClonedBranch->getSuccessor(1 - ExitingIndex), + ClonedExitingBB); + ClonedUniqueExits.remove(ExitEdge.second); + DTUpdates.emplace_back(DominatorTree::Delete, ClonedExitingBB, + ExitEdge.second); + } + + // Fixup PHIs in "OrigHeader". Need to update current incoming value from + // "OrigPreHeader" with the corresponding outgoing value from "ClonedLatch". + Builder.SetInsertPoint(OrigPreHeader, OrigPreHeader->getFirstInsertionPt()); + for (PHINode &OrigPhi : OrigHeader->phis()) { + int Index = OrigPhi.getBasicBlockIndex(OrigPreHeader); + if (Index == -1) + continue; + PHINode *ClonedPHI = cast(VMap[&OrigPhi]); + auto *ResumePHI = Builder.CreatePHI( + OrigPhi.getType(), 1, Twine(ClonedPHI->getName()) + ".resume"); + auto *ExitVal = ClonedPHI->getIncomingValueForBlock(ClonedLatch); + ResumePHI->addIncoming(ExitVal, ClonedLatch); + OrigPhi.setIncomingValue(Index, ResumePHI); + } + + { + // Reconnect "ClonedLatch" exiting edge to "OrigPreHeader". + BranchInst *OrigTerminator = + dyn_cast(OrigLatch->getTerminator()); + assert((OrigTerminator && OrigTerminator->isConditional()) && + "Unsupported terminator"); + BranchInst *ClonedTerminator = cast(VMap[OrigTerminator]); + unsigned ExitingIndex = + ClonedLoop->contains(ClonedTerminator->getSuccessor(0)) ? 1 : 0; + BasicBlock *OrigExit = ClonedTerminator->getSuccessor(ExitingIndex); + BasicBlock *ClonedExit = ClonedTerminator->getSuccessor(ExitingIndex); + ClonedTerminator->setSuccessor(ExitingIndex, OrigPreHeader); + + ClonedUniqueExits.remove(ClonedExit); + ClonedUniqueExits.insert(OrigPreHeader); + + DTUpdates.emplace_back(DominatorTree::Delete, ClonedLatch, ClonedExit); + DTUpdates.emplace_back(DominatorTree::Insert, ClonedLatch, OrigPreHeader); + + // FIXME: DT updater chokes if we apply all updates at once at the end. Do + // an intermediate flush. + DTU->applyUpdates(DTUpdates); + MSSAU->applyUpdates(DTUpdates, DTU->getDomTree()); + DTUpdates.clear(); + + // Insert iteration guard. + BasicBlock *NewClonedExit = splitBlockBefore( + OrigPreHeader, OrigPreHeader->getFirstNonPHI(), DTU.get(), LI, + MSSAU.get(), Twine(OrigPreHeader->getName()) + Suffix + ".exit"); + NewClonedExit->getTerminator()->eraseFromParent(); + + Builder.SetInsertPoint(NewClonedExit); + ICmpInst *ClonedCmpInstCopy = + cast(ClonedLoop->getLatchCmpInst()->clone()); + Builder.Insert(ClonedCmpInstCopy); + for (Use &Op : ClonedCmpInstCopy->operands()) { + if (Instruction *OpInst = dyn_cast(Op.get())) + if (ClonedLoop->contains(OpInst->getParent())) { + PHINode *ResumePhi = + FindPhiWithIncomingValueInBB(OpInst, NewClonedExit); + assert(ResumePhi && "Should generate resuming PHI?"); + Op.set(ResumePhi); + } + } + + BranchInst *ClonedTerminatorCopy = + dyn_cast(ClonedLatch->getTerminator()->clone()); + Builder.Insert(ClonedTerminatorCopy); + ClonedTerminatorCopy->setCondition(ClonedCmpInstCopy); + ClonedTerminatorCopy->setSuccessor(1 - ExitingIndex, OrigPreHeader); + ClonedTerminatorCopy->setSuccessor(ExitingIndex, OrigExit); + + DTUpdates.emplace_back(DominatorTree::Insert, NewClonedExit, OrigExit); + + for (PHINode &ExitPhi : OrigExit->phis()) { + assert(ExitPhi.getNumIncomingValues() == 1 && + "Expected exactly one incoming value"); + Value *IncVal = ExitPhi.getIncomingValue(0); + Instruction *IncInst = dyn_cast(IncVal); + if (IncInst && OrigLoop.contains(IncInst->getParent())) { + Value *ClonedIncVal = VMap[IncInst]; + PHINode *ResumePhi = + FindPhiWithIncomingValueInBB(ClonedIncVal, NewClonedExit); + if (!ResumePhi) { + Builder.SetInsertPoint(NewClonedExit, + NewClonedExit->getFirstInsertionPt()); + ResumePhi = + Builder.CreatePHI(ClonedIncVal->getType(), 1, + Twine(ClonedIncVal->getName()) + ".resume"); + ResumePhi->addIncoming(ClonedIncVal, ClonedLatch); + } + IncVal = ResumePhi; + } + ExitPhi.addIncoming(IncVal, NewClonedExit); + } + } + + DTU->applyUpdates(DTUpdates); + +#ifndef NDEBUG + // TODO: Remove once implementation stabilizes. + DTU->flush(); + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); +#endif + + // MSSAU->updateExitBlocksForClonedLoop(ClonedUniqueExits.getArrayRef(), + // VMap, DT); + MSSAU->applyUpdates(DTUpdates, DTU->getDomTree()); + MSSAU->getMemorySSA()->verifyMemorySSA(); + + // Identify what noalias metadata is inside the loop: if it is inside the + // loop, the associated metadata must be cloned for each iteration. + SmallVector LoopLocalNoAliasDeclScopes; + identifyNoAliasScopesToClone(OrigLoop.getBlocks(), + LoopLocalNoAliasDeclScopes); + // Identify what other metadata depends on the cloned version. After + // cloning, replace the metadata with the corrected version for both + // memory instructions and noalias intrinsics. + cloneAndAdaptNoAliasScopes(LoopLocalNoAliasDeclScopes, NewBlocks, Context, + Suffix.getSingleStringRef()); + + return ClonedLoop; +} + +void MultiExitLoopVectorizer::updateAnalysis(PreservedAnalyses &Required) { + // DT should be flushed before running any analysis if we claim to preserve + // it. + if (GlobalPA.getChecker().preserved()) + // Update DT/PDT. + DTU->flush(); + + FAM.invalidate(F, GlobalPA); + + if (!GlobalPA.getChecker().preserved() && + Required.getChecker().preserved()) { + llvm_unreachable("We don't update cached pointers to Loops at the moment"); + LI = &FAM.getResult(F); + GlobalPA.preserve(); + } + if (!GlobalPA.getChecker().preserved() && + Required.getChecker().preserved()) { + SE = &FAM.getResult(F); + GlobalPA.preserve(); + } + if (!GlobalPA.getChecker().preserved() && + Required.getChecker().preserved()) { + AA = &FAM.getResult(F); + GlobalPA.preserve(); + } + if (!GlobalPA.getChecker().preserved() && + Required.getChecker().preserved()) { + auto &DT = FAM.getResult(F); + auto UpdateStrategy = DTU->isLazy() ? DomTreeUpdater::UpdateStrategy::Lazy + : DomTreeUpdater::UpdateStrategy::Eager; + DTU = std::make_unique(DT, UpdateStrategy); + GlobalPA.preserve(); + } + if (!GlobalPA.getChecker().preserved() && + Required.getChecker().preserved()) { + auto &MSSA = FAM.getResult(F); + MSSAU = std::make_unique(&MSSA.getMSSA()); + } + if (!GlobalPA.getChecker().preserved() && + Required.getChecker().preserved()) { + AC = FAM.getCachedResult(F); + } +} + +bool MultiExitLoopVectorizer::lcssaLoop(Loop &L) { + SmallVector Loops{&L}; + return lcssaLoops(Loops); +} + +bool MultiExitLoopVectorizer::lcssaLoops(SmallVectorImpl &Loops) { + + PreservedAnalyses RequiredAnalysis; + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + + updateAnalysis(RequiredAnalysis); + + bool IsAnyLCSSAFormed = false; + for (auto *L : Loops) { + if (!L) + continue; + + IsAnyLCSSAFormed |= formLCSSA(*L, DTU->getDomTree(), LI, SE); + } + + if (!IsAnyLCSSAFormed) + return false; + + PreservedAnalyses LCSSAPA; + addCfgLocalAnalysis(LCSSAPA); + LCSSAPA.preserve(); + LCSSAPA.preserve(); + LCSSAPA.preserve(); + GlobalPA.intersect(LCSSAPA); + + updateAnalysis(getAllLocalAnalysis()); + + LLVM_DEBUG(dbgs() << "*** MELV intermediate IR After LoopLCSSA ***\n" << F); + assert(!verifyFunction(F, &dbgs()) && + "MELV: Function verification failed after LoopLCSSA"); + + return true; +} + +// Clone all loops. +// NOTE: Analysis invalidated only once. +bool MultiExitLoopVectorizer::cloneLoops( + const SmallVectorImpl &OrigLoops, + SmallVectorImpl &ClonedLoops, + SmallVectorImpl &ExitConds) { + + PreservedAnalyses RequiredAnalysis; + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + + updateAnalysis(RequiredAnalysis); + + bool AnyCloned = false; + for (Loop *L : OrigLoops) { + LoopBlocksRPO RPO(L); + RPO.perform(LI); + + SmallVector, 16> Conditions; + SmallVector AllExitEdges; + SmallVector NotComputableExitEdges; + L->getExitEdges(AllExitEdges); + + for (BlockEdgeT &ExitEdge : AllExitEdges) { + BasicBlock *ExitingBB = ExitEdge.first; + auto Branch = dyn_cast(ExitingBB->getTerminator()); + assert((Branch && Branch->isConditional()) && "Unsupported terminator"); + if (isa(SE->getExitCount(L, ExitingBB))) { + NotComputableExitEdges.push_back(std::move(ExitEdge)); + Conditions.emplace_back(Branch->getCondition(), + Branch->getSuccessor(0) == ExitEdge.second); + } + } + + assert(!Conditions.empty() && "No exit removed?"); + + // SmallSetVector HoistInsts; + SetVector HoistInsts; + Instruction *HoistPoint = &*L->getHeader()->getFirstInsertionPt(); + bool CanHoist = llvm::all_of(Conditions, [&](auto &C) { + return isSafeToHoistReq(HoistPoint, C.first, HoistInsts); + }); + + Loop *ClonedLoop = nullptr; + PHINode *ExitCond = nullptr; + if (CanHoist) { + ValueToValueMapTy VMap; + ClonedLoop = + cloneLoopAndRemoveExits(*L, RPO, AllExitEdges, NotComputableExitEdges, + L->getLoopPreheader(), ".melv", VMap); + ExitCond = hoist(*ClonedLoop, HoistInsts, Conditions, VMap); + AnyCloned = true; + } + ClonedLoops.emplace_back(ClonedLoop); + ExitConds.emplace_back(ExitCond); + } + + if (!AnyCloned) + return false; + + PreservedAnalyses CloningPA; + CloningPA.preserve(); + CloningPA.preserve(); + CloningPA.preserve(); + // Preserved? + // CloningPA.preserve(); + GlobalPA.intersect(CloningPA); + + updateAnalysis(getAllLocalAnalysis()); + + LLVM_DEBUG(dbgs() << "*** MELV intermediate IR After LoopCloning ***\n" << F); + assert(!verifyFunction(F, &dbgs()) && + "MELV: Function verification failed after LoopCloning"); + return true; +} + +bool MultiExitLoopVectorizer::vectorizeLoops( + const SmallVectorImpl &Loops, SmallVectorImpl &VecLoops) { + assert(!EnableVPlanNativePath && + "Multi-exit loop vectorization is not supported in VPlan native path"); + + PreservedAnalyses RequiredAnalysis; + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + + auto FindVecLoop = [&](BasicBlock *Start, BasicBlock *End) -> Loop * { + for (auto BBIT = df_begin(Start), E = df_end(Start); BBIT != E;) { + const BasicBlock *BB = *BBIT; + + if (auto *L = LI->getLoopFor(BB)) + return L; + + if (BB == End) { + // Stop traversal when reaching NewHoistPt. + BBIT.skipChildren(); + continue; + } + ++BBIT; + } + return nullptr; + }; + + bool AnyVectorized = false; + for (Loop *L : Loops) { + + if (!L) + continue; + + assert(L->isInnermost() && L->isLCSSAForm(DTU->getDomTree()) && + "Loop has invalid form for LV"); + + // Update analysis before constructing LV. + updateAnalysis(RequiredAnalysis); + + LoopVectorizePass LV; + LV.SE = SE; + LV.LI = LI; + LV.TTI = &FAM.getResult(F); + LV.DT = &DTU->getDomTree(); + LV.BFI = &FAM.getResult(F); + LV.TLI = &FAM.getResult(F); + LV.DB = &FAM.getResult(F); + LV.AC = AC ? AC : &FAM.getResult(F); + LV.LAIs = &FAM.getResult(F); + LV.ORE = &FAM.getResult(F); + auto &MAMProxy = FAM.getResult(F); + LV.PSI = MAMProxy.getCachedResult(*F.getParent()); + + BasicBlock *OrigPreheader = L->getLoopPreheader(); + if (!LV.processLoop(L)) { + VecLoops.emplace_back(nullptr); + continue; + } + + AnyVectorized = true; + Loop *VecLoop = FindVecLoop(OrigPreheader, L->getLoopPreheader()); + assert(VecLoop && "Can't find vector body"); + assert(L != VecLoop && "Vectorized without scalar epilogue?"); + VecLoops.emplace_back(VecLoop); + + PreservedAnalyses LVPA; + LVPA.preserve(); + LVPA.preserve(); + GlobalPA.intersect(LVPA); + } + + if (!AnyVectorized) + return false; + + LLVM_DEBUG(dbgs() << "*** MELV intermediate IR After LoopVectorizer ***\n" + << F); + assert(!verifyFunction(F, &dbgs()) && + "MELV: Function verification failed after LoopVectorizer"); + + updateAnalysis(getAllLocalAnalysis()); + + return true; +} + +void MultiExitLoopVectorizer::postProcess(Loop &OrigLoop, Loop &ClonedLoop, + Loop &VecLoop, + PHINode &EarlyExitGuardLcssa) { + BasicBlock *OrigLoopPreheader = OrigLoop.getLoopPreheader(); + BasicBlock *ClonedLoopPreheader = ClonedLoop.getLoopPreheader(); + BasicBlock *ClonedLoopHeader = ClonedLoop.getHeader(); + BasicBlock *ClonedLoopLatch = ClonedLoop.getLoopLatch(); + BasicBlock *ClonedLoopExit = ClonedLoop.getUniqueExitBlock(); + BasicBlock *VecLoopPreheader = VecLoop.getLoopPreheader(); + BasicBlock *VecLoopHeader = VecLoop.getHeader(); + BasicBlock *VecLoopLatch = VecLoop.getLoopLatch(); + BasicBlock *VecLoopExit = VecLoop.getUniqueExitBlock(); + LLVMContext &Context = ClonedLoopHeader->getContext(); + IRBuilder<> Builder(Context); + SmallVector DTUpdates; + + // 1-1) Fix up outcoming values from scalar epilogue. + { + Builder.SetInsertPoint(ClonedLoopPreheader, + ClonedLoopPreheader->getFirstInsertionPt()); + for (PHINode &OrigPhi : ClonedLoopExit->phis()) { + Type *OpType = OrigPhi.getType(); + Use *IncValUse = FindIncomingValueFrom(OrigPhi, ClonedLoopLatch, true); + Instruction *IncValInst = dyn_cast(IncValUse->get()); + if (IncValInst && ClonedLoop.contains(IncValInst->getParent())) { + // Find PHI in ClonedHeader with IVFromClonedLatch incoming value. + PHINode *PHIUser = + FindPhiWithIncomingValueInBB(IncValInst, ClonedLoopHeader); + Value *NewIncVal = nullptr; + if (PHIUser) { + NewIncVal = + FindIncomingValueFrom(*PHIUser, ClonedLoopPreheader, true)->get(); + } else { + assert(pred_size(ClonedLoopPreheader) == 2 && + "Cloned loop expected to have 2 predecessors"); + auto PredIt = pred_begin(ClonedLoopPreheader); + BasicBlock *NotVecLoopPred = + *PredIt != VecLoopExit ? *PredIt : *(++PredIt); + Use *VecLoopOutValUse = + FindIncomingValueFrom(OrigPhi, VecLoopExit, true); + Value *VecLoopOutVal = VecLoopOutValUse ? VecLoopOutValUse->get() + : UndefValue::get(OpType); + PHINode *MergePhi = Builder.CreatePHI( + OpType, 2, Twine(VecLoopOutVal->getName()) + ".vec.out.merge"); + MergePhi->addIncoming(UndefValue::get(OpType), NotVecLoopPred); + MergePhi->addIncoming(VecLoopOutVal, VecLoopExit); + NewIncVal = MergePhi; + } + IncValUse->set(NewIncVal); + } + OrigPhi.setIncomingBlock(*IncValUse, ClonedLoopPreheader); + } + } + + // 1-2) Forward execution from scalar epilogue to original version of + // scalar loop. + { + ClonedLoopHeader->removePredecessor(ClonedLoopPreheader); + ClonedLoopPreheader->getTerminator()->setSuccessor(0, ClonedLoopExit); + ClonedLoopLatch->getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(ClonedLoopLatch); + Builder.CreateUnreachable(); + + DTUpdates.emplace_back(DominatorTree::Delete, ClonedLoopPreheader, + ClonedLoopHeader); + DTUpdates.emplace_back(DominatorTree::Insert, ClonedLoopPreheader, + ClonedLoopExit); + DTUpdates.emplace_back(DominatorTree::Delete, ClonedLoopLatch, + ClonedLoopExit); + } + + // 2-2) Process exiting guard. + { + Value *GuardFromVecLoopExit = nullptr; + if (auto *GuardFromVecLoopExitUse = + FindIncomingValueFrom(EarlyExitGuardLcssa, VecLoopExit, true)) { + GuardFromVecLoopExit = GuardFromVecLoopExitUse->get(); + } else { + PHINode *GuardFromClonedLoopPreheader = cast( + FindIncomingValueFrom(EarlyExitGuardLcssa, ClonedLoopPreheader, true) + ->get()); + GuardFromVecLoopExit = + FindIncomingValueFrom(*GuardFromClonedLoopPreheader, VecLoopExit, + true) + ->get(); + } + EarlyExitGuardLcssa.eraseFromParent(); + + SmallVector EarlyExitGuardDeps; + Instruction *SplitPoint = &VecLoopHeader->front(); + collectDeps(*GuardFromVecLoopExit, VecLoop, SplitPoint, EarlyExitGuardDeps); + + assert(SplitPoint && "Can't find split point"); + BasicBlock *SplitBB = SplitPoint->getParent(); + BasicBlock::iterator SplitPointIt = ++SplitPoint->getIterator(); + BasicBlock *SplitBBCont = SplitBB; + if (SplitPointIt != SplitBB->end()) + SplitBBCont = + SplitBlock(SplitBB, &*SplitPointIt, DTU.get(), LI, MSSAU.get(), + Twine(SplitBB->getName()) + ".cont"); + for (Instruction *I : EarlyExitGuardDeps) + I->moveAfter(SplitPoint); + + // Remove old terminator and insert a new one. + SplitBB->getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(SplitBB); + Builder.CreateCondBr(GuardFromVecLoopExit, VecLoopExit, SplitBBCont); + // Update Header/Latch as they can change. + VecLoopHeader = VecLoop.getHeader(); + VecLoopLatch = VecLoop.getLoopLatch(); + } + + // 2-2) Fix up outcoming values from VecLoop. + { + for (PHINode &OrigPHI : VecLoopExit->phis()) { + Value *IncVal = FindIncomingValueFrom(OrigPHI, VecLoopLatch, true)->get(); + Instruction *IncValInst = dyn_cast(IncVal); + if (IncValInst && IncValInst->getParent() != VecLoopHeader) + // Find PHI in VecLoopHeader with IncVal incoming value. + if (PHINode *HeaderPHI = + FindPhiWithIncomingValueInBB(IncVal, VecLoopHeader)) + IncVal = HeaderPHI; + OrigPHI.addIncoming(IncVal, VecLoopHeader); + } + } + + // 2-3) + { + ICmpInst *LatchCmp = VecLoop.getLatchCmpInst(); + PHINode *IndVar = VecLoop.getInductionVariable(*SE); + Value *LoopBound = LatchCmp->getOperand(0); + if (!VecLoop.isLoopInvariant(LoopBound)) { + LoopBound = LatchCmp->getOperand(1); + assert(VecLoop.isLoopInvariant(LoopBound)); + } + Builder.SetInsertPoint(VecLoopExit, VecLoopExit->begin()); + auto *FinalVal = Builder.CreatePHI(IndVar->getType(), 2, + Twine(IndVar->getName()) + ".vec.out"); + FinalVal->addIncoming(IndVar, VecLoop.getHeader()); + FinalVal->addIncoming(LoopBound, VecLoop.getLoopLatch()); + replaceAllReachableUsesWith(LoopBound, FinalVal); + } + + DTU->applyUpdates(DTUpdates); + MSSAU->applyUpdates(DTUpdates, DTU->getDomTree()); +} + +bool MultiExitLoopVectorizer::postProcessLoops( + const SmallVectorImpl &OrigLoops, + const SmallVectorImpl &ClonedLoops, + const SmallVectorImpl &VecLoops, + const SmallVectorImpl &ExitConds) { + PreservedAnalyses RequiredAnalysis; + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + RequiredAnalysis.preserve(); + + updateAnalysis(RequiredAnalysis); + + for (unsigned i = 0; i < OrigLoops.size(); ++i) { + Loop *OrigLoop = OrigLoops[i]; + Loop *ClonedLoop = ClonedLoops[i]; + Loop *VecLoop = VecLoops[i]; + PHINode *ExitCond = ExitConds[i]; + if (!VecLoop) + continue; + + postProcess(*OrigLoop, *ClonedLoop, *VecLoop, *ExitCond); + } + return true; +} + +} // namespace llvm