Index: include/polly/Support/ISLTools.h =================================================================== --- include/polly/Support/ISLTools.h +++ include/polly/Support/ISLTools.h @@ -410,6 +410,119 @@ /// /// @return { Domain[] -> Range[] } isl::map intersectRange(isl::map Map, isl::union_set Range); + +/// If @p PwAff maps to a constant, return said constant. If @p Max/@p Min, it +/// can also be a piecewise constant and it would return the minimum/maximum +/// value. Otherwise, return NaN. +isl::val getConstant(isl::pw_aff PwAff, bool Max, bool Min); + +/// Dump a description of the argument to llvm::errs(). +/// +/// In contrast to isl's dump function, there are a few differences: +/// - Each polyhedron (pieces) is written on its own line. +/// - Spaces are sorted by structure. E.g. maps with same domain space are +/// grouped. Isl sorts them according to the space's hash function. +/// - Pieces of the same space are sorted using their lower bound. +/// - A more compact to_str representation is used instead of Isl's dump +/// functions that try to show the internal representation. +/// +/// The goal is to get a better understandable representation that is also +/// useful to compare two sets. As all dump() functions, its intended use is to +/// be called in a debugger only. +/// +/// isl_map_dump example: +/// [p_0, p_1, p_2] -> { Stmt0[i0] -> [o0, o1] : (o0 = i0 and o1 = 0 and i0 > 0 +/// and i0 <= 5 - p_2) or (i0 = 0 and o0 = 0 and o1 = 0); Stmt3[i0] -> [o0, o1] +/// : (o0 = i0 and o1 = 3 and i0 > 0 and i0 <= 5 - p_2) or (i0 = 0 and o0 = 0 +/// and o1 = 3); Stmt2[i0] -> [o0, o1] : (o0 = i0 and o1 = 1 and i0 >= 3 + p_0 - +/// p_1 and i0 > 0 and i0 <= 5 - p_2) or (o0 = i0 and o1 = 1 and i0 > 0 and i0 +/// <= 5 - p_2 and i0 < p_0 - p_1) or (i0 = 0 and o0 = 0 and o1 = 1 and p_1 >= 3 +/// + p_0) or (i0 = 0 and o0 = 0 and o1 = 1 and p_1 < p_0) or (p_0 = 0 and i0 = +/// 2 - p_1 and o0 = 2 - p_1 and o1 = 1 and p_2 <= 3 + p_1 and p_1 <= 1) or (p_1 +/// = 1 + p_0 and i0 = 0 and o0 = 0 and o1 = 1) or (p_0 = 0 and p_1 = 2 and i0 = +/// 0 and o0 = 0 and o1 = 1) or (p_0 = -1 and p_1 = -1 and i0 = 0 and o0 = 0 and +/// o1 = 1); Stmt1[i0] -> [o0, o1] : (p_0 = -1 and i0 = 1 - p_1 and o0 = 1 - p_1 +/// and o1 = 2 and p_2 <= 4 + p_1 and p_1 <= 0) or (p_0 = 0 and i0 = -p_1 and o0 +/// = -p_1 and o1 = 2 and p_2 <= 5 + p_1 and p_1 < 0) or (p_0 = -1 and p_1 = 1 +/// and i0 = 0 and o0 = 0 and o1 = 2) or (p_0 = 0 and p_1 = 0 and i0 = 0 and o0 +/// = 0 and o1 = 2) } +/// +/// dumpPw example (same set): +/// [p_0, p_1, p_2] -> { +/// Stmt0[0] -> [0, 0]; +/// Stmt0[i0] -> [i0, 0] : 0 < i0 <= 5 - p_2; +/// Stmt1[0] -> [0, 2] : p_1 = 1 and p_0 = -1; +/// Stmt1[0] -> [0, 2] : p_1 = 0 and p_0 = 0; +/// Stmt1[1 - p_1] -> [1 - p_1, 2] : p_0 = -1 and p_1 <= 0 and p_2 <= 4 + p_1; +/// Stmt1[-p_1] -> [-p_1, 2] : p_0 = 0 and p_1 < 0 and p_2 <= 5 + p_1; +/// Stmt2[0] -> [0, 1] : p_1 >= 3 + p_0; +/// Stmt2[0] -> [0, 1] : p_1 < p_0; +/// Stmt2[0] -> [0, 1] : p_1 = 1 + p_0; +/// Stmt2[0] -> [0, 1] : p_1 = 2 and p_0 = 0; +/// Stmt2[0] -> [0, 1] : p_1 = -1 and p_0 = -1; +/// Stmt2[i0] -> [i0, 1] : i0 >= 3 + p_0 - p_1 and 0 < i0 <= 5 - p_2; +/// Stmt2[i0] -> [i0, 1] : 0 < i0 <= 5 - p_2 and i0 < p_0 - p_1; +/// Stmt2[2 - p_1] -> [2 - p_1, 1] : p_0 = 0 and p_1 <= 1 and p_2 <= 3 + p_1; +/// Stmt3[0] -> [0, 3]; +/// Stmt3[i0] -> [i0, 3] : 0 < i0 <= 5 - p_2 +/// } +/// @{ +void dumpPw(const isl::set &Set); +void dumpPw(const isl::map &Map); +void dumpPw(const isl::union_set &USet); +void dumpPw(const isl::union_map &UMap); +void dumpPw(__isl_keep isl_set *Set); +void dumpPw(__isl_keep isl_map *Map); +void dumpPw(__isl_keep isl_union_set *USet); +void dumpPw(__isl_keep isl_union_map *UMap); +/// @} + +/// Dump all points of the argument to llvm::errs(). +/// +/// Before being printed by dumpPw(), the argument's pieces are expanded to +/// contain only single points. If a dimension is unbounded, it keeps its +/// representation. +/// +/// This is useful for debugging reduced cases where parameters are set to +/// constants to keep the example simple. Such sets can still contain +/// existential dimensions which makes the polyhedral hard to compare. +/// +/// Example: +/// { [MemRef_A[i0] -> [i1]] : (exists (e0 = floor((1 + i1)/3): i0 = 1 and 3e0 +/// <= i1 and 3e0 >= -1 + i1 and i1 >= 15 and i1 <= 25)) or (exists (e0 = +/// floor((i1)/3): i0 = 0 and 3e0 < i1 and 3e0 >= -2 + i1 and i1 > 0 and i1 <= +/// 11)) } +/// +/// dumpExpanded: +/// { +/// [MemRef_A[0] ->[1]]; +/// [MemRef_A[0] ->[2]]; +/// [MemRef_A[0] ->[4]]; +/// [MemRef_A[0] ->[5]]; +/// [MemRef_A[0] ->[7]]; +/// [MemRef_A[0] ->[8]]; +/// [MemRef_A[0] ->[10]]; +/// [MemRef_A[0] ->[11]]; +/// [MemRef_A[1] ->[15]]; +/// [MemRef_A[1] ->[16]]; +/// [MemRef_A[1] ->[18]]; +/// [MemRef_A[1] ->[19]]; +/// [MemRef_A[1] ->[21]]; +/// [MemRef_A[1] ->[22]]; +/// [MemRef_A[1] ->[24]]; +/// [MemRef_A[1] ->[25]] +/// } +/// @{ +void dumpExpanded(const isl::set &Set); +void dumpExpanded(const isl::map &Map); +void dumpExpanded(const isl::union_set &USet); +void dumpExpanded(const isl::union_map &UMap); +void dumpExpanded(__isl_keep isl_set *Set); +void dumpExpanded(__isl_keep isl_map *Map); +void dumpExpanded(__isl_keep isl_union_set *USet); +void dumpExpanded(__isl_keep isl_union_map *UMap); +/// @} + } // namespace polly #endif /* POLLY_ISLTOOLS_H */ Index: lib/Support/ISLTools.cpp =================================================================== --- lib/Support/ISLTools.cpp +++ lib/Support/ISLTools.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "polly/Support/ISLTools.h" +#include "llvm/ADT/StringRef.h" using namespace polly; @@ -536,3 +537,340 @@ isl::set RangeSet = Range.extract_set(Map.get_space().range()); return Map.intersect_range(RangeSet); } + +isl::val polly::getConstant(isl::pw_aff PwAff, bool Max, bool Min) { + assert(!Max || !Min); // Cannot return min and max at the same time. + isl::val Result; + PwAff.foreach_piece([=, &Result](isl::set Set, isl::aff Aff) -> isl::stat { + if (Result && Result.is_nan()) + return isl::stat::ok; + + // TODO: If Min/Max, we can also determine a minimum/maximum value if + // Set is constant-bounded. + if (!Aff.is_cst()) { + Result = isl::val::nan(Aff.get_ctx()); + return isl::stat::error; + } + + isl::val ThisVal = Aff.get_constant_val(); + if (!Result) { + Result = ThisVal; + return isl::stat::ok; + } + + if (Result.eq(ThisVal)) + return isl::stat::ok; + + if (Max && ThisVal.gt(Result)) { + Result = ThisVal; + return isl::stat::ok; + } + + if (Min && ThisVal.lt(Result)) { + Result = ThisVal; + return isl::stat::ok; + } + + // Not compatible + Result = isl::val::nan(Aff.get_ctx()); + return isl::stat::error; + }); + return Result; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +static void foreachPoint(const isl::set &Set, + const std::function &F) { + isl_set_foreach_point( + Set.keep(), + [](__isl_take isl_point *p, void *User) -> isl_stat { + auto &F = *static_cast *>(User); + F(give(p)); + return isl_stat_ok; + }, + const_cast(static_cast(&F))); +} + +static void foreachPoint(isl::basic_set BSet, + const std::function &F) { + foreachPoint(give(isl_set_from_basic_set(BSet.take())), F); +} + +/// Determine the sorting order of the sets @p A and @p B without considering +/// the space structure. +/// +/// Ordering is based on the lower bounds of the set's dimensions. First +/// dimensions are considered first. +static int flatCompare(const isl::basic_set &A, const isl::basic_set &B) { + int ALen = A.dim(isl::dim::set); + int BLen = B.dim(isl::dim::set); + int Len = std::min(ALen, BLen); + + for (int i = 0; i < Len; i += 1) { + isl::basic_set ADim = + A.project_out(isl::dim::param, 0, A.dim(isl::dim::param)) + .project_out(isl::dim::set, i + 1, ALen - i - 1) + .project_out(isl::dim::set, 0, i); + isl::basic_set BDim = + B.project_out(isl::dim::param, 0, B.dim(isl::dim::param)) + .project_out(isl::dim::set, i + 1, BLen - i - 1) + .project_out(isl::dim::set, 0, i); + + isl::basic_set AHull = isl::set(ADim).convex_hull(); + isl::basic_set BHull = isl::set(BDim).convex_hull(); + + bool ALowerBounded = + bool(isl::set(AHull).dim_has_any_lower_bound(isl::dim::set, 0)); + bool BLowerBounded = + bool(isl::set(BHull).dim_has_any_lower_bound(isl::dim::set, 0)); + + int BoundedCompare = BLowerBounded - ALowerBounded; + if (BoundedCompare != 0) + return BoundedCompare; + + if (!ALowerBounded || !BLowerBounded) + continue; + + isl::pw_aff AMin = isl::set(ADim).dim_min(0); + isl::pw_aff BMin = isl::set(BDim).dim_min(0); + + isl::val AMinVal = polly::getConstant(AMin, false, true); + isl::val BMinVal = polly::getConstant(BMin, false, true); + + int MinCompare = AMinVal.sub(BMinVal).sgn(); + if (MinCompare != 0) + return MinCompare; + } + + // If all the dimensions' lower bounds are equal or incomparable, sort based + // on the number of dimensions. + return ALen - BLen; +} + +/// Compare the sets @p A and @p B according to their nested space structure. If +/// the structure is the same, sort using the dimension lower bounds. +static int recursiveCompare(const isl::basic_set &A, const isl::basic_set &B) { + isl::space ASpace = A.get_space(); + isl::space BSpace = B.get_space(); + + int WrappingCompare = bool(ASpace.is_wrapping()) - bool(BSpace.is_wrapping()); + if (WrappingCompare != 0) + return WrappingCompare; + + if (ASpace.is_wrapping() && B.is_wrapping()) { + isl::basic_map AMap = A.unwrap(); + isl::basic_map BMap = B.unwrap(); + + int FirstResult = recursiveCompare(AMap.domain(), BMap.domain()); + if (FirstResult != 0) + return FirstResult; + + return recursiveCompare(AMap.range(), BMap.range()); + } + + std::string AName = ASpace.has_tuple_name(isl::dim::set) + ? ASpace.get_tuple_name(isl::dim::set) + : std::string(); + std::string BName = BSpace.has_tuple_name(isl::dim::set) + ? BSpace.get_tuple_name(isl::dim::set) + : std::string(); + + int NameCompare = AName.compare(BName); + if (NameCompare != 0) + return NameCompare; + + return flatCompare(A, B); +} + +/// Wrapper for recursiveCompare, convert a {-1,0,1} compare result to what +/// std::sort expects. +static bool orderComparer(const isl::basic_set &A, const isl::basic_set &B) { + return recursiveCompare(A, B) < 0; +} + +/// Print a string representation of @p USet to @p OS. +/// +/// The pieces of @p USet are printed in a sorted order. Spaces with equal or +/// similar nesting structure are printed together. Compared to isl's own +/// printing function the uses the structure itself as base of the sorting, not +/// a hash of it. It ensures that e.g. maps spaces with same domain structure +/// are printed together. Set pieces with same structure are printed in order of +/// their lower bounds. +/// +/// @param USet Polyhedra to print. +/// @param OS Target stream. +/// @param Simplify Whether to simplify the polyhedron before printing. +/// @param IsMap Whether @p USet is a wrapped map. If true, sets are +/// unwrapped before printing to again appear as a map. +static void printSortedPolyhedra(isl::union_set USet, llvm::raw_ostream &OS, + bool Simplify, bool IsMap) { + if (!USet) { + OS << "\n"; + return; + } + + if (Simplify) + simplify(USet); + + // Get all the polyhedra. + std::vector BSets; + USet.foreach_set([&BSets](isl::set Set) -> isl::stat { + Set.foreach_basic_set([&BSets](isl::basic_set BSet) -> isl::stat { + BSets.push_back(BSet); + return isl::stat::ok; + }); + return isl::stat::ok; + }); + + if (BSets.empty()) { + OS << "{\n}\n"; + return; + } + + // Sort the polyhedra. + std::sort(BSets.begin(), BSets.end(), orderComparer); + + // Print the polyhedra. + bool First = true; + for (const isl::basic_set &BSet : BSets) { + std::string Str; + if (IsMap) + Str = isl::map(BSet.unwrap()).to_str(); + else + Str = isl::set(BSet).to_str(); + size_t OpenPos = Str.find_first_of('{'); + assert(OpenPos != std::string::npos); + size_t ClosePos = Str.find_last_of('}'); + assert(ClosePos != std::string::npos); + + if (First) + OS << llvm::StringRef(Str).substr(0, OpenPos + 1) << "\n "; + else + OS << ";\n "; + + OS << llvm::StringRef(Str).substr(OpenPos + 1, ClosePos - OpenPos - 2); + First = false; + } + assert(!First); + OS << "\n}\n"; +} + +static void recursiveExpand(isl::basic_set BSet, int Dim, isl::set &Expanded) { + int Dims = BSet.dim(isl::dim::set); + if (Dim >= Dims) { + Expanded = Expanded.unite(BSet); + return; + } + + isl::basic_set DimOnly = + BSet.project_out(isl::dim::param, 0, BSet.dim(isl::dim::param)) + .project_out(isl::dim::set, Dim + 1, Dims - Dim - 1) + .project_out(isl::dim::set, 0, Dim); + if (!DimOnly.is_bounded()) { + recursiveExpand(BSet, Dim + 1, Expanded); + return; + } + + foreachPoint(DimOnly, [&, Dim](isl::point P) { + isl::val Val = P.get_coordinate_val(isl::dim::set, 0); + isl::basic_set FixBSet = BSet.fix_val(isl::dim::set, Dim, Val); + recursiveExpand(FixBSet, Dim + 1, Expanded); + }); +} + +/// Make each point of a set explicit. +/// +/// "Expanding" makes each point a set contains explicit. That is, the result is +/// a set of singleton polyhedra. Unbounded dimensions are not expanded. +/// +/// Example: +/// { [i] : 0 <= i < 2 } +/// is expanded to: +/// { [0]; [1] } +static isl::set expand(const isl::set &Set) { + isl::set Expanded = isl::set::empty(Set.get_space()); + Set.foreach_basic_set([&](isl::basic_set BSet) -> isl::stat { + recursiveExpand(BSet, 0, Expanded); + return isl::stat::ok; + }); + return Expanded; +} + +/// Expand all points of a union set explicit. +/// +/// @see expand(const isl::set) +static isl::union_set expand(const isl::union_set &USet) { + isl::union_set Expanded = + give(isl_union_set_empty(isl_union_set_get_space(USet.keep()))); + USet.foreach_set([&](isl::set Set) -> isl::stat { + isl::set SetExpanded = expand(Set); + Expanded = Expanded.add_set(SetExpanded); + return isl::stat::ok; + }); + return Expanded; +} + +LLVM_DUMP_METHOD void polly::dumpPw(const isl::set &Set) { + printSortedPolyhedra(Set, llvm::errs(), true, false); +} + +LLVM_DUMP_METHOD void polly::dumpPw(const isl::map &Map) { + printSortedPolyhedra(Map.wrap(), llvm::errs(), true, true); +} + +LLVM_DUMP_METHOD void polly::dumpPw(const isl::union_set &USet) { + printSortedPolyhedra(USet, llvm::errs(), true, false); +} + +LLVM_DUMP_METHOD void polly::dumpPw(const isl::union_map &UMap) { + printSortedPolyhedra(UMap.wrap(), llvm::errs(), true, true); +} + +LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_set *Set) { + dumpPw(isl::manage(isl_set_copy(Set))); +} + +LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_map *Map) { + dumpPw(isl::manage(isl_map_copy(Map))); +} + +LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_union_set *USet) { + dumpPw(isl::manage(isl_union_set_copy(USet))); +} + +LLVM_DUMP_METHOD void polly::dumpPw(__isl_keep isl_union_map *UMap) { + dumpPw(isl::manage(isl_union_map_copy(UMap))); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::set &Set) { + printSortedPolyhedra(expand(Set), llvm::errs(), false, false); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::map &Map) { + printSortedPolyhedra(expand(Map.wrap()), llvm::errs(), false, true); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::union_set &USet) { + printSortedPolyhedra(expand(USet), llvm::errs(), false, false); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(const isl::union_map &UMap) { + printSortedPolyhedra(expand(UMap.wrap()), llvm::errs(), false, true); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_set *Set) { + dumpExpanded(isl::manage(isl_set_copy(Set))); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_map *Map) { + dumpExpanded(isl::manage(isl_map_copy(Map))); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_union_set *USet) { + dumpExpanded(isl::manage(isl_union_set_copy(USet))); +} + +LLVM_DUMP_METHOD void polly::dumpExpanded(__isl_keep isl_union_map *UMap) { + dumpExpanded(isl::manage(isl_union_map_copy(UMap))); +} +#endif Index: lib/Transform/FlattenAlgo.cpp =================================================================== --- lib/Transform/FlattenAlgo.cpp +++ lib/Transform/FlattenAlgo.cpp @@ -14,6 +14,7 @@ #include "polly/FlattenAlgo.h" #include "polly/Support/ISLOStream.h" +#include "polly/Support/ISLTools.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "polly-flatten-algo" @@ -69,49 +70,6 @@ }) == isl::stat::ok; } -/// If @p PwAff maps to a constant, return said constant. If @p Max/@p Min, it -/// can also be a piecewise constant and it would return the minimum/maximum -/// value. Otherwise, return NaN. -isl::val getConstant(isl::pw_aff PwAff, bool Max, bool Min) { - assert(!Max || !Min); - isl::val Result; - PwAff.foreach_piece([=, &Result](isl::set Set, isl::aff Aff) -> isl::stat { - if (Result && Result.is_nan()) - return isl::stat::ok; - - // TODO: If Min/Max, we can also determine a minimum/maximum value if - // Set is constant-bounded. - if (!Aff.is_cst()) { - Result = isl::val::nan(Aff.get_ctx()); - return isl::stat::error; - } - - auto ThisVal = Aff.get_constant_val(); - if (!Result) { - Result = ThisVal; - return isl::stat::ok; - } - - if (Result.eq(ThisVal)) - return isl::stat::ok; - - if (Max && ThisVal.gt(Result)) { - Result = ThisVal; - return isl::stat::ok; - } - - if (Min && ThisVal.lt(Result)) { - Result = ThisVal; - return isl::stat::ok; - } - - // Not compatible - Result = isl::val::nan(Aff.get_ctx()); - return isl::stat::error; - }); - return Result; -} - /// Compute @p UPwAff - @p Val. isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) { if (Val.is_zero())