diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h --- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h +++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h @@ -43,9 +43,9 @@ /// It uses runtime check provided by the user. If \p UseLAIChecks is true, /// we will retain the default checks made by LAI. Otherwise, construct an /// object having no checks and we expect the user to add them. - LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, - DominatorTree *DT, ScalarEvolution *SE, - bool UseLAIChecks = true); + LoopVersioning(const LoopAccessInfo &LAI, + ArrayRef Checks, Loop *L, LoopInfo *LI, + DominatorTree *DT, ScalarEvolution *SE); /// Performs the CFG manipulation part of versioning the loop including /// the DominatorTree and LoopInfo updates. @@ -75,12 +75,6 @@ /// loop may alias (i.e. one of the memchecks failed). Loop *getNonVersionedLoop() { return NonVersionedLoop; } - /// Sets the runtime alias checks for versioning the loop. - void setAliasChecks(ArrayRef Checks); - - /// Sets the runtime SCEV checks for versioning the loop. - void setSCEVChecks(SCEVUnionPredicate Check); - /// Annotate memory instructions in the versioned loop with no-alias /// metadata based on the memchecks issued. /// @@ -129,7 +123,7 @@ SmallVector AliasChecks; /// The set of SCEV checks that we are versioning for. - SCEVUnionPredicate Preds; + const SCEVUnionPredicate &Preds; /// Maps a pointer to the pointer checking group that the pointer /// belongs to. diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -814,9 +814,7 @@ LLVM_DEBUG(dbgs() << "\nPointers:\n"); LLVM_DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - LoopVersioning LVer(*LAI, L, LI, DT, SE, false); - LVer.setAliasChecks(std::move(Checks)); - LVer.setSCEVChecks(LAI->getPSE().getUnionPredicate()); + LoopVersioning LVer(*LAI, Checks, L, LI, DT, SE); LVer.versionLoop(DefsUsedOutside); LVer.annotateLoopWithNoAlias(); diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -559,9 +559,7 @@ // Point of no-return, start the transformation. First, version the loop // if necessary. - LoopVersioning LV(LAI, L, LI, DT, PSE.getSE(), false); - LV.setAliasChecks(std::move(Checks)); - LV.setSCEVChecks(LAI.getPSE().getUnionPredicate()); + LoopVersioning LV(LAI, Checks, L, LI, DT, PSE.getSE()); LV.versionLoop(); } diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -603,7 +603,8 @@ // Create memcheck for memory accessed inside loop. // Clone original loop, and set blocks properly. DominatorTree *DT = &getAnalysis().getDomTree(); - LoopVersioning LVer(*LAI, CurLoop, LI, DT, SE, true); + LoopVersioning LVer(*LAI, LAI->getRuntimePointerChecking()->getChecks(), + CurLoop, LI, DT, SE); LVer.versionLoop(); // Set Loop Versioning metaData for original loop. addStringMetadataToLoop(LVer.getNonVersionedLoop(), LICMVersioningMetaData); diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -32,25 +32,16 @@ cl::desc("Add no-alias annotation for instructions that " "are disambiguated by memchecks")); -LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, - DominatorTree *DT, ScalarEvolution *SE, - bool UseLAIChecks) - : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT), +LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, + ArrayRef Checks, Loop *L, + LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE) + : VersionedLoop(L), NonVersionedLoop(nullptr), + AliasChecks(Checks.begin(), Checks.end()), + Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT), SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form"); - if (UseLAIChecks) { - setAliasChecks(LAI.getRuntimePointerChecking()->getChecks()); - setSCEVChecks(LAI.getPSE().getUnionPredicate()); - } -} - -void LoopVersioning::setAliasChecks(ArrayRef Checks) { - AliasChecks = {Checks.begin(), Checks.end()}; -} - -void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) { - Preds = std::move(Check); } void LoopVersioning::versionLoop( @@ -67,11 +58,10 @@ addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop, AliasChecks, RtPtrChecking.getSE()); - const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate(); SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), "scev.check"); SCEVRuntimeCheck = - Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator()); + Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator()); auto *CI = dyn_cast(SCEVRuntimeCheck); // Discard the SCEV runtime check if it is always true. @@ -286,7 +276,8 @@ if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() && (LAI.getNumRuntimePointerChecks() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { - LoopVersioning LVer(LAI, L, LI, DT, SE); + LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), + L, LI, DT, SE); LVer.versionLoop(); LVer.annotateLoopWithNoAlias(); Changed = true; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -2852,8 +2852,10 @@ // We currently don't use LoopVersioning for the actual loop cloning but we // still use it to add the noalias metadata. - LVer = std::make_unique(*Legal->getLAI(), OrigLoop, LI, DT, - PSE.getSE()); + LVer = std::make_unique( + *Legal->getLAI(), + Legal->getLAI()->getRuntimePointerChecking()->getChecks(), OrigLoop, LI, + DT, PSE.getSE()); LVer->prepareNoAliasMetadata(); }