Index: include/llvm/Transforms/Utils/LoopVersioning.h =================================================================== --- include/llvm/Transforms/Utils/LoopVersioning.h +++ include/llvm/Transforms/Utils/LoopVersioning.h @@ -25,6 +25,7 @@ class Loop; class LoopAccessInfo; class LoopInfo; +class ScalarEvolution; /// \brief This class emits a version of the loop where run-time checks ensure /// that may-alias pointers can't overlap. @@ -37,12 +38,12 @@ /// as input. It uses runtime check provided by user. LoopVersioning(SmallVector Checks, const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, - DominatorTree *DT); + DominatorTree *DT, ScalarEvolution *SE); /// \brief Expects LoopAccessInfo, Loop, LoopInfo, DominatorTree as input. /// It uses default runtime check provided by LoopAccessInfo. LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, LoopInfo *LI, - DominatorTree *DT); + DominatorTree *DT, ScalarEvolution *SE); /// \brief Performs the CFG manipulation part of versioning the loop including /// the DominatorTree and LoopInfo updates. @@ -98,6 +99,7 @@ const LoopAccessInfo &LAI; LoopInfo *LI; DominatorTree *DT; + ScalarEvolution *SE; }; } Index: lib/Transforms/Scalar/LoopDistribute.cpp =================================================================== --- lib/Transforms/Scalar/LoopDistribute.cpp +++ lib/Transforms/Scalar/LoopDistribute.cpp @@ -577,6 +577,7 @@ LI = &getAnalysis().getLoopInfo(); LAA = &getAnalysis(); DT = &getAnalysis().getDomTree(); + SE = &getAnalysis().getSE(); // Build up a worklist of inner-loops to vectorize. This is necessary as the // act of distributing a loop creates new loops and can invalidate iterators @@ -599,6 +600,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addPreserved(); AU.addRequired(); @@ -772,10 +774,12 @@ const auto &AllChecks = RtPtrChecking->getChecks(); auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, RtPtrChecking); - if (!Checks.empty()) { + + const SCEVUnionPredicate &Pred = LAI.Preds; + if ((!Pred.isAlwaysTrue()) || !Checks.empty()) { DEBUG(dbgs() << "\nPointers:\n"); DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT); + LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT, SE); LVer.versionLoop(DefsUsedOutside); } @@ -802,6 +806,7 @@ LoopInfo *LI; LoopAccessAnalysis *LAA; DominatorTree *DT; + ScalarEvolution *SE; }; } // anonymous namespace @@ -812,6 +817,7 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false) namespace llvm { Index: lib/Transforms/Utils/LoopVersioning.cpp =================================================================== --- lib/Transforms/Utils/LoopVersioning.cpp +++ lib/Transforms/Utils/LoopVersioning.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -25,18 +26,20 @@ LoopVersioning::LoopVersioning( SmallVector Checks, - const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT) + const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE) : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)), - LAI(LAI), LI(LI), DT(DT) { + LAI(LAI), LI(LI), DT(DT), SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->getLoopPreheader() && "No preheader"); } LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, - LoopInfo *LI, DominatorTree *DT) + LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE) : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo), - LI(LI), DT(DT) { + LI(LI), DT(DT), SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->getLoopPreheader() && "No preheader"); } @@ -45,12 +48,36 @@ const SmallVectorImpl &DefsUsedOutside) { Instruction *FirstCheckInst; Instruction *MemRuntimeCheck; + Value *SCEVRuntimeCheck; + Value *RuntimeCheck = nullptr; + // Add the memcheck in the original preheader (this is empty initially). BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader(); std::tie(FirstCheckInst, MemRuntimeCheck) = LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks); assert(MemRuntimeCheck && "called even though needsAnyChecking = false"); + const SCEVUnionPredicate &Pred = LAI.Preds; + SCEVExpander Exp(*SE, MemCheckBB->getModule()->getDataLayout(), "scev.check"); + SCEVRuntimeCheck = Exp.expandCodeForPredicate(&Pred, + MemCheckBB->getTerminator()); + auto *CI = dyn_cast(SCEVRuntimeCheck); + + // Discard the SCEV runtime check if it is always true. + if (CI && CI->isZero()) + SCEVRuntimeCheck = nullptr; + + if (MemRuntimeCheck && SCEVRuntimeCheck) { + RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, + SCEVRuntimeCheck, "ldist.safe"); + if (auto *I = dyn_cast(RuntimeCheck)) + I->insertBefore(MemCheckBB->getTerminator()); + } else + RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck; + + assert(RuntimeCheck && "called even though we don't need " + "any runtime checks"); + // Rename the block to make the IR more readable. MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck"); @@ -72,8 +99,7 @@ // Insert the conditional branch based on the result of the memchecks. Instruction *OrigTerm = MemCheckBB->getTerminator(); BranchInst::Create(NonVersionedLoop->getLoopPreheader(), - VersionedLoop->getLoopPreheader(), MemRuntimeCheck, - OrigTerm); + VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm); OrigTerm->eraseFromParent(); // The loops merge in the original exit block. This is now dominated by the