Index: polly/trunk/include/polly/ScheduleOptimizer.h =================================================================== --- polly/trunk/include/polly/ScheduleOptimizer.h +++ polly/trunk/include/polly/ScheduleOptimizer.h @@ -147,8 +147,45 @@ /// - if vectorization is enabled /// /// @param Node The schedule node to (possibly) optimize. - /// @param User A pointer to forward some use information (currently unused). + /// @param User A pointer to forward some use information + /// (currently unused). static isl_schedule_node *optimizeBand(isl_schedule_node *Node, void *User); + + /// @brief Apply additional optimizations on the bands in the schedule tree. + /// + /// We apply the following + /// transformations: + /// + /// - Tile the band + /// - Prevectorize the schedule of the band (or the point loop in case of + /// tiling). + /// - if vectorization is enabled + /// + /// @param Node The schedule node to (possibly) optimize. + /// @param User A pointer to forward some use information + /// (currently unused). + static isl_schedule_node *standardBandOpts(__isl_take isl_schedule_node *Node, + void *User); + + /// @brief Check if this node contains a partial schedule that could + /// probably be optimized with analytical modeling. + /// + /// isMatrMultPattern tries to determine whether the following conditions + /// are true: + /// 1. the partial schedule contains only one statement. + /// 2. there are exactly three input dimensions. + /// 3. all memory accesses of the statement will have stride 0 or 1, if we + /// interchange loops (switch the variable used in the inner loop to + /// the outer loop). + /// 4. all memory accesses of the statement except from the last one, are + /// read memory access and the last one is write memory access. + /// 5. all subscripts of the last memory access of the statement don’t + /// contain the variable used in the inner loop. + /// If this is the case, we could try to use an approach that is similar to + /// the one used to get close-to-peak performance of matrix multiplications. + /// + /// @param Node The node to check. + static bool isMatrMultPattern(__isl_keep isl_schedule_node *Node); }; #endif Index: polly/trunk/lib/Transform/ScheduleOptimizer.cpp =================================================================== --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp @@ -166,6 +166,11 @@ cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated, cl::cat(PollyCategory)); +static cl::opt + PMBasedOpts("polly-pattern-matching-based-opts", + cl::desc("Perform optimizations based on pattern matching"), + cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); + /// @brief Create an isl_union_set, which describes the isolate option based /// on IsoalteDomain. /// @@ -359,11 +364,8 @@ } __isl_give isl_schedule_node * -ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node, - void *User) { - if (!isTileableBandNode(Node)) - return Node; - +ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node, + void *User) { if (FirstLevelTiling) Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes, FirstLevelDefaultTileSize); @@ -396,6 +398,110 @@ return Node; } +/// @brief Check whether output dimensions of the map rely on the specified +/// input dimension. +/// +/// @param IslMap The isl map to be considered. +/// @param DimNum The number of an input dimension to be checked. +static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) { + auto *CheckedAccessRelation = + isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1); + CheckedAccessRelation = + isl_map_insert_dims(CheckedAccessRelation, isl_dim_in, DimNum, 1); + auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in); + CheckedAccessRelation = + isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId); + InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out); + CheckedAccessRelation = + isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId); + auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap); + isl_map_free(CheckedAccessRelation); + isl_map_free(IslMap); + return res; +} + +/// @brief Check if the SCoP statement could probably be optimized with +/// analytical modeling. +/// +/// containsMatrMult tries to determine whether the following conditions +/// are true: +/// 1. all memory accesses of the statement will have stride 0 or 1, +/// if we interchange loops (switch the variable used in the inner +/// loop to the outer loop). +/// 2. all memory accesses of the statement except from the last one, are +/// read memory access and the last one is write memory access. +/// 3. all subscripts of the last memory access of the statement don’t contain +/// the variable used in the inner loop. +/// +/// @param PartialSchedule The PartialSchedule that contains a SCoP statement +/// to check. +static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) { + auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in); + auto *ScpStmt = static_cast(isl_id_get_user(InputDimsId)); + isl_id_free(InputDimsId); + if (ScpStmt->size() <= 1) + return false; + auto MemA = ScpStmt->begin(); + for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end(); + i++, MemA++) + if (!(*MemA)->isRead() or + ((*MemA)->isArrayKind() and + !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or + (*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))) + return false; + MemA++; + if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or + !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or + (*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))) + return false; + auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in); + return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1); +} + +/// @brief Circular shift of output dimensions of the integer map. +/// +/// @param IslMap The isl map to be modified. +static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) { + auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in); + auto DimNum = isl_map_dim(IslMap, isl_dim_out); + IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1); + IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1); + return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId); +} + +bool ScheduleTreeOptimizer::isMatrMultPattern( + __isl_keep isl_schedule_node *Node) { + auto *PartialSchedule = + isl_schedule_node_band_get_partial_schedule_union_map(Node); + if (isl_union_map_n_map(PartialSchedule) != 1) + return false; + auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule); + auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in); + if (DimNum != 3) { + isl_map_free(NewPartialSchedule); + return false; + } + NewPartialSchedule = circularShiftOutputDims(NewPartialSchedule); + if (containsMatrMult(NewPartialSchedule)) { + isl_map_free(NewPartialSchedule); + return true; + } + isl_map_free(NewPartialSchedule); + return false; +} + +__isl_give isl_schedule_node * +ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node, + void *User) { + if (!isTileableBandNode(Node)) + return Node; + + if (PMBasedOpts && isMatrMultPattern(Node)) + DEBUG(dbgs() << "The matrix multiplication pattern was detected\n"); + + return standardBandOpts(Node, User); +} + __isl_give isl_schedule * ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) { isl_schedule_node *Root = isl_schedule_get_root(Schedule); Index: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll =================================================================== --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts.ll @@ -0,0 +1,65 @@ +; RUN: opt %loadPolly -polly-opt-isl -debug < %s 2>&1| FileCheck %s +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1| FileCheck %s --check-prefix=PATTERN-MATCHING-OPTS +; REQUIRES: asserts +; CHECK-NOT: The matrix multiplication pattern was detected +; PATTERN-MATCHING-OPTS: The matrix multiplication pattern was detected + +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) { +bb: + br label %bb8 + +bb8: ; preds = %bb39, %bb + %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ] + %tmp9 = icmp slt i32 %tmp, 1056 + br i1 %tmp9, label %bb10, label %bb41 + +bb10: ; preds = %bb8 + br label %bb11 + +bb11: ; preds = %bb37, %bb10 + %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ] + %tmp13 = icmp slt i32 %tmp12, 1056 + br i1 %tmp13, label %bb14, label %bb39 + +bb14: ; preds = %bb11 + %tmp15 = sext i32 %tmp12 to i64 + %tmp16 = sext i32 %tmp to i64 + %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16 + %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15 + %tmp19 = load double, double* %tmp18, align 8 + %tmp20 = fmul double %tmp19, %arg4 + store double %tmp20, double* %tmp18, align 8 + br label %bb21 + +bb21: ; preds = %bb24, %bb14 + %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ] + %tmp23 = icmp slt i32 %tmp22, 1024 + br i1 %tmp23, label %bb24, label %bb37 + +bb24: ; preds = %bb21 + %tmp25 = sext i32 %tmp22 to i64 + %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16 + %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25 + %tmp28 = load double, double* %tmp27, align 8 + %tmp29 = fmul double %arg3, %tmp28 + %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25 + %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15 + %tmp32 = load double, double* %tmp31, align 8 + %tmp33 = fmul double %tmp29, %tmp32 + %tmp34 = load double, double* %tmp18, align 8 + %tmp35 = fadd double %tmp34, %tmp33 + store double %tmp35, double* %tmp18, align 8 + %tmp36 = add nsw i32 %tmp22, 1 + br label %bb21 + +bb37: ; preds = %bb21 + %tmp38 = add nsw i32 %tmp12, 1 + br label %bb11 + +bb39: ; preds = %bb11 + %tmp40 = add nsw i32 %tmp, 1 + br label %bb8 + +bb41: ; preds = %bb8 + ret void +} Index: polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll =================================================================== --- polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll +++ polly/trunk/test/ScheduleOptimizer/pattern-matching-based-opts_2.ll @@ -0,0 +1,63 @@ +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -debug < %s 2>&1 | FileCheck %s +; REQUIRES: asserts +; CHECK-NOT: The matrix multiplication pattern was detected + +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) { +bb: + br label %bb8 + +bb8: ; preds = %bb39, %bb + %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ] + %tmp9 = icmp slt i32 %tmp, 1056 + br i1 %tmp9, label %bb10, label %bb41 + +bb10: ; preds = %bb8 + br label %bb11 + +bb11: ; preds = %bb37, %bb10 + %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ] + %tmp13 = icmp slt i32 %tmp12, 1056 + br i1 %tmp13, label %bb14, label %bb39 + +bb14: ; preds = %bb11 + %tmp15 = sext i32 %tmp12 to i64 + %tmp16 = sext i32 %tmp to i64 + %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16 + %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15 + %tmp19 = load double, double* %tmp18, align 8 + %tmp20 = fmul double %tmp19, %arg4 + store double %tmp20, double* %tmp18, align 8 + br label %bb21 + +bb21: ; preds = %bb24, %bb14 + %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ] + %tmp23 = icmp slt i32 %tmp22, 1024 + br i1 %tmp23, label %bb24, label %bb37 + +bb24: ; preds = %bb21 + %tmp25 = sext i32 %tmp22 to i64 + %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16 + %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25 + %tmp28 = load double, double* %tmp27, align 8 + %tmp29 = fmul double %arg3, %tmp28 + %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25 + %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15 + %tmp32 = load double, double* %tmp31, align 8 + %tmp33 = fmul double %tmp29, %tmp32 + %tmp34 = load double, double* %tmp18, align 8 + %tmp35 = fadd double %tmp34, %tmp33 + store double %tmp35, double* %tmp18, align 8 + %tmp36 = add nsw i32 %tmp22, 1 + br label %bb21 + +bb37: ; preds = %bb21 + %tmp38 = add nsw i32 %tmp12, 2 + br label %bb11 + +bb39: ; preds = %bb11 + %tmp40 = add nsw i32 %tmp, 1 + br label %bb8 + +bb41: ; preds = %bb8 + ret void +}