Index: include/polly/Support/GICHelper.h =================================================================== --- include/polly/Support/GICHelper.h +++ include/polly/Support/GICHelper.h @@ -30,6 +30,7 @@ struct isl_aff; struct isl_pw_aff; struct isl_val; +struct isl_ast_node; namespace llvm { class Value; @@ -69,6 +70,36 @@ std::string getIslCompatibleName(std::string Prefix, const llvm::Value *Val, std::string Suffix); +/// @brief Return @p Aff incremented by @p i +__isl_give isl_pw_aff *incrementPwAff(__isl_take isl_pw_aff *Aff, int i); + +/// @brief Return an overestimation of the number of elements in @p S. +/// +/// @param S The set of which we want to now a bound on the elements. +/// @param Dim The number of input dimensions we want to keep. +/// @param LexMinPtr If not null the lexmin will be returned in here. +/// @param LexMaxPtr If not null the lexmax will be returned in here. +/// +/// @note The returned value is always "off by one" (to low) in each dimension. +__isl_give isl_pw_multi_aff * +getNumElementsUB(__isl_take isl_set *S, int Dim, + __isl_give isl_map **LexMinPtr = nullptr, + __isl_give isl_map **LexMaxPtr = nullptr); + +/// @brief Return the number of iterations for the outermost dim of @p Schedule. +/// +/// @see getNumberOfIterations(isl_ast_node *For). +__isl_give isl_pw_aff * +getNumberOfIterationsForSchedule(__isl_take isl_union_map *Schedule); + +/// @brief Return the number of iterations for @p For. +/// +/// @note This function used getNumElementsUB which guarantees only an +/// overestimation on the number of elements of the given set. +/// However, due to the nature of the set we are using the number of +/// elements will always be exact. (The set here is the loop domain!) +__isl_give isl_pw_aff *getNumberOfIterations(__isl_keep isl_ast_node *For); + } // end namespace polly #endif Index: lib/Support/GICHelper.cpp =================================================================== --- lib/Support/GICHelper.cpp +++ lib/Support/GICHelper.cpp @@ -11,7 +11,10 @@ // //===----------------------------------------------------------------------===// #include "polly/Support/GICHelper.h" +#include "polly/CodeGen/IslAst.h" + #include "llvm/IR/Value.h" + #include "isl/aff.h" #include "isl/map.h" #include "isl/schedule.h" @@ -19,6 +22,7 @@ #include "isl/union_map.h" #include "isl/union_set.h" #include "isl/val.h" +#include "isl/ast.h" using namespace llvm; @@ -152,3 +156,50 @@ makeIslCompatible(ValStr); return ValStr; } + +isl_pw_aff *polly::incrementPwAff(isl_pw_aff *Aff, int i) { + isl_aff *Zero = isl_aff_zero_on_domain( + isl_local_space_from_space(isl_pw_aff_get_domain_space(Aff))); + isl_aff *Inc = isl_aff_add_constant_si(Zero, i); + return isl_pw_aff_add(Aff, isl_pw_aff_from_aff(Inc)); +} + +isl_pw_multi_aff *polly::getNumElementsUB(isl_set *S, int Dim, + isl_map **LexMinPtr, + isl_map **LexMaxPtr) { + S = isl_set_reset_tuple_id(S); + + isl_map *Map = isl_map_from_domain_and_range(isl_set_copy(S), S); + for (int i = 0; i < Dim; i++) + Map = isl_map_equate(Map, isl_dim_in, i, isl_dim_out, i); + + isl_map *LexMax = isl_map_lexmax(isl_map_copy(Map)); + if (LexMaxPtr) + *LexMaxPtr = isl_map_copy(LexMax); + + isl_map *LexMin = isl_map_lexmin(Map); + if (LexMinPtr) + *LexMinPtr = isl_map_copy(LexMin); + + isl_map *Sub = isl_map_sum(LexMax, isl_map_neg(LexMin)); + isl_pw_multi_aff *Elements = isl_pw_multi_aff_from_map(Sub); + return isl_pw_multi_aff_coalesce(Elements); +} + +isl_pw_aff *polly::getNumberOfIterationsForSchedule(isl_union_map *Schedule) { + isl_union_set *Range = isl_union_map_range(Schedule); + isl_set *LoopDomain = isl_set_from_union_set(Range); + int InDim = isl_set_n_dim(LoopDomain) - 1; + isl_pw_multi_aff *NumElements = getNumElementsUB(LoopDomain, InDim); + unsigned Dim = isl_pw_multi_aff_dim(NumElements, isl_dim_set) - 1; + isl_pw_aff *NumIterations = isl_pw_multi_aff_get_pw_aff(NumElements, Dim); + isl_pw_multi_aff_free(NumElements); + NumIterations = incrementPwAff(NumIterations, 1); + NumIterations = isl_pw_aff_coalesce(NumIterations); + return NumIterations; +} + +isl_pw_aff *polly::getNumberOfIterations(isl_ast_node *For) { + isl_union_map *Schedule = IslAstInfo::getSchedule(For); + return getNumberOfIterationsForSchedule(Schedule); +}