Index: include/polly/ScheduleOptimizer.h =================================================================== --- include/polly/ScheduleOptimizer.h +++ include/polly/ScheduleOptimizer.h @@ -148,6 +148,9 @@ /// /// @param Node The schedule node to (possibly) optimize. /// @param User A pointer to forward some use information (currently unused). + static isl_schedule_node * + standardBandOpts(__isl_take isl_schedule_node *Node); + static isl_schedule_node *optimizeBand(isl_schedule_node *Node, void *User); }; Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -166,6 +166,11 @@ cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated, cl::cat(PollyCategory)); +static cl::opt + PMBasedOpts("polly-pm-based-opts", + cl::desc("Perform optimizations based on pattern matching"), + cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); + /// @brief Create an isl_union_set, which describes the isolate option based /// on IsoalteDomain. /// @@ -359,11 +364,7 @@ } __isl_give isl_schedule_node * -ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node, - void *User) { - if (!isTileableBandNode(Node)) - return Node; - +ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node) { if (FirstLevelTiling) Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes, FirstLevelDefaultTileSize); @@ -396,6 +397,87 @@ return Node; } +static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) { + auto *CheckedAccessRelation = + isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1); + CheckedAccessRelation = + isl_map_add_dims(CheckedAccessRelation, isl_dim_in, 1); + auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in); + CheckedAccessRelation = + isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId); + InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out); + CheckedAccessRelation = + isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId); + auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap); + isl_map_free(CheckedAccessRelation); + isl_map_free(IslMap); + return res; +} + +static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) { + auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in); + auto *ScpStmt = static_cast(isl_id_get_user(InputDimsId)); + isl_id_free(InputDimsId); + if (ScpStmt->size() <= 1) + return false; + auto MemA = ScpStmt->begin(); + for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end(); + i++, MemA++) + if (!(*MemA)->isRead() or + ((*MemA)->isArrayKind() and + !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or + (*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))) + return false; + MemA++; + if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or + !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or + (*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))) + return false; + auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in); + return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1); +} + +static __isl_give isl_map *interchangeInputDims(__isl_take isl_map *IslMap, + unsigned DstPos, + unsigned SrcPos) { + auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in); + IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, SrcPos, 1); + IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, DstPos, 1); + return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId); +} + +bool static isMatrMultPattern(__isl_keep isl_schedule_node *Node) { + auto *PartialSchedule = + isl_schedule_node_band_get_partial_schedule_union_map(Node); + if (isl_union_map_n_map(PartialSchedule) != 1) + return false; + auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule); + auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in); + if (DimNum != 3) { + isl_map_free(NewPartialSchedule); + return false; + } + NewPartialSchedule = interchangeInputDims(NewPartialSchedule, 0, DimNum - 1); + if (containsMatrMult(NewPartialSchedule)) { + isl_map_free(NewPartialSchedule); + return true; + } + isl_map_free(NewPartialSchedule); + return false; +} + +__isl_give isl_schedule_node * +ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node, + void *User) { + if (!isTileableBandNode(Node)) + return Node; + + if (PMBasedOpts && isMatrMultPattern(Node)) + dbgs() << "The matrix multiplication pattern was detected\n"; + + return standardBandOpts(Node); +} + __isl_give isl_schedule * ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) { isl_schedule_node *Root = isl_schedule_get_root(Schedule);