Index: include/polly/ScopDetection.h =================================================================== --- include/polly/ScopDetection.h +++ include/polly/ScopDetection.h @@ -299,6 +299,22 @@ /// @note An OpenMP subfunction will be marked as invalid. bool isValidFunction(llvm::Function &F); + /// @brief Can SCEV compute the trip count of a loop. + /// + /// @param L The loop to check. + /// @param Context The context of scop detection. + /// + /// @return True if SCEV can compute the trip count of the loop. + bool canUseSCEVTripCount(Loop *L, DetectionContext &Context) const; + + /// @brief Can ISL compute the trip count of a loop. + /// + /// @param L The loop to check. + /// @param Context The context of scop detection. + /// + /// @return True if ISL can compute the trip count of the loop. + bool canUseISLTripCount(Loop *L, DetectionContext &Context) const; + /// @brief Print the locations of all detected scops. void printLocations(llvm::Function &F); Index: include/polly/ScopInfo.h =================================================================== --- include/polly/ScopInfo.h +++ include/polly/ScopInfo.h @@ -557,6 +557,7 @@ __isl_give isl_set *buildConditionSet(const Comparison &Cmp); void addConditionsToDomain(TempScop &tempScop, const Region &CurRegion); void addLoopBoundsToDomain(TempScop &tempScop); + void addLoopTripCountToDomain(const Loop *L); void buildDomain(TempScop &tempScop, const Region &CurRegion); /// @brief Create the accesses for instructions in @p Block. @@ -1221,6 +1222,9 @@ /// /// @return true if a change was made bool restrictDomains(__isl_take isl_union_set *Domain); + + /// @brief Get the depth of a loop relative to the outermost loop in the Scop. + unsigned getRelativeLoopDepth(const Loop *L) const; }; /// @brief Print Scop scop to raw_ostream O. Index: include/polly/Support/SCEVAffinator.h =================================================================== --- include/polly/Support/SCEVAffinator.h +++ include/polly/Support/SCEVAffinator.h @@ -71,8 +71,6 @@ llvm::ScalarEvolution &SE; const ScopStmt *Stmt; - int getLoopDepth(const llvm::Loop *L); - __isl_give isl_pw_aff *visit(const llvm::SCEV *E); __isl_give isl_pw_aff *visitConstant(const llvm::SCEVConstant *E); __isl_give isl_pw_aff *visitTruncateExpr(const llvm::SCEVTruncateExpr *E); Index: lib/Analysis/ScopDetection.cpp =================================================================== --- lib/Analysis/ScopDetection.cpp +++ lib/Analysis/ScopDetection.cpp @@ -166,6 +166,11 @@ cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); +static cl::opt AllowNonSCEVBackedgeTakenCount( + "polly-allow-non-scev-backedge-taken-count", + cl::desc("Allow loops even if SCEV cannot provide a trip count"), + cl::Hidden, cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory)); + bool polly::PollyTrackFailures = false; bool polly::PollyDelinearize = false; StringRef polly::PollySkipFnAttr = "polly.skip.fn"; @@ -728,10 +733,69 @@ return invalid(Context, /*Assert=*/true, &Inst); } +bool ScopDetection::canUseSCEVTripCount(Loop *L, + DetectionContext &Context) const { + + Region &CurRegion = Context.CurRegion; + + // Ensure SCEV can compute the loop trip count. + const SCEV *LoopCount = SE->getBackedgeTakenCount(L); + if (isa(LoopCount)) + return false; + + // Ensure the trip count is affine. + if (!isAffineExpr(&CurRegion, LoopCount, *SE)) + return false; + + // We can use SCEV to compute the trip count of L. + return true; +} + +bool ScopDetection::canUseISLTripCount(Loop *L, + DetectionContext &Context) const { + + Region &CurRegion = Context.CurRegion; + + // Ensure the loop has a single back edge. + if (L->getNumBackEdges() != 1) + return false; + + // Ensure the loop has a single exiting block. + BasicBlock *ExitingBB = L->getExitingBlock(); + if (!ExitingBB) + return false; + + // Ensure the exiting block is terminated by a conditional branch. + BranchInst *Term = dyn_cast(ExitingBB->getTerminator()); + if (!Term || !Term->isConditional()) + return false; + + Value *Cond = Term->getCondition(); + + // If the terminating condition is an integer comparison, ensure that it is a + // comparison between a recurrence and an invariant value. + if (ICmpInst *I = dyn_cast(Cond)) { + const Value *Op0 = I->getOperand(0); + const Value *Op1 = I->getOperand(1); + const SCEV *LHS = SE->getSCEVAtScope(const_cast(Op0), L); + const SCEV *RHS = SE->getSCEVAtScope(const_cast(Op1), L); + if ((isa(LHS) && !isInvariant(*Op1, CurRegion)) || + (isa(RHS) && !isInvariant(*Op0, CurRegion))) + return false; + } + + // If the terminating condition is not an integer comparison, ensure that it + // is a constant. + else if (!isa(Cond)) + return false; + + // We can use ISL to compute the trip count of L. + return true; +} + bool ScopDetection::isValidLoop(Loop *L, DetectionContext &Context) const { // Is the loop count affine? - const SCEV *LoopCount = SE->getBackedgeTakenCount(L); - if (isAffineExpr(&Context.CurRegion, LoopCount, *SE)) { + if (canUseSCEVTripCount(L, Context) || canUseISLTripCount(L, Context)) { Context.hasAffineLoops = true; return true; } @@ -743,7 +807,8 @@ return true; } - return invalid(Context, /*Assert=*/true, L, LoopCount); + return invalid(Context, /*Assert=*/true, L, + SE->getBackedgeTakenCount(L)); } Region *ScopDetection::expandRegion(Region &R) { Index: lib/Analysis/ScopInfo.cpp =================================================================== --- lib/Analysis/ScopInfo.cpp +++ lib/Analysis/ScopInfo.cpp @@ -736,6 +736,65 @@ Domain = isl_set_align_params(Domain, Parent.getParamSpace()); } +void ScopStmt::addLoopTripCountToDomain(const Loop *L) { + + unsigned loopDimension = getParent()->getRelativeLoopDepth(L); + ScalarEvolution *SE = getParent()->getSE(); + isl_space *DomSpace = isl_set_get_space(Domain); + + isl_space *MapSpace = isl_space_map_from_set(isl_space_copy(DomSpace)); + isl_multi_aff *LoopMAff = isl_multi_aff_identity(MapSpace); + isl_aff *LoopAff = isl_multi_aff_get_aff(LoopMAff, loopDimension); + LoopAff = isl_aff_add_constant_si(LoopAff, 1); + LoopMAff = isl_multi_aff_set_aff(LoopMAff, loopDimension, LoopAff); + isl_map *TranslationMap = isl_map_from_multi_aff(LoopMAff); + + BasicBlock *ExitingBB = L->getExitingBlock(); + assert(ExitingBB && "Loop has more than one exiting block"); + + BranchInst *Term = dyn_cast(ExitingBB->getTerminator()); + assert(Term && Term->isConditional() && "Terminator is not conditional"); + + const SCEV *LHS = nullptr; + const SCEV *RHS = nullptr; + Value *Cond = Term->getCondition(); + CmpInst::Predicate Pred = CmpInst::Predicate::BAD_ICMP_PREDICATE; + + ICmpInst *CondICmpInst = dyn_cast(Cond); + ConstantInt *CondConstant = dyn_cast(Cond); + if (CondICmpInst) { + LHS = SE->getSCEVAtScope(CondICmpInst->getOperand(0), L); + RHS = SE->getSCEVAtScope(CondICmpInst->getOperand(1), L); + Pred = CondICmpInst->getPredicate(); + } else if (CondConstant) { + LHS = SE->getConstant(CondConstant); + RHS = SE->getConstant(ConstantInt::getTrue(SE->getContext())); + Pred = CmpInst::Predicate::ICMP_EQ; + } else { + llvm_unreachable("Condition is neither a ConstantInt nor a ICmpInst"); + } + + if (!L->contains(Term->getSuccessor(0))) + Pred = ICmpInst::getInversePredicate(Pred); + Comparison Comp(LHS, RHS, Pred); + + isl_set *CondSet = buildConditionSet(Comp); + isl_map *ForwardMap = isl_map_lex_le(isl_space_copy(DomSpace)); + for (unsigned i = 0; i < isl_set_n_dim(Domain); i++) + if (i != loopDimension) + ForwardMap = isl_map_equate(ForwardMap, isl_dim_in, i, isl_dim_out, i); + + ForwardMap = isl_map_apply_range(ForwardMap, isl_map_copy(TranslationMap)); + isl_set *CondDom = isl_set_subtract(isl_set_copy(Domain), CondSet); + isl_set *ForwardCond = isl_set_apply(CondDom, isl_map_copy(ForwardMap)); + isl_set *ForwardDomain = isl_set_apply(isl_set_copy(Domain), ForwardMap); + ForwardCond = isl_set_gist(ForwardCond, ForwardDomain); + Domain = isl_set_subtract(Domain, ForwardCond); + + isl_map_free(TranslationMap); + isl_space_free(DomSpace); +} + __isl_give isl_set *ScopStmt::buildConditionSet(const Comparison &Comp) { isl_pw_aff *L = getPwAff(Comp.getLHS()); isl_pw_aff *R = getPwAff(Comp.getRHS()); @@ -786,9 +845,15 @@ // IV <= LatchExecutions. const Loop *L = getLoopForDimension(i); const SCEV *LatchExecutions = SE->getBackedgeTakenCount(L); - isl_pw_aff *UpperBound = getPwAff(LatchExecutions); - isl_set *UpperBoundSet = isl_pw_aff_le_set(IV, UpperBound); - Domain = isl_set_intersect(Domain, UpperBoundSet); + if (!isa(LatchExecutions)) { + isl_pw_aff *UpperBound = getPwAff(LatchExecutions); + isl_set *UpperBoundSet = isl_pw_aff_le_set(IV, UpperBound); + Domain = isl_set_intersect(Domain, UpperBoundSet); + } else { + // If SCEV cannot provide a loop trip count we compute it with ISL. + addLoopTripCountToDomain(L); + isl_pw_aff_free(IV); + } } isl_local_space_free(LocalSpace); @@ -2056,6 +2121,12 @@ return StmtMapIt->second; } +unsigned Scop::getRelativeLoopDepth(const Loop *L) const { + Loop *OuterLoop = R.outermostLoopInRegion(const_cast(L)); + assert(OuterLoop && "Scop does not contain this loop"); + return L->getLoopDepth() - OuterLoop->getLoopDepth(); +} + //===----------------------------------------------------------------------===// ScopInfo::ScopInfo() : RegionPass(ID), scop(0) { ctx = isl_ctx_alloc(); Index: lib/Support/SCEVAffinator.cpp =================================================================== --- lib/Support/SCEVAffinator.cpp +++ lib/Support/SCEVAffinator.cpp @@ -160,7 +160,7 @@ isl_space *Space = isl_space_set_alloc(Ctx, 0, NumIterators); isl_local_space *LocalSpace = isl_local_space_from_space(Space); - int loopDimension = getLoopDepth(Expr->getLoop()); + unsigned loopDimension = S->getRelativeLoopDepth(Expr->getLoop()); isl_aff *LAff = isl_aff_set_coefficient_si( isl_aff_zero_on_domain(LocalSpace), isl_dim_in, loopDimension, 1); @@ -248,9 +248,3 @@ llvm_unreachable( "Unknowns SCEV was neither parameter nor a valid instruction."); } - -int SCEVAffinator::getLoopDepth(const Loop *L) { - Loop *outerLoop = S->getRegion().outermostLoopInRegion(const_cast(L)); - assert(outerLoop && "Scop does not contain this loop"); - return L->getLoopDepth() - outerLoop->getLoopDepth(); -} Index: test/ScopInfo/isl_trip_count_01.ll =================================================================== --- /dev/null +++ test/ScopInfo/isl_trip_count_01.ll @@ -0,0 +1,38 @@ +; RUN: opt %loadPolly -polly-detect-unprofitable -polly-allow-non-scev-backedge-taken-count -polly-scops -analyze < %s | FileCheck %s +; +; CHECK: [M, N] -> { Stmt_while_body[i0] : i0 >= 0 and 4i0 <= -M + N } +; +; void f(int *A, int N, int M) { +; int i = 0; +; while (M <= N) { +; A[i++] = 1; +; M += 4; +; } +; } +; +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" + +define void @f(i32* nocapture %A, i32 %N, i32 %M) { +entry: + %cmp3 = icmp sgt i32 %M, %N + br i1 %cmp3, label %while.end, label %while.body.preheader + +while.body.preheader: + br label %while.body + +while.body: + %i.05 = phi i32 [ %inc, %while.body ], [ 0, %while.body.preheader ] + %M.addr.04 = phi i32 [ %add, %while.body ], [ %M, %while.body.preheader ] + %inc = add nuw nsw i32 %i.05, 1 + %arrayidx = getelementptr inbounds i32, i32* %A, i32 %i.05 + store i32 1, i32* %arrayidx, align 4 + %add = add nsw i32 %M.addr.04, 4 + %cmp = icmp sgt i32 %add, %N + br i1 %cmp, label %while.end.loopexit, label %while.body + +while.end.loopexit: + br label %while.end + +while.end: + ret void +} Index: test/ScopInfo/isl_trip_count_02.ll =================================================================== --- /dev/null +++ test/ScopInfo/isl_trip_count_02.ll @@ -0,0 +1,33 @@ +; RUN: opt %loadPolly -polly-detect-unprofitable -polly-allow-non-scev-backedge-taken-count -polly-scops -analyze < %s | FileCheck %s +; +; CHECK: [M, N] -> { Stmt_for_body[i0] : i0 >= 0 and N <= -1 + M }; +; +; void f(int *A, int N, int M) { +; for (int i = M; i > N; i++) +; A[i] = i; +; } +; +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" + +define void @f(i32* %A, i32 %N, i32 %M) { +entry: + br label %entry.split + +entry.split: + %cmp.1 = icmp sgt i32 %M, %N + br i1 %cmp.1, label %for.body, label %for.end + +for.body: + %indvars.iv = phi i32 [ %indvars.iv.next, %for.body ], [ %M, %entry.split ] + %arrayidx = getelementptr inbounds i32, i32* %A, i32 %indvars.iv + store i32 %indvars.iv, i32* %arrayidx, align 4 + %cmp = icmp slt i32 %M, %N + %indvars.iv.next = add i32 %indvars.iv, 1 + br i1 %cmp, label %for.cond.for.end_crit_edge, label %for.body + +for.cond.for.end_crit_edge: + br label %for.end + +for.end: + ret void +}