diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -255,7 +255,7 @@ void initializeLoopUnswitchPass(PassRegistry&); void initializeLoopVectorizePass(PassRegistry&); void initializeLoopVersioningLICMPass(PassRegistry&); -void initializeLoopVersioningPassPass(PassRegistry&); +void initializeLoopVersioningLegacyPassPass(PassRegistry &); void initializeLowerAtomicLegacyPassPass(PassRegistry&); void initializeLowerConstantIntrinsicsPass(PassRegistry&); void initializeLowerEmuTLSPass(PassRegistry&); diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h --- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h +++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h @@ -16,6 +16,7 @@ #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/IR/PassManager.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -148,6 +149,14 @@ DominatorTree *DT; ScalarEvolution *SE; }; + +/// Expose LoopVersioning as a pass. Currently this is only used for +/// unit-testing. It adds all memchecks necessary to remove all may-aliasing +/// array accesses from the loop. +class LoopVersioningPass : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); +}; } #endif diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -186,6 +186,7 @@ #include "llvm/Transforms/Utils/LCSSA.h" #include "llvm/Transforms/Utils/LibCallsShrinkWrap.h" #include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" #include "llvm/Transforms/Utils/LowerInvoke.h" #include "llvm/Transforms/Utils/Mem2Reg.h" #include "llvm/Transforms/Utils/NameAnonGlobals.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -225,6 +225,7 @@ FUNCTION_PASS("loop-load-elim", LoopLoadEliminationPass()) FUNCTION_PASS("loop-fusion", LoopFusePass()) FUNCTION_PASS("loop-distribute", LoopDistributePass()) +FUNCTION_PASS("loop-versioning", LoopVersioningPass()) FUNCTION_PASS("pgo-memop-opt", PGOMemOPSizeOpt()) FUNCTION_PASS("print", PrintFunctionPass(dbgs())) FUNCTION_PASS("print", AssumptionPrinterPass(dbgs())) diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -110,7 +110,7 @@ initializeLoopDistributeLegacyPass(Registry); initializeLoopLoadEliminationPass(Registry); initializeLoopSimplifyCFGLegacyPassPass(Registry); - initializeLoopVersioningPassPass(Registry); + initializeLoopVersioningLegacyPassPass(Registry); initializeEntryExitInstrumenterPass(Registry); initializePostInlineEntryExitInstrumenterPass(Registry); } diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -16,10 +16,13 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" @@ -253,47 +256,57 @@ } namespace { +bool runImpl(LoopInfo *LI, std::function GetLAA, + DominatorTree *DT, ScalarEvolution *SE) { + // Build up a worklist of inner-loops to version. This is necessary as the + // act of versioning a loop creates new loops and can invalidate iterators + // across the loops. + SmallVector Worklist; + + for (Loop *TopLevelLoop : *LI) + for (Loop *L : depth_first(TopLevelLoop)) + // We only handle inner-most loops. + if (L->empty()) + Worklist.push_back(L); + + // Now walk the identified inner loops. + bool Changed = false; + for (Loop *L : Worklist) { + // const LoopAccessInfo &LAI = LAA->getInfo(L); + const LoopAccessInfo &LAI = GetLAA(*L); + if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() && + (LAI.getNumRuntimePointerChecks() || + !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { + LoopVersioning LVer(LAI, L, LI, DT, SE); + LVer.versionLoop(); + LVer.annotateLoopWithNoAlias(); + Changed = true; + } + } + + return Changed; +} + /// Also expose this is a pass. Currently this is only used for /// unit-testing. It adds all memchecks necessary to remove all may-aliasing /// array accesses from the loop. -class LoopVersioningPass : public FunctionPass { +class LoopVersioningLegacyPass : public FunctionPass { public: - LoopVersioningPass() : FunctionPass(ID) { - initializeLoopVersioningPassPass(*PassRegistry::getPassRegistry()); + LoopVersioningLegacyPass() : FunctionPass(ID) { + initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { auto *LI = &getAnalysis().getLoopInfo(); - auto *LAA = &getAnalysis(); + std::function GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { + return getAnalysis().getInfo(&L); + }; + auto *DT = &getAnalysis().getDomTree(); auto *SE = &getAnalysis().getSE(); - // Build up a worklist of inner-loops to version. This is necessary as the - // act of versioning a loop creates new loops and can invalidate iterators - // across the loops. - SmallVector Worklist; - - for (Loop *TopLevelLoop : *LI) - for (Loop *L : depth_first(TopLevelLoop)) - // We only handle inner-most loops. - if (L->empty()) - Worklist.push_back(L); - - // Now walk the identified inner loops. - bool Changed = false; - for (Loop *L : Worklist) { - const LoopAccessInfo &LAI = LAA->getInfo(L); - if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() && - (LAI.getNumRuntimePointerChecks() || - !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { - LoopVersioning LVer(LAI, L, LI, DT, SE); - LVer.versionLoop(); - LVer.annotateLoopWithNoAlias(); - Changed = true; - } - } - - return Changed; + return runImpl(LI, GetLAA, DT, SE); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -312,18 +325,46 @@ #define LVER_OPTION "loop-versioning" #define DEBUG_TYPE LVER_OPTION -char LoopVersioningPass::ID; +char LoopVersioningLegacyPass::ID; static const char LVer_name[] = "Loop Versioning"; -INITIALIZE_PASS_BEGIN(LoopVersioningPass, LVER_OPTION, LVer_name, false, false) +INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false, + false) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_END(LoopVersioningPass, LVER_OPTION, LVer_name, false, false) +INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false, + false) namespace llvm { -FunctionPass *createLoopVersioningPass() { - return new LoopVersioningPass(); +FunctionPass *createLoopVersioningLegacyPass() { + return new LoopVersioningLegacyPass(); +} + +PreservedAnalyses LoopVersioningPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &SE = AM.getResult(F); + auto &LI = AM.getResult(F); + auto &TTI = AM.getResult(F); + auto &DT = AM.getResult(F); + auto &TLI = AM.getResult(F); + auto &AA = AM.getResult(F); + auto &AC = AM.getResult(F); + MemorySSA *MSSA = EnableMSSALoopDependency + ? &AM.getResult(F).getMSSA() + : nullptr; + + auto &LAM = AM.getResult(F).getManager(); + std::function GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI, MSSA}; + return LAM.getResult(L, AR); + }; + + if (runImpl(&LI, GetLAA, &DT, &SE)) { + return PreservedAnalyses::none(); + } + return PreservedAnalyses::all(); } }