Index: include/llvm/Transforms/Utils/LoopVersioning.h =================================================================== --- include/llvm/Transforms/Utils/LoopVersioning.h +++ include/llvm/Transforms/Utils/LoopVersioning.h @@ -17,6 +17,7 @@ #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -25,6 +26,7 @@ class Loop; class LoopAccessInfo; class LoopInfo; +class ScalarEvolution; /// \brief This class emits a version of the loop where run-time checks ensure /// that may-alias pointers can't overlap. @@ -33,16 +35,13 @@ /// already has a preheader. class LoopVersioning { public: - /// \brief Expects MemCheck, LoopAccessInfo, Loop, LoopInfo, DominatorTree - /// as input. It uses runtime check provided by user. - LoopVersioning(SmallVector Checks, - const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, - DominatorTree *DT); - /// \brief Expects LoopAccessInfo, Loop, LoopInfo, DominatorTree as input. - /// It uses default runtime check provided by LoopAccessInfo. - LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, LoopInfo *LI, - DominatorTree *DT); + /// It uses runtime check provided by the user. If \p UseLAIChecks is true, + /// we will retain the default checks made by LAI. Otherwise, construct an + /// object having no checks and we expect the user to add them. + LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, + DominatorTree *DT, ScalarEvolution *SE, + bool UseLAIChecks = true); /// \brief Performs the CFG manipulation part of versioning the loop including /// the DominatorTree and LoopInfo updates. @@ -72,6 +71,13 @@ /// loop may alias (i.e. one of the memchecks failed). Loop *getNonVersionedLoop() { return NonVersionedLoop; } + /// \brief Adds a set of runtime alias checks for versioning the loop. + void addAliasChecks( + const SmallVectorImpl &Checks); + + /// \brief Adds a set of runtime SCEV checks for versioning the loop. + void addSCEVChecks(const SCEVUnionPredicate &Check); + private: /// \brief Adds the necessary PHI nodes for the versioned loops based on the /// loop-defined values used outside of the loop. @@ -91,13 +97,17 @@ /// in NonVersionedLoop. ValueToValueMapTy VMap; - /// \brief The set of checks that we are versioning for. + /// \brief The set of alias checks that we are versioning for. SmallVector Checks; + /// \brief The set of SCEV checks that we are versioning for. + SCEVUnionPredicate Preds; + /// \brief Analyses used. const LoopAccessInfo &LAI; LoopInfo *LI; DominatorTree *DT; + ScalarEvolution *SE; }; } Index: lib/Transforms/Scalar/LoopDistribute.cpp =================================================================== --- lib/Transforms/Scalar/LoopDistribute.cpp +++ lib/Transforms/Scalar/LoopDistribute.cpp @@ -577,6 +577,7 @@ LI = &getAnalysis().getLoopInfo(); LAA = &getAnalysis(); DT = &getAnalysis().getDomTree(); + SE = &getAnalysis().getSE(); // Build up a worklist of inner-loops to vectorize. This is necessary as the // act of distributing a loop creates new loops and can invalidate iterators @@ -599,6 +600,7 @@ } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addPreserved(); AU.addRequired(); @@ -764,17 +766,20 @@ if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator()) SplitBlock(PH, PH->getTerminator(), DT, LI); - // If we need run-time checks to disambiguate pointers are run-time, version - // the loop now. + // If we need run-time checks, version the loop now. auto PtrToPartition = Partitions.computePartitionSetForPointers(LAI); const auto *RtPtrChecking = LAI.getRuntimePointerChecking(); const auto &AllChecks = RtPtrChecking->getChecks(); auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, RtPtrChecking); - if (!Checks.empty()) { + + const SCEVUnionPredicate &Pred = LAI.Preds; + if (!Pred.isAlwaysTrue() || !Checks.empty()) { DEBUG(dbgs() << "\nPointers:\n"); DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT); + LoopVersioning LVer(LAI, L, LI, DT, SE, false); + LVer.addAliasChecks(Checks); + LVer.addSCEVChecks(LAI.Preds); LVer.versionLoop(DefsUsedOutside); } @@ -801,6 +806,7 @@ LoopInfo *LI; LoopAccessAnalysis *LAA; DominatorTree *DT; + ScalarEvolution *SE; }; } // anonymous namespace @@ -811,6 +817,7 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false) namespace llvm { Index: lib/Transforms/Utils/LoopVersioning.cpp =================================================================== --- lib/Transforms/Utils/LoopVersioning.cpp +++ lib/Transforms/Utils/LoopVersioning.cpp @@ -17,46 +17,78 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" using namespace llvm; -LoopVersioning::LoopVersioning( - SmallVector Checks, - const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT) - : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)), - LAI(LAI), LI(LI), DT(DT) { +LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, + DominatorTree *DT, ScalarEvolution *SE, + bool UseLAIChecks) + : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT), + SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->getLoopPreheader() && "No preheader"); + if (UseLAIChecks) { + addAliasChecks(LAI.getRuntimePointerChecking()->getChecks()); + addSCEVChecks(LAI.Preds); + } } -LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, - LoopInfo *LI, DominatorTree *DT) - : VersionedLoop(L), NonVersionedLoop(nullptr), - Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo), - LI(LI), DT(DT) { - assert(L->getExitBlock() && "No single exit block"); - assert(L->getLoopPreheader() && "No preheader"); +void LoopVersioning::addAliasChecks( + const SmallVectorImpl &Checks) { + std::copy(Checks.begin(), Checks.end(), std::back_inserter(this->Checks)); +} + +void LoopVersioning::addSCEVChecks(const SCEVUnionPredicate &Check) { + Preds.add(&Check); } void LoopVersioning::versionLoop( const SmallVectorImpl &DefsUsedOutside) { Instruction *FirstCheckInst; Instruction *MemRuntimeCheck; + Value *SCEVRuntimeCheck; + Value *RuntimeCheck = nullptr; + // Add the memcheck in the original preheader (this is empty initially). - BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader(); + BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); std::tie(FirstCheckInst, MemRuntimeCheck) = - LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks); + LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), Checks); assert(MemRuntimeCheck && "called even though needsAnyChecking = false"); + const SCEVUnionPredicate &Pred = LAI.Preds; + SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), + "scev.check"); + SCEVRuntimeCheck = + Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator()); + auto *CI = dyn_cast(SCEVRuntimeCheck); + + // Discard the SCEV runtime check if it is always true. + if (CI && CI->isZero()) + SCEVRuntimeCheck = nullptr; + + if (MemRuntimeCheck && SCEVRuntimeCheck) { + RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, + SCEVRuntimeCheck, "ldist.safe"); + if (auto *I = dyn_cast(RuntimeCheck)) + I->insertBefore(RuntimeCheckBB->getTerminator()); + } else + RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck; + + assert(RuntimeCheck && "called even though we don't need " + "any runtime checks"); + // Rename the block to make the IR more readable. - MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck"); + RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() + + ".lver.check"); // Create empty preheader for the loop (and after cloning for the // non-versioned loop). - BasicBlock *PH = SplitBlock(MemCheckBB, MemCheckBB->getTerminator(), DT, LI); + BasicBlock *PH = + SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI); PH->setName(VersionedLoop->getHeader()->getName() + ".ph"); // Clone the loop including the preheader. @@ -65,20 +97,19 @@ // block is a join between the two loops. SmallVector NonVersionedLoopBlocks; NonVersionedLoop = - cloneLoopWithPreheader(PH, MemCheckBB, VersionedLoop, VMap, ".lver.orig", - LI, DT, NonVersionedLoopBlocks); + cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap, + ".lver.orig", LI, DT, NonVersionedLoopBlocks); remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap); // Insert the conditional branch based on the result of the memchecks. - Instruction *OrigTerm = MemCheckBB->getTerminator(); + Instruction *OrigTerm = RuntimeCheckBB->getTerminator(); BranchInst::Create(NonVersionedLoop->getLoopPreheader(), - VersionedLoop->getLoopPreheader(), MemRuntimeCheck, - OrigTerm); + VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm); OrigTerm->eraseFromParent(); // The loops merge in the original exit block. This is now dominated by the // memchecking block. - DT->changeImmediateDominator(VersionedLoop->getExitBlock(), MemCheckBB); + DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB); // Adds the necessary PHI nodes for the versioned loops based on the // loop-defined values used outside of the loop. Index: test/Transforms/LoopDistribute/basic-with-memchecks.ll =================================================================== --- test/Transforms/LoopDistribute/basic-with-memchecks.ll +++ test/Transforms/LoopDistribute/basic-with-memchecks.ll @@ -36,7 +36,7 @@ ; Since the checks to A and A + 4 get merged, this will give us a ; total of 8 compares. ; -; CHECK: for.body.lver.memcheck: +; CHECK: for.body.lver.check: ; CHECK: = icmp ; CHECK: = icmp