Index: include/polly/FlattenAlgo.h =================================================================== --- /dev/null +++ include/polly/FlattenAlgo.h @@ -0,0 +1,38 @@ +//===------ FlattenAlgo.h --------------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Main algorithm of the FlattenSchedulePass. This is a separate file to avoid +// the unittest for this requiring linking against LLVM. +// +//===----------------------------------------------------------------------===// + +#ifndef POLLY_FLATTENALGO_H +#define POLLY_FLATTENALGO_H + +#include "polly/Support/GICHelper.h" + +namespace polly { +/// Recursively flatten a schedule. +/// +/// Reduce the number of scatter dimensions as much as possible without changing +/// the relative order of instances in a schedule. Ideally, this results in a +/// single scatter dimension, but it may not always be possible to combine +/// dimensions, eg. if a dimension is unbounded. In worst case, the original +/// schedule is returned. +/// +/// Schedules with fewer dimensions may be easier to understand for humans, but +/// it should make no difference to the computer. +/// +/// @param Schedule The input schedule. +/// +/// @return The flattened schedule. +IslPtr flattenSchedule(IslPtr Schedule); +} // namespace polly + +#endif /* POLLY_FLATTENALGO_H */ Index: include/polly/FlattenSchedule.h =================================================================== --- /dev/null +++ include/polly/FlattenSchedule.h @@ -0,0 +1,32 @@ +//===------ FlattenSchedule.h ----------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Try to reduce the number of scatter dimension. Useful to make isl_union_map +// schedules more understandable. This is only intended for debugging and +// unittests, not for optimizations themselves. +// +//===----------------------------------------------------------------------===// + +#ifndef POLLY_FLATTENSCHEDULE_H +#define POLLY_FLATTENSCHEDULE_H + +namespace llvm { +class PassRegistry; +class Pass; +} // anonymous namespace + +namespace polly { +llvm::Pass *createFlattenSchedulePass(); +} // namespace polly + +namespace llvm { +void initializeFlattenSchedulePass(llvm::PassRegistry &); +} // namespace llvm + +#endif /* POLLY_FLATTENSCHEDULE_H */ Index: include/polly/LinkAllPasses.h =================================================================== --- include/polly/LinkAllPasses.h +++ include/polly/LinkAllPasses.h @@ -47,6 +47,7 @@ llvm::Pass *createPPCGCodeGenerationPass(); #endif llvm::Pass *createIslScheduleOptimizerPass(); +llvm::Pass *createFlattenSchedulePass(); extern char &CodePreparationID; } // namespace polly @@ -80,6 +81,7 @@ polly::createPPCGCodeGenerationPass(); #endif polly::createIslScheduleOptimizerPass(); + polly::createFlattenSchedulePass(); } } PollyForcePassLinking; // Force link by creating a global definition. } // namespace @@ -97,6 +99,7 @@ #endif void initializeIslScheduleOptimizerPass(llvm::PassRegistry &); void initializePollyCanonicalizePass(llvm::PassRegistry &); +void initializeFlattenSchedulePass(llvm::PassRegistry &); } // namespace llvm #endif Index: include/polly/Support/GICHelper.h =================================================================== --- include/polly/Support/GICHelper.h +++ include/polly/Support/GICHelper.h @@ -16,21 +16,17 @@ #include "llvm/ADT/APInt.h" #include "llvm/Support/raw_ostream.h" +#include "isl/aff.h" #include "isl/ctx.h" +#include "isl/map.h" +#include "isl/set.h" +#include "isl/union_map.h" +#include "isl/union_set.h" +#include #include -struct isl_map; -struct isl_union_map; -struct isl_set; -struct isl_union_set; struct isl_schedule; struct isl_multi_aff; -struct isl_pw_multi_aff; -struct isl_union_pw_multi_aff; -struct isl_aff; -struct isl_pw_aff; -struct isl_val; -struct isl_space; namespace llvm { class Value; @@ -177,6 +173,239 @@ const std::string &Middle, const std::string &Suffix); +/// IslObjTraits is a static class to invoke common functions that all +/// ISL objects have: isl_*_copy, isl_*_free, isl_*_get_ctx and isl_*_to_str. +/// These functions follow a common naming scheme, but not a base class +/// hierarchy (as ISL is written in C). As such, the functions are accessible +/// only by constructing the function name using the preprocessor. This class +/// serves to make these names accessible to a C++ template scheme. +/// +/// There is an isl_obj polymorphism layer, but its implementation is +/// incomplete. +template class IslObjTraits; + +#define DECLARE_TRAITS(TYPE) \ + template <> class IslObjTraits { \ + public: \ + static __isl_give isl_##TYPE *copy(__isl_keep isl_##TYPE *Obj) { \ + return isl_##TYPE##_copy(Obj); \ + } \ + static void free(__isl_take isl_##TYPE *Obj) { isl_##TYPE##_free(Obj); } \ + static isl_ctx *get_ctx(__isl_keep isl_##TYPE *Obj) { \ + return isl_##TYPE##_get_ctx(Obj); \ + } \ + static std::string to_str(__isl_keep isl_##TYPE *Obj) { \ + if (!Obj) \ + return "null"; \ + char *cstr = isl_##TYPE##_to_str(Obj); \ + if (!cstr) \ + return "null"; \ + std::string Result{cstr}; \ + ::free(cstr); \ + return Result; \ + } \ + }; + +DECLARE_TRAITS(val) +DECLARE_TRAITS(space) +DECLARE_TRAITS(basic_map) +DECLARE_TRAITS(map) +DECLARE_TRAITS(union_map) +DECLARE_TRAITS(basic_set) +DECLARE_TRAITS(set) +DECLARE_TRAITS(union_set) +DECLARE_TRAITS(aff) +DECLARE_TRAITS(pw_aff) +DECLARE_TRAITS(union_pw_aff) +DECLARE_TRAITS(multi_union_pw_aff) +DECLARE_TRAITS(union_pw_multi_aff) + +template class NonowningIslPtr; + +/// Smart pointer to an ISL object. +/// +/// An object of this class owns an reference of an ISL object, meaning if will +/// free it when destroyed. Most ISL objects are reference counted such that we +/// gain an automatic memory management. +/// +/// Function parameters in the ISL API are annotated using either __isl_keep +/// __isl_take. Return values that are objects are annotated using __is_give, +/// meaning the caller is responsible for releasing the object. When annotated +/// with __isl_keep, use the keep() function to pass a plain pointer to the ISL +/// object. For __isl_take-annotated parameters, use either copy() to increase +/// the reference counter by one, or take() to pass the ownership to the called +/// function. When IslPtr loses ownership, it cannot be used anymore and won't +/// free the object when destroyed. Use the give() function to wrap the +/// ownership of a returned isl_* object into an IstPtr. +/// +/// There is purposefully no implicit conversion from/to plain isl_* pointers to +/// avoid difficult to find bugs because keep/copy/take would have been +/// required. +template class IslPtr { + typedef IslPtr ThisTy; + typedef IslObjTraits Traits; + +private: + T *Obj; + + explicit IslPtr(__isl_take T *Obj, bool TakeOwnership) + : Obj(TakeOwnership ? Obj : Traits::copy(Obj)) {} + +public: + IslPtr() : Obj(nullptr) {} + /* implicit */ IslPtr(nullptr_t That) : IslPtr() {} + + /* implicit */ IslPtr(const ThisTy &That) : IslPtr(That.Obj, false) {} + /* implicit */ IslPtr(ThisTy &&That) : IslPtr(That.Obj, true) { + That.Obj = nullptr; + } + /* implicit */ IslPtr(NonowningIslPtr That) : IslPtr(That.copy(), true) {} + ~IslPtr() { Traits::free(Obj); } + + ThisTy &operator=(const ThisTy &That) { + Traits::free(this->Obj); + this->Obj = Traits::copy(That.Obj); + return *this; + } + ThisTy &operator=(ThisTy &&That) { + swap(*this, That); + return *this; + } + + explicit operator bool() const { return Obj; } + + static void swap(ThisTy &LHS, ThisTy &RHS) { std::swap(LHS.Obj, RHS.Obj); } + + static ThisTy give(T *Obj) { return ThisTy(Obj, true); } + T *keep() const { return Obj; } + T *take() { + auto *Result = Obj; + Obj = nullptr; + return Result; + } + T *copy() const { return Traits::copy(Obj); } + + isl_ctx *getCtx() const { return Traits::get_ctx(Obj); } + std::string toStr() const { return Traits::to_str(Obj); } +}; + +template static IslPtr give(__isl_take T *Obj) { + return IslPtr::give(Obj); +} + +template +llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const IslPtr &Obj) { + OS << IslObjTraits::to_str(Obj.keep()); + return OS; +} + +/// Smart pointer to an ISL object, but does not release it when destroyed. +/// +/// This is meant to be used as function parameter type. The caller guarantees +/// that the reference is alive during the function's execution and hence +/// doesn't need to add a reference. Therefore, it is equivalent to the +/// __isl_keep annotation (IslPtr being equivalent to __isl_take which can be +/// either copied or moved). +/// +/// Just as IslPtr, it has keep() and copy() methods. The take() method is +/// missing as this would steal the reference from the owner (the caller). +template class NonowningIslPtr { + typedef NonowningIslPtr ThisTy; + typedef IslObjTraits Traits; + +private: + T *Obj; + + /* implicit */ NonowningIslPtr(__isl_keep T *Obj) : Obj(Obj) {} + +public: + NonowningIslPtr() : Obj(nullptr) {} + /* implicit */ NonowningIslPtr(nullptr_t That) : NonowningIslPtr() {} + + /* implicit */ NonowningIslPtr(const IslPtr &That) + : NonowningIslPtr(That.keep()) {} + + explicit operator bool() const { return Obj; } + + static void swap(ThisTy &LHS, ThisTy &RHS) { std::swap(LHS.Obj, RHS.Obj); } + + T *keep() const { return Obj; } + T *copy() const { return Traits::copy(Obj); } + + isl_ctx *getCtx() const { return Traits::get_ctx(Obj); } + std::string toStr() const { return Traits::to_str(Obj); } +}; + +template +static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + NonowningIslPtr Obj) { + OS << IslObjTraits::to_str(Obj.keep()); + return OS; +} + +/// Enumerate all isl_maps of an isl_union_map. +/// +/// This basically wraps isl_union_map_foreach_map() and allows to call back +/// C++11 closures. +void foreachElt(NonowningIslPtr UMap, + const std::function Map)> &F); + +/// Enumerate all isl_pw_aff of an isl_union_pw_aff. +/// +/// This basically wraps isl_union_pw_aff(), but also allows to call back C++11 +/// closures. +void foreachElt(NonowningIslPtr UPwAff, + const std::function)> &F); + +/// Enumerate all polyhedra of an isl_map. +/// +/// This is a wrapper for isl_map_foreach_basic_map() that allows to call back +/// C++ closures. The callback has the possibility to interrupt (break) the +/// enumeration by returning isl_stat_error. A return value of isl_stat_ok will +/// continue enumerations, if any more elements are left. +/// +/// @param UMap Collection to enumerate. +/// @param F The callback function, lambda or closure. +/// +/// @return The isl_stat returned by the last callback invocation; isl_stat_ok +/// if the collection was empty. +isl_stat +foreachEltWithBreak(NonowningIslPtr Map, + const std::function)> &F); + +/// Enumerate all isl_maps of an isl_union_map. +/// +/// This is a wrapper for isl_union_map_foreach_map() that allows to call back +/// C++ closures. In contrast to the variant without "_with_break", the callback +/// has the possibility to interrupt (break) the enumeration by returning +/// isl_stat_error. A return value of isl_stat_ok will continue enumerations, if +/// any more elements are left. +/// +/// @param UMap Collection to enumerate. +/// @param F The callback function, lambda or closure. +/// +/// @return The isl_stat returned by the last callback invocation; isl_stat_ok +/// if the collection was initially empty. +isl_stat +foreachEltWithBreak(NonowningIslPtr UMap, + const std::function Map)> &F); + +/// Enumerate all pieces of an isl_pw_aff. +/// +/// This is a wrapper around isl_pw_aff_foreach_piece() that allows to call back +/// C++11 closures. The callback has the possibility to interrupt (break) the +/// enumeration by returning isl_stat_error. A return value of isl_stat_ok will +/// continue enumerations, if any more elements are left. +/// +/// @param UMap Collection to enumerate. +/// @param F The callback function, lambda or closure. +/// +/// @return The isl_stat returned by the last callback invocation; isl_stat_ok +/// if the collection was initially empty. +isl_stat foreachPieceWithBreak( + NonowningIslPtr PwAff, + const std::function, IslPtr)> &F); + } // end namespace polly #endif Index: lib/CMakeLists.txt =================================================================== --- lib/CMakeLists.txt +++ lib/CMakeLists.txt @@ -55,6 +55,8 @@ Transform/CodePreparation.cpp Transform/DeadCodeElimination.cpp Transform/ScheduleOptimizer.cpp + Transform/FlattenSchedule.cpp + Transform/FlattenAlgo.cpp ${POLLY_HEADER_FILES} ) Index: lib/Support/GICHelper.cpp =================================================================== --- lib/Support/GICHelper.cpp +++ lib/Support/GICHelper.cpp @@ -198,3 +198,71 @@ ValStr.erase(0, 1); return getIslCompatibleName(Prefix, ValStr, Suffix); } + +void polly::foreachElt(NonowningIslPtr UMap, + const std::function Map)> &F) { + isl_union_map_foreach_map( + UMap.keep(), + [](__isl_take isl_map *Map, void *User) -> isl_stat { + auto &F = + *static_cast)> *>(User); + F(give(Map)); + return isl_stat_ok; + }, + const_cast(static_cast(&F))); +} + +void polly::foreachElt(NonowningIslPtr UPwAff, + const std::function)> &F) { + isl_union_pw_aff_foreach_pw_aff( + UPwAff.keep(), + [](__isl_take isl_pw_aff *PwAff, void *User) -> isl_stat { + auto &F = + *static_cast)> *>(User); + F(give(PwAff)); + return isl_stat_ok; + }, + const_cast(static_cast(&F))); +} + +isl_stat polly::foreachEltWithBreak( + NonowningIslPtr Map, + const std::function)> &F) { + return isl_map_foreach_basic_map( + Map.keep(), + [](__isl_take isl_basic_map *BMap, void *User) -> isl_stat { + auto &F = *static_cast< + const std::function)> *>(User); + return F(give(BMap)); + }, + const_cast(static_cast(&F))); +} + +isl_stat polly::foreachEltWithBreak( + NonowningIslPtr UMap, + const std::function Map)> &F) { + return isl_union_map_foreach_map( + UMap.keep(), + [](__isl_take isl_map *Map, void *User) -> isl_stat { + auto &F = + *static_cast Map)> *>( + User); + return F(give(Map)); + }, + const_cast(static_cast(&F))); +} + +isl_stat polly::foreachPieceWithBreak( + NonowningIslPtr PwAff, + const std::function, IslPtr)> &F) { + return isl_pw_aff_foreach_piece( + PwAff.keep(), + [](__isl_take isl_set *Domain, __isl_take isl_aff *Aff, + void *User) -> isl_stat { + auto &F = *static_cast< + const std::function, IslPtr)> *>( + User); + return F(give(Domain), give(Aff)); + }, + const_cast(static_cast(&F))); +} Index: lib/Support/RegisterPasses.cpp =================================================================== --- lib/Support/RegisterPasses.cpp +++ lib/Support/RegisterPasses.cpp @@ -24,6 +24,7 @@ #include "polly/CodeGen/CodeGeneration.h" #include "polly/CodeGen/CodegenCleanup.h" #include "polly/DependenceInfo.h" +#include "polly/FlattenSchedule.h" #include "polly/LinkAllPasses.h" #include "polly/Options.h" #include "polly/PolyhedralInfo.h" @@ -180,6 +181,7 @@ initializeScopInfoRegionPassPass(Registry); initializeScopInfoWrapperPassPass(Registry); initializeCodegenCleanupPass(Registry); + initializeFlattenSchedulePass(Registry); } /// Register Polly passes such that they form a polyhedral optimizer. Index: lib/Transform/FlattenAlgo.cpp =================================================================== --- /dev/null +++ lib/Transform/FlattenAlgo.cpp @@ -0,0 +1,417 @@ +//===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Main algorithm of the FlattenSchedulePass. This is a separate file to avoid +// the unittest for this requiring linking against LLVM. +// +//===----------------------------------------------------------------------===// + +#include "polly/FlattenAlgo.h" +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "polly-flatten-algo" + +using namespace polly; +using namespace llvm; + +namespace { + +/// Whether a dimension of a set is bounded (lower and upper) by a constant, +/// i.e. there are two constants Min and Max, such that every value x of the +/// chosen dimensions is Min <= x <= Max. +bool isDimBoundedByConstant(IslPtr Set, unsigned dim) { + auto ParamDims = isl_set_dim(Set.keep(), isl_dim_param); + Set = give(isl_set_project_out(Set.take(), isl_dim_param, 0, ParamDims)); + Set = give(isl_set_project_out(Set.take(), isl_dim_set, 0, dim)); + auto SetDims = isl_set_dim(Set.keep(), isl_dim_set); + Set = give(isl_set_project_out(Set.take(), isl_dim_set, 1, SetDims - 1)); + return isl_set_is_bounded(Set.keep()); +} + +/// Whether a dimension of a set is (lower and upper) bounded by a constant or +/// parameters, i.e. there are two expressions Min_p and Max_p of the parameters +/// p, such that every value x of the chosen dimensions is +/// Min_p <= x <= Max_p. +bool isDimBoundedByParameter(IslPtr Set, unsigned dim) { + Set = give(isl_set_project_out(Set.take(), isl_dim_set, 0, dim)); + auto SetDims = isl_set_dim(Set.keep(), isl_dim_set); + Set = give(isl_set_project_out(Set.take(), isl_dim_set, 1, SetDims - 1)); + return isl_set_is_bounded(Set.keep()); +} + +/// Whether BMap's first out-dimension is not a constant. +bool isVariableDim(NonowningIslPtr BMap) { + auto FixedVal = + give(isl_basic_map_plain_get_val_if_fixed(BMap.keep(), isl_dim_out, 0)); + return !FixedVal || isl_val_is_nan(FixedVal.keep()); +} + +/// Whether Map's first out dimension is no constant nor piecewise constant. +bool isVariableDim(NonowningIslPtr Map) { + return foreachEltWithBreak(Map, [](IslPtr BMap) -> isl_stat { + if (isVariableDim(BMap)) + return isl_stat_error; + return isl_stat_ok; + }); +} + +/// Whether UMap's first out dimension is no (piecewise) constant. +bool isVariableDim(NonowningIslPtr UMap) { + return foreachEltWithBreak(UMap, [](IslPtr Map) -> isl_stat { + if (isVariableDim(Map)) + return isl_stat_error; + return isl_stat_ok; + }); +} + +/// If @p PwAff maps to a constant, return said constant. If @p Max/@p Min, it +/// can also be a piecewise constant and it would return the minimum/maximum +/// value. Otherwise, return NaN. +IslPtr getConstant(IslPtr PwAff, bool Max, bool Min) { + assert(!Max || !Min); + IslPtr Result; + foreachPieceWithBreak( + PwAff, [=, &Result](IslPtr Set, IslPtr Aff) { + if (Result && isl_val_is_nan(Result.keep())) + return isl_stat_ok; + + // TODO: If Min/Max, we can also determine a minimum/maximum value if + // Set is constant-bounded. + if (!isl_aff_is_cst(Aff.keep())) { + Result = give(isl_val_nan(Aff.getCtx())); + return isl_stat_error; + } + + auto ThisVal = give(isl_aff_get_constant_val(Aff.keep())); + if (!Result) { + Result = ThisVal; + return isl_stat_ok; + } + + if (isl_val_eq(Result.keep(), ThisVal.keep())) + return isl_stat_ok; + + if (Max && isl_val_gt(ThisVal.keep(), Result.keep())) { + Result = ThisVal; + return isl_stat_ok; + } + + if (Min && isl_val_lt(ThisVal.keep(), Result.keep())) { + Result = ThisVal; + return isl_stat_ok; + } + + // Not compatible + Result = give(isl_val_nan(Aff.getCtx())); + return isl_stat_error; + }); + return Result; +} + +/// Compute @p UPwAff - @p Val. +IslPtr subtract(IslPtr UPwAff, + IslPtr Val) { + if (isl_val_is_zero(Val.keep())) + return UPwAff; + + auto Result = + give(isl_union_pw_aff_empty(isl_union_pw_aff_get_space(UPwAff.keep()))); + foreachElt(UPwAff, [=, &Result](IslPtr PwAff) { + auto ValAff = give(isl_pw_aff_val_on_domain( + isl_set_universe(isl_space_domain(isl_pw_aff_get_space(PwAff.keep()))), + Val.copy())); + auto Subtracted = give(isl_pw_aff_sub(PwAff.copy(), ValAff.take())); + Result = give(isl_union_pw_aff_union_add( + Result.take(), isl_union_pw_aff_from_pw_aff(Subtracted.take()))); + }); + return Result; +} + +/// Compute @UPwAff * @p Val. +IslPtr multiply(IslPtr UPwAff, + IslPtr Val) { + if (isl_val_is_one(Val.keep())) + return UPwAff; + + auto Result = + give(isl_union_pw_aff_empty(isl_union_pw_aff_get_space(UPwAff.keep()))); + foreachElt(UPwAff, [=, &Result](IslPtr PwAff) { + auto ValAff = give(isl_pw_aff_val_on_domain( + isl_set_universe(isl_space_domain(isl_pw_aff_get_space(PwAff.keep()))), + Val.copy())); + auto Multiplied = give(isl_pw_aff_mul(PwAff.copy(), ValAff.take())); + Result = give(isl_union_pw_aff_union_add( + Result.take(), isl_union_pw_aff_from_pw_aff(Multiplied.take()))); + }); + return Result; +} + +/// Remove @p n dimensions from @p UMap's range, starting at @p first. +/// +/// It is assumed that all maps in the maps have at least the necessary number +/// of out dimensions. +IslPtr scheduleProjectOut(NonowningIslPtr UMap, + unsigned first, unsigned n) { + if (n == 0) + return UMap; /* isl_map_project_out would also reset the tuple, which should + have no effect on schedule ranges */ + + auto Result = give(isl_union_map_empty(isl_union_map_get_space(UMap.keep()))); + foreachElt(UMap, [=, &Result](IslPtr Map) { + auto Outprojected = + give(isl_map_project_out(Map.take(), isl_dim_out, first, n)); + Result = give(isl_union_map_add_map(Result.take(), Outprojected.take())); + }); + return Result; +} + +/// Return the number of dimensions in the input map's range. +/// +/// Because this function takes an isl_union_map, the out dimensions could be +/// different. We return the maximum number in this case. However, a different +/// number of dimensions is not supported by the other code in this file. +size_t scheduleScatterDims(NonowningIslPtr Schedule) { + unsigned Dims = 0; + foreachElt(Schedule, [&Dims](IslPtr Map) { + Dims = std::max(Dims, isl_map_dim(Map.keep(), isl_dim_out)); + }); + return Dims; +} + +/// Return the @p pos' range dimension, converted to an isl_union_pw_aff. +IslPtr scheduleExtractDimAff(IslPtr UMap, + unsigned pos) { + auto SingleUMap = + give(isl_union_map_empty(isl_union_map_get_space(UMap.keep()))); + foreachElt(UMap, [=, &SingleUMap](IslPtr Map) { + auto MapDims = isl_map_dim(Map.keep(), isl_dim_out); + auto SingleMap = give(isl_map_project_out(Map.take(), isl_dim_out, 0, pos)); + SingleMap = give(isl_map_project_out(SingleMap.take(), isl_dim_out, 1, + MapDims - pos - 1)); + SingleUMap = + give(isl_union_map_add_map(SingleUMap.take(), SingleMap.take())); + }); + + auto UAff = give(isl_union_pw_multi_aff_from_union_map(SingleUMap.take())); + auto FirstMAff = + give(isl_multi_union_pw_aff_from_union_pw_multi_aff(UAff.take())); + return give(isl_multi_union_pw_aff_get_union_pw_aff(FirstMAff.keep(), 0)); +} + +/// Flatten a sequence-like first dimension. +/// +/// A sequence-like scatter dimension is constant, or at least only small +/// variation, typically the result of ordering a sequence of different +/// statements. An example would be: +/// { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] } +/// to schedule all instances of Stmt_A before any instance of Stmt_B. +/// +/// To flatten, first begin with an offset of zero. Then determine the lowest +/// possible value of the dimension, call it "i" [In the example we start at 0]. +/// Considering only schedules with that value, consider only instances with +/// that value and determine the extent of the next dimension. Let l_X(i) and +/// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them +/// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1" +/// to Offset and remove all i-instances from the old schedule. Repeat with the +/// remaining lowest value i' until there are no instances in the old schedule +/// left. +/// The example schedule would be transformed to: +/// { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] } +IslPtr tryFlattenSequence(IslPtr Schedule) { + auto IslCtx = Schedule.getCtx(); + auto ScatterSet = + give(isl_set_from_union_set(isl_union_map_range(Schedule.copy()))); + + auto ParamSpace = give(isl_union_map_get_space(Schedule.keep())); + auto Dims = isl_set_dim(ScatterSet.keep(), isl_dim_set); + assert(Dims >= 2); + + // Would cause an infinite loop. + if (!isDimBoundedByConstant(ScatterSet, 0)) { + DEBUG(dbgs() << "Abort; dimension is not of fixed size\n"); + return nullptr; + } + + auto AllDomains = give(isl_union_map_domain(Schedule.copy())); + auto AllDomainsToNull = + give(isl_union_pw_multi_aff_from_domain(AllDomains.take())); + + auto NewSchedule = give(isl_union_map_empty(ParamSpace.copy())); + auto Counter = give(isl_pw_aff_zero_on_domain(isl_local_space_from_space( + isl_space_set_from_params(ParamSpace.copy())))); + + while (!isl_set_is_empty(ScatterSet.keep())) { + DEBUG(dbgs() << "Next counter:\n " << Counter << "\n"); + DEBUG(dbgs() << "Remaining scatter set:\n " << ScatterSet << "\n"); + auto ThisSet = + give(isl_set_project_out(ScatterSet.copy(), isl_dim_set, 1, Dims - 1)); + auto ThisFirst = give(isl_set_lexmin(ThisSet.take())); + auto ScatterFirst = + give(isl_set_add_dims(ThisFirst.take(), isl_dim_set, Dims - 1)); + + auto SubSchedule = give(isl_union_map_intersect_range( + Schedule.copy(), isl_union_set_from_set(ScatterFirst.copy()))); + SubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1); + SubSchedule = flattenSchedule(std::move(SubSchedule)); + + auto SubDims = scheduleScatterDims(SubSchedule); + auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1); + auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0); + auto RemainingSubSchedule = + scheduleProjectOut(std::move(SubSchedule), 0, 1); + + auto FirstSubScatter = give( + isl_set_from_union_set(isl_union_map_range(FirstSubSchedule.take()))); + DEBUG(dbgs() << "Next step in sequence is:\n " << FirstSubScatter << "\n"); + + if (!isDimBoundedByParameter(FirstSubScatter, 0)) { + DEBUG(dbgs() << "Abort; sequence step is not bounded\n"); + return nullptr; + } + + auto FirstSubScatterMap = give(isl_map_from_range(FirstSubScatter.take())); + + // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of + // 'none'. It doesn't match with any space including a 0-dimensional + // anonymous tuple. + // Interesting, one can create such a set using + // isl_set_universe(ParamSpace). Bug? + auto PartMin = give(isl_map_dim_min(FirstSubScatterMap.copy(), 0)); + auto PartMax = give(isl_map_dim_max(FirstSubScatterMap.take(), 0)); + auto One = give(isl_pw_aff_val_on_domain( + isl_set_universe(isl_space_set_from_params(ParamSpace.copy())), + isl_val_one(IslCtx))); + auto PartLen = give(isl_pw_aff_add( + isl_pw_aff_add(PartMax.take(), isl_pw_aff_neg(PartMin.copy())), + One.take())); + + auto AllPartMin = give(isl_union_pw_aff_pullback_union_pw_multi_aff( + isl_union_pw_aff_from_pw_aff(PartMin.take()), AllDomainsToNull.copy())); + auto FirstScheduleAffNormalized = + give(isl_union_pw_aff_sub(FirstScheduleAff.take(), AllPartMin.take())); + auto AllCounter = give(isl_union_pw_aff_pullback_union_pw_multi_aff( + isl_union_pw_aff_from_pw_aff(Counter.copy()), AllDomainsToNull.copy())); + auto FirstScheduleAffWithOffset = give(isl_union_pw_aff_add( + FirstScheduleAffNormalized.take(), AllCounter.take())); + + auto ScheduleWithOffset = give(isl_union_map_flat_range_product( + isl_union_map_from_union_pw_aff(FirstScheduleAffWithOffset.take()), + RemainingSubSchedule.take())); + NewSchedule = give( + isl_union_map_union(NewSchedule.take(), ScheduleWithOffset.take())); + + ScatterSet = give(isl_set_subtract(ScatterSet.take(), ScatterFirst.take())); + Counter = give(isl_pw_aff_add(Counter.take(), PartLen.take())); + } + + DEBUG(dbgs() << "Sequence-flatten result is:\n " << NewSchedule << "\n"); + return NewSchedule; +} + +/// Flatten a loop-like first dimension. +/// +/// A loop-like dimension is one that depends on a variable (usually a loop's +/// induction variable). Let the input schedule look like this: +/// { Stmt[i] -> [i, X, ...] } +/// +/// To flatten, we determine the largest extent of X which may not depend on the +/// actual value of i. Let l_X() the smallest possible value of X and u_X() its +/// largest value. Then, construct a new schedule +/// { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] } +IslPtr tryFlattenLoop(IslPtr Schedule) { + assert(scheduleScatterDims(Schedule) >= 2); + + auto Remaining = scheduleProjectOut(Schedule, 0, 1); + auto SubSchedule = flattenSchedule(Remaining); + auto SubDims = scheduleScatterDims(SubSchedule); + + auto SubExtent = + give(isl_set_from_union_set(isl_union_map_range(SubSchedule.copy()))); + auto SubExtentDims = isl_set_dim(SubExtent.keep(), isl_dim_param); + SubExtent = give( + isl_set_project_out(SubExtent.take(), isl_dim_param, 0, SubExtentDims)); + SubExtent = + give(isl_set_project_out(SubExtent.take(), isl_dim_set, 1, SubDims - 1)); + + if (!isDimBoundedByConstant(SubExtent, 0)) { + DEBUG(dbgs() << "Abort; dimension not bounded by constant\n"); + return nullptr; + } + + auto Min = give(isl_set_dim_min(SubExtent.copy(), 0)); + DEBUG(dbgs() << "Min bound:\n " << Min << "\n"); + auto MinVal = getConstant(Min, false, true); + auto Max = give(isl_set_dim_max(SubExtent.take(), 0)); + DEBUG(dbgs() << "Max bound:\n " << Max << "\n"); + auto MaxVal = getConstant(Max, true, false); + + if (!MinVal || !MaxVal || isl_val_is_nan(MinVal.keep()) || + isl_val_is_nan(MaxVal.keep())) { + DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n"); + return nullptr; + } + + auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0); + auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1); + + auto LenVal = + give(isl_val_add_ui(isl_val_sub(MaxVal.take(), MinVal.copy()), 1)); + auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal); + + // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum, + // subtract it) + auto FirstAff = scheduleExtractDimAff(Schedule, 0); + auto Offset = multiply(FirstAff, LenVal); + auto Index = give( + isl_union_pw_aff_add(FirstSubScheduleNormalized.take(), Offset.take())); + auto IndexMap = give(isl_union_map_from_union_pw_aff(Index.take())); + + auto Result = give(isl_union_map_flat_range_product( + IndexMap.take(), RemainingSubSchedule.take())); + DEBUG(dbgs() << "Loop-flatten result is:\n " << Result << "\n"); + return Result; +} +} // anonymous namespace + +IslPtr polly::flattenSchedule(IslPtr Schedule) { + auto Dims = scheduleScatterDims(Schedule); + DEBUG(dbgs() << "Recursive schedule to process:\n " << Schedule << "\n"); + + // Base case; no dimensions left + if (Dims == 0) { + // TODO: Add one dimension? + return Schedule; + } + + // Base case; already one-dimensional + if (Dims == 1) + return Schedule; + + // Fixed dimension; no need to preserve variabledness. + if (!isVariableDim(Schedule)) { + DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n"); + auto NewScheduleSequence = tryFlattenSequence(Schedule); + if (NewScheduleSequence) + return NewScheduleSequence; + } + + // Constant stride + DEBUG(dbgs() << "Try loop flattening\n"); + auto NewScheduleLoop = tryFlattenLoop(Schedule); + if (NewScheduleLoop) + return NewScheduleLoop; + + // Try again without loop condition (may blow up the number of pieces!!) + DEBUG(dbgs() << "Try sequence flattening again\n"); + auto NewScheduleSequence = tryFlattenSequence(Schedule); + if (NewScheduleSequence) + return NewScheduleSequence; + + // Cannot flatten + return Schedule; +} Index: lib/Transform/FlattenSchedule.cpp =================================================================== --- /dev/null +++ lib/Transform/FlattenSchedule.cpp @@ -0,0 +1,108 @@ +//===------ FlattenSchedule.cpp --------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Try to reduce the number of scatter dimension. Useful to make isl_union_map +// schedules more understandable. This is only intended for debugging and +// unittests, not for production use. +// +//===----------------------------------------------------------------------===// + +#include "polly/FlattenSchedule.h" +#include "polly/FlattenAlgo.h" +#include "polly/ScopInfo.h" +#include "polly/ScopPass.h" +#define DEBUG_TYPE "polly-flatten-schedule" + +using namespace polly; +using namespace llvm; + +namespace { + +/// Print a schedule to @p OS. +/// +/// Prints the schedule for each statements on a new line. +void printSchedule(raw_ostream &OS, NonowningIslPtr Schedule, + int indent) { + foreachElt(Schedule, [&OS, indent](IslPtr Map) { + OS.indent(indent) << Map << "\n"; + }); +} + +/// Flatten the schedule stored in an polly::Scop. +class FlattenSchedule : public ScopPass { +private: + FlattenSchedule(const FlattenSchedule &) = delete; + const FlattenSchedule &operator=(const FlattenSchedule &) = delete; + + std::shared_ptr IslCtx; + IslPtr OldSchedule; + +public: + static char ID; + explicit FlattenSchedule() : ScopPass(ID) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredTransitive(); + AU.setPreservesAll(); + } + + virtual bool runOnScop(Scop &S) override { + // Keep a reference to isl_ctx to ensure that it is not freed before we free + // OldSchedule. + IslCtx = S.getSharedIslCtx(); + + DEBUG(dbgs() << "Going to flatten old schedule:\n"); + OldSchedule = give(S.getSchedule()); + DEBUG(printSchedule(dbgs(), OldSchedule, 2)); + + auto Domains = give(S.getDomains()); + auto RestrictedOldSchedule = give( + isl_union_map_intersect_domain(OldSchedule.copy(), Domains.copy())); + DEBUG(dbgs() << "Old schedule with domains:\n"); + DEBUG(printSchedule(dbgs(), RestrictedOldSchedule, 2)); + + auto NewSchedule = flattenSchedule(RestrictedOldSchedule); + + DEBUG(dbgs() << "Flattened new schedule:\n"); + DEBUG(printSchedule(dbgs(), NewSchedule, 2)); + + NewSchedule = + give(isl_union_map_gist_domain(NewSchedule.take(), Domains.take())); + DEBUG(dbgs() << "Gisted, flattened new schedule:\n"); + DEBUG(printSchedule(dbgs(), NewSchedule, 2)); + + S.setSchedule(NewSchedule.take()); + return false; + } + + virtual void printScop(raw_ostream &OS, Scop &S) const override { + OS << "Schedule before flattening {\n"; + printSchedule(OS, OldSchedule, 4); + OS << "}\n\n"; + + OS << "Schedule after flattening {\n"; + printSchedule(OS, give(S.getSchedule()), 4); + OS << "}\n"; + } + + virtual void releaseMemory() override { + OldSchedule = nullptr; + IslCtx.reset(); + } +}; + +char FlattenSchedule::ID; +} // anonymous namespace + +Pass *polly::createFlattenSchedulePass() { return new FlattenSchedule(); } + +INITIALIZE_PASS_BEGIN(FlattenSchedule, "polly-flatten-schedule", + "Polly - Flatten schedule", false, false) +INITIALIZE_PASS_END(FlattenSchedule, "polly-flatten-schedule", + "Polly - Flatten schedule", false, false) Index: test/FlattenSchedule/gemm.ll =================================================================== --- /dev/null +++ test/FlattenSchedule/gemm.ll @@ -0,0 +1,98 @@ +; RUN: opt %loadPolly -polly-flatten-schedule -analyze < %s | FileCheck %s +; +; dgemm kernel +; C := alpha*A*B + beta*C +; C[ni][nj] +; A[ni][nk] +; B[nk][nj] + +target datalayout = "e-m:x-p:32:32-i64:64-f80:32-n8:16:32-a:0:32-S32" + +define void @gemm(i32 %ni, i32 %nj, i32 %nk, double %alpha, double %beta, double* noalias nonnull %C, double* noalias nonnull %A, double* noalias nonnull %B) { +entry: + br label %ni.for + +ni.for: + %i = phi i32 [0, %entry], [%i.inc, %ni.inc] + %i.cmp = icmp slt i32 %i, 3 + br i1 %i.cmp, label %nj.for, label %ni.exit + + nj.for: + %j = phi i32 [0, %ni.for], [%j.inc, %nj.inc] + %j.cmp = icmp slt i32 %j, 7 + br i1 %j.cmp, label %nj_beta, label %nj.exit + + nj_beta: + %c_stride = mul nsw i32 %i, 3; %nj + %c_idx_i = getelementptr inbounds double, double* %C, i32 %c_stride + %c_idx_ij = getelementptr inbounds double, double* %c_idx_i, i32 %j + + ; C[i][j] *= beta + %c = load double, double* %c_idx_ij + %c_beta = fmul double %c, %beta + store double %c_beta, double* %c_idx_ij + + br label %nk.for + + nk.for: + %k = phi i32 [0, %nj_beta], [%k.inc, %nk.inc] + %k.cmp = icmp slt i32 %k, 3 ; %nk + br i1 %k.cmp, label %nk_alpha, label %nk.exit + + nk_alpha: + %a_stride = mul nsw i32 %i, 3; %nk + %a_idx_i = getelementptr inbounds double, double* %A, i32 %a_stride + %a_idx_ik = getelementptr inbounds double, double* %a_idx_i, i32 %k + + %b_stride = mul nsw i32 %k, 3; %nj + %b_idx_k = getelementptr inbounds double, double* %B, i32 %b_stride + %b_idx_kj = getelementptr inbounds double, double* %b_idx_k, i32 %j + + ; C[i][j] += alpha * A[i][k] * B[k][j] + %a = load double, double* %a_idx_ik + %b = load double, double* %b_idx_kj + %beta_c = load double, double* %c_idx_ij + + %alpha_a = fmul double %a, %alpha + %alpha_a_b = fmul double %alpha_a, %b + %beta_c_alpha_a_b = fadd double %beta_c, %alpha_a_b + + store double %beta_c_alpha_a_b, double* %c_idx_ij + + br label %nk.inc + + nk.inc: + %k.inc = add nuw nsw i32 %k, 1 + br label %nk.for + + nk.exit: + ; store double %c, double* %c_idx_ij + br label %nj.inc + + nj.inc: + %j.inc = add nuw nsw i32 %j, 1 + br label %nj.for + + nj.exit: + br label %ni.inc + +ni.inc: + %i.inc = add nuw nsw i32 %i, 1 + br label %ni.for + +ni.exit: + br label %return + +return: + ret void +} + + +; CHECK: Schedule before flattening { +; CHECK-NEXT: { Stmt_nk_alpha[i0, i1, i2] -> [i0, i1, 1, i2] } +; CHECK-NEXT: { Stmt_nj_beta[i0, i1] -> [i0, i1, 0, 0] } +; CHECK-NEXT: } +; CHECK: Schedule after flattening { +; CHECK-NEXT: { Stmt_nj_beta[i0, i1] -> [28i0 + 4i1] } +; CHECK-NEXT: { Stmt_nk_alpha[i0, i1, i2] -> [1 + 28i0 + 4i1 + i2] } +; CHECK-NEXT: } Index: test/update_check.py =================================================================== --- test/update_check.py +++ test/update_check.py @@ -170,6 +170,18 @@ continue elif line.startswith("New access function '"): yield {'NewAccessFunction'} + elif line == 'Schedule before flattening {': + while True: + yield {'ScheduleBeforeFlattening'} + if line == '}': + break + line = i.__next__() + elif line == 'Schedule after flattening {': + while True: + yield {'ScheduleAfterFlattening'} + if line == '}': + break + line = i.__next__() else: yield set() line = i.__next__() Index: unittests/CMakeLists.txt =================================================================== --- unittests/CMakeLists.txt +++ unittests/CMakeLists.txt @@ -20,3 +20,4 @@ endfunction() add_subdirectory(Isl) +add_subdirectory(Flatten) Index: unittests/Flatten/CMakeLists.txt =================================================================== --- /dev/null +++ unittests/Flatten/CMakeLists.txt @@ -0,0 +1,3 @@ +add_polly_unittest(FlattenTests + FlattenTest.cpp + ) Index: unittests/Flatten/FlattenTest.cpp =================================================================== --- /dev/null +++ unittests/Flatten/FlattenTest.cpp @@ -0,0 +1,70 @@ +//===- FlattenTest.cpp ----------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "polly/FlattenAlgo.h" +#include "polly/Support/GICHelper.h" +#include "gtest/gtest.h" +#include "isl/union_map.h" + +using namespace llvm; +using namespace polly; + +namespace { + +/// Flatten a schedule and compare to the expected result. +/// +/// @param ScheduleStr The schedule to flatten as string. +/// @param ExpectedStr The expected result as string. +/// +/// @result Whether the flattened schedule is the same as the expected schedule. +bool checkFlatten(const char *ScheduleStr, const char *ExpectedStr) { + auto *Ctx = isl_ctx_alloc(); + isl_bool Success; + + { + auto Schedule = give(isl_union_map_read_from_str(Ctx, ScheduleStr)); + auto Expected = give(isl_union_map_read_from_str(Ctx, ExpectedStr)); + + auto Result = flattenSchedule(std::move(Schedule)); + Success = isl_union_map_is_equal(Result.keep(), Expected.keep()); + } + + isl_ctx_free(Ctx); + return Success == isl_bool_true; +} + +TEST(Flatten, FlattenTrivial) { + EXPECT_TRUE(checkFlatten("{ A[] -> [0] }", "{ A[] -> [0] }")); + EXPECT_TRUE(checkFlatten("{ A[i] -> [i, 0] : 0 <= i < 10 }", + "{ A[i] -> [i] : 0 <= i < 10 }")); + EXPECT_TRUE(checkFlatten("{ A[i] -> [0, i] : 0 <= i < 10 }", + "{ A[i] -> [i] : 0 <= i < 10 }")); +} + +TEST(Flatten, FlattenSequence) { + EXPECT_TRUE(checkFlatten( + "[n] -> { A[i] -> [0, i] : 0 <= i < n; B[i] -> [1, i] : 0 <= i < n }", + "[n] -> { A[i] -> [i] : 0 <= i < n; B[i] -> [n + i] : 0 <= i < n }")); + + EXPECT_TRUE(checkFlatten( + "{ A[i] -> [0, i] : 0 <= i < 10; B[i] -> [1, i] : 0 <= i < 10 }", + "{ A[i] -> [i] : 0 <= i < 10; B[i] -> [10 + i] : 0 <= i < 10 }")); +} + +TEST(Flatten, FlattenLoop) { + EXPECT_TRUE(checkFlatten( + "[n] -> { A[i] -> [i, 0] : 0 <= i < n; B[i] -> [i, 1] : 0 <= i < n }", + "[n] -> { A[i] -> [2i] : 0 <= i < n; B[i] -> [2i + 1] : 0 <= i < n }")); + + EXPECT_TRUE(checkFlatten( + "{ A[i] -> [i, 0] : 0 <= i < 10; B[i] -> [i, 1] : 0 <= i < 10 }", + "{ A[i] -> [2i] : 0 <= i < 10; B[i] -> [2i + 1] : 0 <= i < 10 }")); +} + +} // anonymous namespace