Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -525,23 +525,31 @@ /// If the client speculates (and then issues run-time checks) for the values /// of symbolic strides, \p Strides provides the mapping (see /// replaceSymbolicStrideSCEV). If there is no cached result available run - /// the analysis. - const LoopAccessInfo &getInfo(Loop *L, const ValueToValueMap &Strides); + /// the analysis. If UseAssumptions is true, the user must add the + /// run-time checks collected by the AssumingScalarEvolution pass if + /// making any changes to the loop. + const LoopAccessInfo &getInfo(Loop *L, const ValueToValueMap &Strides, + bool UseAssumptions = false); void releaseMemory() override { // Invalidate the cache when the pass is freed. LoopAccessInfoMap.clear(); + AssumeLoopAccessInfoMap.clear(); } /// \brief Print the result of the analysis when invoked with -analyze. void print(raw_ostream &OS, const Module *M = nullptr) const override; private: - /// \brief The cache. + /// \brief The cache for answers computed without ussing the + /// AssumingScalarEvolution pass. DenseMap> LoopAccessInfoMap; - + /// \brief A separate cache used for answers computed using the + /// AsummingScalarEvolution pass. + DenseMap> AssumeLoopAccessInfoMap; // The used analysis passes. ScalarEvolution *SE; + ScalarEvolution *ASE; const TargetLibraryInfo *TLI; AliasAnalysis *AA; DominatorTree *DT; Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -581,6 +581,7 @@ public: static char ID; // Pass identification, replacement for typeid ScalarEvolution(); + ScalarEvolution(bool AddToRegistry, char &ID); LLVMContext &getContext() const { return F->getContext(); } @@ -602,14 +603,16 @@ /// getSCEV - Return a SCEV expression for the full generality of the /// specified expression. - const SCEV *getSCEV(Value *V); + virtual const SCEV *getSCEV(Value *V); const SCEV *getConstant(ConstantInt *V); const SCEV *getConstant(const APInt& Val); const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false); const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty); - const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty); - const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty); + + virtual const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty); + virtual const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty); + const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAddExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); @@ -837,7 +840,7 @@ /// changed a loop in a way that may effect ScalarEvolution's ability to /// compute a trip count, or if the loop is deleted. This call is /// potentially expensive for large loop bodies. - void forgetLoop(const Loop *L); + virtual void forgetLoop(const Loop *L); /// forgetValue - This method should be called by the client when it has /// changed a value in a way that may effect its value, or which may @@ -981,6 +984,93 @@ /// to locate them all and call their destructors. SCEVUnknown *FirstUnknown; }; + + /// AssumingScalarEvolution - a version of ScalarEvolution that will + /// make and track assumptions. The rationale behind this is that most + /// users of ScalarEvolution need a chrec returned as a result in order + /// to perform some other action. + /// + /// The pass provides an interface to generate run-time checks for the + /// assumptions that were made. + /// + /// We will lazily make the assumptions as getSCEV(), getBackedgeCount(), etc + /// are used. + /// + /// So far We can make the following assumptions to return a chrec where + /// otherwise we would not: + /// - if A is a chrec, and we need to return sext/zext(A) we will + /// add a nsw/nuw assumption for A. This allows us to fold the + /// extend operation an return a chrec instead. We expect the runtime + /// check will evaluate to 'true' in most cases. + class AssumingScalarEvolution : public ScalarEvolution { + private: + const Loop *AnalyzedLoop; + /// Use the normal Scalar Evolution analysis for any querries of + /// values outside the loop. This limits the recomputing we have to + /// do for each loop. + ScalarEvolution *SE; + public: + AssumingScalarEvolution(); + static char ID; // Pass identification, replacement for typeid + + private: + /// Assumption tracking structures + + /// OverflowAssumption - An assumption that has been made by SCEV + /// regarding the possibility of overflow of an expression. + struct OverflowAssumption { + const SCEV *Expression; + SCEV::NoWrapFlags AddedFlags; + OverflowAssumption() : Expression(nullptr), + AddedFlags(SCEV::FlagAnyWrap) {} + }; + + typedef DenseMap AssumptionMapTy; + DenseMap LoopAssumptions; + + AssumptionMapTy::iterator getLoopAssumptionBegin(const Loop *L) { + AssumptionMapTy &Map = LoopAssumptions.FindAndConstruct(L).second; + return Map.begin(); + } + AssumptionMapTy::iterator getLoopAssumptionEnd(const Loop *L) { + AssumptionMapTy &Map = LoopAssumptions.FindAndConstruct(L).second; + return Map.end(); + } + + public: + void setAnalyzedLoop(const Loop *L) {AnalyzedLoop = L;} + + /// Override the virtual methods from the ScalarEvolution analysis. + void getAnalysisUsage(AnalysisUsage &AU) const override; + /// Override the methods for producing extending expressions. + /// This allows us to add overflow assumptions whenever SCEV can't + /// prove NSW/NUW. + const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty) override; + const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty) override; + + void releaseMemory() override; + void forgetLoop(const Loop *L) override; + const SCEV *getSCEV(Value *V) override; + bool runOnFunction(Function &F) override; + + /// Add an assumption that the expression S has the wrapping behaviour + /// implied by Flags. All added assumptions need to be checked + /// at runtime, otherwise the results of this analysis are invalid. + void addOverflowAssumption(const SCEV *S, SCEV::NoWrapFlags Flags); + + /// Method for generating run-time checks. The results of the + /// analysis are invalid if these run-time checks are not added. + std::pair + addRuntimeOverflowChecks(Instruction *Loc); + + /// Checks that the added assumptions can actually be checked + /// at runtime (and returns true we can generate code for these + /// checks). + bool canPerformOverflowChecks(); + + /// Clears the assumption tracking structures. + void forgetLoopAssumptions(const Loop *L); + }; } #endif Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -64,6 +64,7 @@ void initializeAAEvalPass(PassRegistry&); void initializeAddDiscriminatorsPass(PassRegistry&); +void initializeAssumingScalarEvolutionPass(PassRegistry&); void initializeADCEPass(PassRegistry&); void initializeBDCEPass(PassRegistry&); void initializeAliasAnalysisAnalysisGroup(PassRegistry&); Index: lib/Analysis/Analysis.cpp =================================================================== --- lib/Analysis/Analysis.cpp +++ lib/Analysis/Analysis.cpp @@ -25,6 +25,7 @@ initializeAAEvalPass(Registry); initializeAliasDebuggerPass(Registry); initializeAliasSetPrinterPass(Registry); + initializeAssumingScalarEvolutionPass(Registry); initializeNoAAPass(Registry); initializeBasicAliasAnalysisPass(Registry); initializeBlockFrequencyInfoPass(Registry); Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -1357,8 +1357,22 @@ } const LoopAccessInfo & -LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides) { - auto &LAI = LoopAccessInfoMap[L]; +LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides, + bool UseAssumptions) { + DenseMap> &Map = + UseAssumptions ? LoopAccessInfoMap : AssumeLoopAccessInfoMap; + + auto &LAI = Map[L]; + + // FIXME: Swapping the Scalar Evolution analysis isn't nice, + // but it allows us to only teach some of the users + // of LoopAccessAnalysis about the AssumingScalarEvolution pass. + // At some point if all the users know about the + // AssumingScalarEvolution pass, we can remove this. + + // Swap the normal Scalar Evolution analysis with the assuming one + // if the user has requested to use assumptions. + if (UseAssumptions) std::swap(SE, ASE); #ifndef NDEBUG assert((!LAI || LAI->NumSymbolicStrides == Strides.size()) && @@ -1373,6 +1387,8 @@ LAI->NumSymbolicStrides = Strides.size(); #endif } + if (UseAssumptions) std::swap(SE, ASE); + return *LAI.get(); } @@ -1391,6 +1407,7 @@ bool LoopAccessAnalysis::runOnFunction(Function &F) { SE = &getAnalysis(); + ASE = &getAnalysis(); auto *TLIP = getAnalysisIfAvailable(); TLI = TLIP ? &TLIP->getTLI() : nullptr; AA = &getAnalysis(); @@ -1402,6 +1419,7 @@ void LoopAccessAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired(); + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -1416,6 +1434,7 @@ INITIALIZE_PASS_BEGIN(LoopAccessAnalysis, LAA_NAME, laa_name, false, true) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) +INITIALIZE_PASS_DEPENDENCY(AssumingScalarEvolution) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(LoopAccessAnalysis, LAA_NAME, laa_name, false, true) Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -80,9 +80,12 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -109,6 +112,14 @@ "derived loop"), cl::init(100)); +static cl::opt +OverflowCheckThreshold("force-max-overflow-checks", cl::init(16), + cl::Hidden, + cl::desc("Don't vectorize if the number of added " + "overflow checks would be greater than this " + "value.")); + + // FIXME: Enable this with XDEBUG when the test suite is clean. static cl::opt VerifySCEV("verify-scev", @@ -8054,13 +8065,16 @@ //===----------------------------------------------------------------------===// // ScalarEvolution Class Implementation //===----------------------------------------------------------------------===// - -ScalarEvolution::ScalarEvolution() +ScalarEvolution::ScalarEvolution(bool AddToRegistry, char &ID) : FunctionPass(ID), WalkingBEDominatingConds(false), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), FirstUnknown(nullptr) { - initializeScalarEvolutionPass(*PassRegistry::getPassRegistry()); + if (AddToRegistry) { + initializeScalarEvolutionPass(*PassRegistry::getPassRegistry()); + } } +ScalarEvolution::ScalarEvolution() : ScalarEvolution(true, ID) {} + bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; AC = &getAnalysis().getAssumptionCache(F); @@ -8528,3 +8542,318 @@ // TODO: Verify more things. } + +char AssumingScalarEvolution::ID = 0; + +INITIALIZE_PASS_BEGIN(AssumingScalarEvolution, "assuming-scalar-evolution", + "Assuming Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) +INITIALIZE_PASS_END(AssumingScalarEvolution, "assuming-scalar-evolution", + "Assuming Scalar Evolution Analysis", false, true) + +AssumingScalarEvolution::AssumingScalarEvolution() : + ScalarEvolution(false, ID), SE(nullptr) { + initializeAssumingScalarEvolutionPass(*PassRegistry::getPassRegistry()); +} + +void AssumingScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const{ + ScalarEvolution::getAnalysisUsage(AU); + AU.addRequired(); +} +void AssumingScalarEvolution::addOverflowAssumption(const SCEV *S, + SCEV::NoWrapFlags Flags) { + // We should only be adding overflow assumptions for AddRecExpr. + const SCEVAddRecExpr *Expr = static_cast(S); + + SCEVAddRecExpr *RE = const_cast(Expr); + RE->setNoWrapFlags(Flags); + + const Loop *L = Expr->getLoop(); + AssumptionMapTy &Map = LoopAssumptions.FindAndConstruct(L).second; + + // FIXME: assumptions can become redundant when stronger assumptions are + // added. We should try and reduce the number of assumptions once we have + // finished gathering them. + auto II = Map.find(S); + if (II == Map.end()) { + struct OverflowAssumption OA; + OA.Expression = S; + OA.AddedFlags = Flags; + Map[OA.Expression] = OA; + return; + } + + struct OverflowAssumption &Existing = II->second; + Existing.AddedFlags = setFlags(Flags, Existing.AddedFlags); +} + +const SCEV * +AssumingScalarEvolution::getZeroExtendExpr(const SCEV *OP, Type *T) { + + const SCEV *BaseResult = ScalarEvolution::getZeroExtendExpr(OP, T); + if (dyn_cast(BaseResult)) + return BaseResult; + + // We only know how to add assumptions for affine + // expressions. + const SCEVAddRecExpr *AR = dyn_cast(OP); + if(!AR) return BaseResult; + if (!AR->isAffine()) return BaseResult; + if (!(AR->getLoop() == AnalyzedLoop)) return BaseResult; + + addOverflowAssumption(AR, SCEV::FlagNUW); + + const Loop *L = AR->getLoop(); + const SCEV *Step = AR->getStepRecurrence(*this); + + return getAddRecExpr( + getExtendAddRecStart(AR, T, this), + getZeroExtendExpr(Step, T), L, AR->getNoWrapFlags()); +} + +const SCEV * +AssumingScalarEvolution::getSignExtendExpr(const SCEV *OP, Type *T) { + + const SCEV *BaseResult = ScalarEvolution::getSignExtendExpr(OP, T); + if (dyn_cast(BaseResult)) + return BaseResult; + + // We only know how to add assumptions for affine + // expressions. + const SCEVAddRecExpr *AR = dyn_cast(OP); + if(!AR) return BaseResult; + if (!AR->isAffine()) return BaseResult; + if (!(AR->getLoop() == AnalyzedLoop)) return BaseResult; + + addOverflowAssumption(AR, SCEV::FlagNSW); + + const Loop *L = AR->getLoop(); + const SCEV *Step = AR->getStepRecurrence(*this); + + return getAddRecExpr( + getExtendAddRecStart(AR, T, this), + getSignExtendExpr(Step, T), L, AR->getNoWrapFlags()); +} + +static Value * +generateOverflowCheck(const SCEVAddRecExpr *AR, + Instruction *Loc, + bool Signed, + ScalarEvolution *SE, + const DataLayout *DL, + SCEVExpander &Exp) { + + Module *M = Loc->getParent()->getParent()->getParent(); + IRBuilder<> OFBuilder(Loc); + Value *AddF, *MulF; + if (Signed) { + AddF = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow, + AR->getType()); + MulF = Intrinsic::getDeclaration(M, Intrinsic::smul_with_overflow, + AR->getType()); + } else { + AddF = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, + AR->getType()); + MulF = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, + AR->getType()); + } + Value *Start; + Value *Stride; + + const SCEV *ExitCount = SE->getBackedgeTakenCount(AR->getLoop()); + + unsigned DstBits = AR->getType()->getPrimitiveSizeInBits(); + unsigned SrcBits = ExitCount->getType()->getPrimitiveSizeInBits(); + + if (SrcBits < DstBits) { + // We need to extend + if (Signed) + ExitCount = SE->getNoopOrSignExtend(ExitCount, AR->getType()); + else + ExitCount = SE->getNoopOrZeroExtend(ExitCount, AR->getType()); + } + + assert(ExitCount != SE->getCouldNotCompute() && "Invalid loop count"); + Value *TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(), Loc); + Value *TripCountCheck = nullptr; + + // We might need to truncate TripCount + // If this is the case, we need to make sure that this is legal. + if (SrcBits > DstBits) { + APInt CmpMaxValue = Signed ? + APInt::getSignedMaxValue(DstBits).sext(SrcBits) : + APInt::getMaxValue(DstBits).zext(SrcBits); + // The min value only makes sense for signed checks. + + ConstantInt *CTMax = ConstantInt::get(M->getContext(), CmpMaxValue); + CmpInst::Predicate P = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; + TripCountCheck = OFBuilder.CreateICmp(P, TripCount, CTMax); + + if (Signed) { + APInt CmpMinValue = APInt::getSignedMinValue(DstBits).sext(SrcBits); + ConstantInt *CTMin = ConstantInt::get(M->getContext(), CmpMinValue); + Value *MinCheck = OFBuilder.CreateICmp(ICmpInst::ICMP_SLT, + TripCount, + CTMin); + TripCountCheck = OFBuilder.CreateOr(TripCountCheck, MinCheck); + } + + TripCount = OFBuilder.CreateTrunc(TripCount, AR->getType()); + } + + // We need to truncate or extend TripCount to the type used by the SCEV + // Extension is not a problem. + Start = Exp.expandCodeFor(AR->getStart(), + AR->getStart()->getType(), Loc); + + // This is an affine expression + Stride = Exp.expandCodeFor(AR->getOperand(1), + AR->getOperand(1)->getType(), Loc); + + CallInst *Mul = OFBuilder.CreateCall(MulF, {Stride, TripCount}, "mul"); + Value *MulV = OFBuilder.CreateExtractValue(Mul, 0, "mul.result"); + Value *OfMul = OFBuilder.CreateExtractValue(Mul, 1, "mul.overflow"); + CallInst *Add = OFBuilder.CreateCall(AddF,{MulV, Start}, "uadd"); + Value *OfAdd = OFBuilder.CreateExtractValue(Add, 1, "add.overflow"); + Value *Overflow = OFBuilder.CreateOr(OfMul, OfAdd, "overflow"); + + if (TripCountCheck) { + Overflow = OFBuilder.CreateOr(Overflow, TripCountCheck); + } + + return Overflow; +} + +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +std::pair +AssumingScalarEvolution::addRuntimeOverflowChecks(Instruction *Loc) { + Instruction *tnullptr = nullptr; + + if (getLoopAssumptionBegin(AnalyzedLoop) == + getLoopAssumptionEnd(AnalyzedLoop)) + return std::pair(tnullptr, tnullptr); + + IRBuilder<> OFBuilder(Loc); + Instruction *FirstInst = nullptr; + Module *M = Loc->getParent()->getParent()->getParent(); + const DataLayout &DL = M->getDataLayout(); + Value *OverflowRuntimeCheck = nullptr; + SCEVExpander Exp(*this, DL, "start"); + + for (auto II = getLoopAssumptionBegin(AnalyzedLoop), + EE = getLoopAssumptionEnd(AnalyzedLoop); + II != EE; ++II) { + struct OverflowAssumption &OA = II->second; + + const SCEVAddRecExpr *AR = static_cast(OA.Expression); + + assert(AR->isAffine() && + "We don't know how to check non-affine expressions for overflow"); + + if (OA.AddedFlags & SCEV::FlagNUW) { + // Add a check for NUW + Value *Overflow = generateOverflowCheck(AR, Loc, false, this, &DL, Exp); + + if (!OverflowRuntimeCheck) + OverflowRuntimeCheck = Overflow; + else + OverflowRuntimeCheck = OFBuilder.CreateOr(OverflowRuntimeCheck, + Overflow); + } + if (OA.AddedFlags & SCEV::FlagNSW) { + // Add a check for NSW + Value *Overflow = generateOverflowCheck(AR, Loc, true, this, &DL, Exp); + + if (!OverflowRuntimeCheck) + OverflowRuntimeCheck = Overflow; + else + OverflowRuntimeCheck = OFBuilder.CreateOr(OverflowRuntimeCheck, + Overflow); + } + } + + if (!OverflowRuntimeCheck) + return std::make_pair(nullptr, nullptr); + Instruction *Check = BinaryOperator::CreateOr(OverflowRuntimeCheck, + ConstantInt::getFalse(M->getContext())); + OFBuilder.Insert(Check, "overflow.check"); + + FirstInst = getFirstInst(FirstInst, Check, Loc); + return std::make_pair(FirstInst, Check); +} + +bool AssumingScalarEvolution::canPerformOverflowChecks() { + unsigned NumChecks = 0; + for (auto II = getLoopAssumptionBegin(AnalyzedLoop), + EE = getLoopAssumptionEnd(AnalyzedLoop); II != EE; ++II) { + NumChecks++; + struct OverflowAssumption &OA = II->second; + DEBUG(dbgs() << "ASCEV: Processing overflow assumption " + << *OA.Expression << "\n"); + if (const SCEVAddRecExpr *AR = dyn_cast(OA.Expression)) { + if (AR->isAffine()) { + // We know how to check affine expressions. + // {u, + v} executed l times will not overflow as long as v*l does not + // overflow and u + v * l does not overflow. + continue; + } + } + DEBUG(dbgs() << "ASCEV: Unable to check " << *OA.Expression << "\n"); + return false; + } + // FIXME: Some of these checks can be redundant. We should remove them here. + // FIXME: The number of overflow checks is not a good metric. A lot of + // checks are expected to share code and will be CSE'ed away. + // Furthermore, instcombine is expected to simplify these checks further. + if (NumChecks > OverflowCheckThreshold) { + DEBUG(dbgs() << "ASCEV: Too many overflow checks needed.\n"); + return false; + } + return true; +} + +void AssumingScalarEvolution::forgetLoopAssumptions(const Loop *L) { + auto It = LoopAssumptions.find(L); + if (It == LoopAssumptions.end()) return; + It->second.clear(); +} + +void AssumingScalarEvolution::forgetLoop(const Loop *L) { + ScalarEvolution::forgetLoop(L); + SE->forgetLoop(L); + forgetLoopAssumptions(L); +} + +void AssumingScalarEvolution::releaseMemory() { + ScalarEvolution::releaseMemory(); + LoopAssumptions.clear(); +} + +const SCEV *AssumingScalarEvolution::getSCEV(Value *V) { + Instruction *Inst = dyn_cast(V); + if (!Inst) + return ScalarEvolution::getSCEV(V); + // If this value is not produced in the analyzed loop use + // the normal ScalarEvolution. + if (!AnalyzedLoop->contains(Inst)) + return SE->getSCEV(Inst); + + return ScalarEvolution::getSCEV(V); +} + +bool AssumingScalarEvolution::runOnFunction(Function &F) { + SE = &getAnalysis(); + return ScalarEvolution::runOnFunction(F); +} Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -244,7 +244,7 @@ /// and reduction variables that were found to a given vectorization factor. class InnerLoopVectorizer { public: - InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, + InnerLoopVectorizer(Loop *OrigLoop, AssumingScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned VecWidth, unsigned UnrollFactor) @@ -291,6 +291,10 @@ /// pair as (first, last). std::pair addStrideCheck(Instruction *Loc); + // Adds code to check the overflow assumptions made by SCEV + std::pair + addRuntimeOverflowChecks(Instruction *Loc); + /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(); /// Copy and widen the instructions from the old loop. @@ -396,8 +400,8 @@ /// The original loop. Loop *OrigLoop; - /// Scev analysis to use. - ScalarEvolution *SE; + /// Scev analysis to use (with assumptions). + AssumingScalarEvolution *SE; /// Loop Info. LoopInfo *LI; /// Dominator Tree. @@ -456,7 +460,7 @@ class InnerLoopUnroller : public InnerLoopVectorizer { public: - InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI, + InnerLoopUnroller(Loop *OrigLoop, AssumingScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, unsigned UnrollFactor) : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {} @@ -560,7 +564,7 @@ /// induction variable and the different reduction variables. class LoopVectorizationLegality { public: - LoopVectorizationLegality(Loop *L, ScalarEvolution *SE, DominatorTree *DT, + LoopVectorizationLegality(Loop *L, AssumingScalarEvolution *SE, DominatorTree *DT, TargetLibraryInfo *TLI, AliasAnalysis *AA, Function *F, const TargetTransformInfo *TTI, LoopAccessAnalysis *LAA) @@ -777,7 +781,7 @@ /// The loop that we evaluate. Loop *TheLoop; /// Scev analysis. - ScalarEvolution *SE; + AssumingScalarEvolution *SE; /// Target Library Info. TargetLibraryInfo *TLI; /// Parent function @@ -833,7 +837,7 @@ /// different operations. class LoopVectorizationCostModel { public: - LoopVectorizationCostModel(Loop *L, ScalarEvolution *SE, LoopInfo *LI, + LoopVectorizationCostModel(Loop *L, AssumingScalarEvolution *SE, LoopInfo *LI, LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, AssumptionCache *AC, @@ -909,7 +913,7 @@ /// The loop that we evaluate. Loop *TheLoop; /// Scev analysis. - ScalarEvolution *SE; + AssumingScalarEvolution *SE; /// Loop Info analysis. LoopInfo *LI; /// Vectorization legality. @@ -1179,7 +1183,7 @@ initializeLoopVectorizePass(*PassRegistry::getPassRegistry()); } - ScalarEvolution *SE; + AssumingScalarEvolution *SE; LoopInfo *LI; TargetTransformInfo *TTI; DominatorTree *DT; @@ -1194,7 +1198,7 @@ BlockFrequency ColdEntryFreq; bool runOnFunction(Function &F) override { - SE = &getAnalysis(); + SE = &getAnalysis(); LI = &getAnalysis().getLoopInfo(); TTI = &getAnalysis().getTTI(F); DT = &getAnalysis().getDomTree(); @@ -1227,8 +1231,11 @@ // Now walk the identified inner loops. bool Changed = false; - while (!Worklist.empty()) - Changed |= processLoop(Worklist.pop_back_val()); + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + SE->setAnalyzedLoop(L); + Changed |= processLoop(L); + } // Process each loop nest in the function. return Changed; @@ -1450,6 +1457,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -2222,6 +2230,30 @@ LastBypassBlock = CheckBlock; } + // Overflow check + Instruction *OFCheck; + Instruction *FirstOFCheckInst; + std::tie(FirstOFCheckInst, OFCheck) = + SE->addRuntimeOverflowChecks(LastBypassBlock->getTerminator()); + if (OFCheck) { + AddedSafetyChecks = true; + // Create a new block containing the stride check. + BasicBlock *CheckBlock = + LastBypassBlock->splitBasicBlock(FirstOFCheckInst, "vector.overflowcheck"); + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); + LoopBypassBlocks.push_back(CheckBlock); + + // Replace the branch into the memory check block with a conditional branch + // for the "few elements case". + Instruction *OldTerm = LastBypassBlock->getTerminator(); + BranchInst::Create(MiddleBlock, CheckBlock, Cmp, OldTerm); + OldTerm->eraseFromParent(); + + Cmp = OFCheck; + LastBypassBlock = CheckBlock; + } + LastBypassBlock->getTerminator()->eraseFromParent(); BranchInst::Create(MiddleBlock, VectorPH, Cmp, LastBypassBlock); @@ -3403,6 +3435,16 @@ // Collect all of the variables that remain uniform after vectorization. collectLoopUniforms(); + /// Vectorizing some loops requires assuming that some operations do no + /// overflow. These assumptions need to be checked at run-time. + /// Returns true if we know how to generate code to check the required + /// assumptions. + if (!SE->canPerformOverflowChecks()) { + DEBUG(dbgs() << "LV: Can't vectorize - unable to add " + << "overflow checks\n"); + return false; + } + DEBUG(dbgs() << "LV: We can vectorize this loop" << (LAI->getRuntimePointerCheck()->Need ? " (with a runtime bound check)" : "") @@ -3801,7 +3843,7 @@ } bool LoopVectorizationLegality::canVectorizeMemory() { - LAI = &LAA->getInfo(TheLoop, Strides); + LAI = &LAA->getInfo(TheLoop, Strides, true); auto &OptionalReport = LAI->getReport(); if (OptionalReport) emitAnalysis(VectorizationReport(*OptionalReport)); @@ -4681,7 +4723,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) +INITIALIZE_PASS_DEPENDENCY(AssumingScalarEvolution) INITIALIZE_PASS_DEPENDENCY(LCSSA) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) Index: test/Transforms/LoopVectorize/safegep.ll =================================================================== --- test/Transforms/LoopVectorize/safegep.ll +++ test/Transforms/LoopVectorize/safegep.ll @@ -9,8 +9,10 @@ ; PR16592 ; CHECK-LABEL: @safe( +; CHECK-LABEL-NOT: vector.overflowcheck ; CHECK: <4 x float> + define void @safe(float* %A, float* %B, float %K) { entry: br label %"" Index: test/Transforms/LoopVectorize/scev-overflow-check.ll =================================================================== --- test/Transforms/LoopVectorize/scev-overflow-check.ll +++ test/Transforms/LoopVectorize/scev-overflow-check.ll @@ -0,0 +1,125 @@ +; RUN: opt -mtriple=aarch64--linux-gnueabi -loop-vectorize < %s -S | FileCheck %s + +; CHECK-LABEL: test0 +define void @test0(i32* %A, + i32* %B, + i32* %C, i32 %N) { +entry: + %cmp13 = icmp eq i32 %N, 0 + br i1 %cmp13, label %for.end, label %for.body.preheader + +; If N is greater then 65535, this would loop forever. +; CHECK: icmp ugt i32 %N, 65535 + +for.body.preheader: + br label %for.body + +for.body: + %indvars.iv = phi i16 [ %indvars.next, %for.body ], [ 0, %for.body.preheader ] + %indvars.next = add i16 %indvars.iv, 1 + %indvars.ext = zext i16 %indvars.iv to i32 + + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %indvars.ext + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %C, i32 %indvars.ext + %1 = load i32, i32* %arrayidx3, align 4 + + %mul4 = mul i32 %1, %0 + + %arrayidx7 = getelementptr inbounds i32, i32* %A, i32 %indvars.ext + store i32 %mul4, i32* %arrayidx7, align 4 + + %exitcond = icmp eq i32 %indvars.ext, %N + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + +; CHECK-LABEL: test1 +define void @test1(i32* %A, + i32* %B, + i32* %C, i32 %N, i32 %Offset) { +entry: + %cmp13 = icmp eq i32 %N, 0 + br i1 %cmp13, label %for.end, label %for.body.preheader + +; Because of the GEPs, we need check that Offset + N does not overflow. +; CHECK: [[MUL0:%[a-zA-Z_0-9.]+]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 1, i32 %N) +; CHECK: [[MUL1:%[a-zA-Z_0-9.]+]] = extractvalue { i32, i1 } [[MUL0]], 0 +; CHECK: call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 [[MUL1]], i32 %Offset) + +for.body.preheader: + br label %for.body + +for.body: + %indvars.iv = phi i16 [ %indvars.next, %for.body ], [ 0, %for.body.preheader ] + %indvars.next = add i16 %indvars.iv, 1 + + %indvars.ext = zext i16 %indvars.iv to i32 + %indvars.access = add i32 %Offset, %indvars.ext + + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %indvars.access + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %C, i32 %indvars.access + %1 = load i32, i32* %arrayidx3, align 4 + + %mul4 = mul i32 %1, %0 + + %arrayidx7 = getelementptr inbounds i32, i32* %A, i32 %indvars.access + store i32 %mul4, i32* %arrayidx7, align 4 + + %exitcond = icmp eq i32 %indvars.ext, %N + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + +; CHECK-LABEL: test2 +define void @test2(i32* %A, + i32* %B, + i32* %C, i32 %N, i32 %Offset) { +entry: + %cmp13 = icmp eq i32 %N, 0 + br i1 %cmp13, label %for.end, label %for.body.preheader + +; CHECK: icmp sgt i32 %N, 32767 +; CHECK: icmp slt i32 %N, -32768 + +for.body.preheader: + br label %for.body + +for.body: + %indvars.iv = phi i16 [ %indvars.next, %for.body ], [ 0, %for.body.preheader ] + %indvars.next = add i16 %indvars.iv, 1 + + %indvars.ext = sext i16 %indvars.iv to i32 + %indvars.access = add i32 %Offset, %indvars.ext + + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %indvars.access + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %C, i32 %indvars.access + %1 = load i32, i32* %arrayidx3, align 4 + + %mul4 = add i32 %1, %0 + + %arrayidx7 = getelementptr inbounds i32, i32* %A, i32 %indvars.access + store i32 %mul4, i32* %arrayidx7, align 4 + + %exitcond = icmp eq i32 %indvars.ext, %N + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + Index: test/Transforms/LoopVectorize/version-mem-access.ll =================================================================== --- test/Transforms/LoopVectorize/version-mem-access.ll +++ test/Transforms/LoopVectorize/version-mem-access.ll @@ -16,11 +16,12 @@ %cmp13 = icmp eq i32 %N, 0 br i1 %cmp13, label %for.end, label %for.body.preheader +; We don't need to check the symbolic stride for B, we can assume instead +; that {0,+,BStride} will not overflow. + ; CHECK-DAG: icmp ne i64 %AStride, 1 -; CHECK-DAG: icmp ne i32 %BStride, 1 ; CHECK-DAG: icmp ne i64 %CStride, 1 ; CHECK: or -; CHECK: or ; CHECK: br ; CHECK: vector.body @@ -56,11 +57,11 @@ } ; We used to crash on this function because we removed the fptosi cast when -; replacing the symbolic stride '%conv'. -; PR18480 +; replacing the symbolic stride '%conv' (PR18480). However, replacing the +; symbolic stride is no longer required since we can do an overflow check. ; CHECK-LABEL: fn1 -; CHECK: load <2 x double> +; CHECK: store <2 x double> define void @fn1(double* noalias %x, double* noalias %c, double %a) { entry: