Index: include/polly/ScheduleOptimizer.h =================================================================== --- include/polly/ScheduleOptimizer.h +++ include/polly/ScheduleOptimizer.h @@ -13,6 +13,7 @@ #define POLLY_SCHEDULE_OPTIMIZER_H #include "llvm/ADT/ArrayRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "isl/ctx.h" struct isl_schedule; @@ -37,9 +38,11 @@ /// /// @param Schedule The schedule object the transformations will be applied /// to. + /// @param TTI Target Transform Info. /// @returns The transformed schedule. static __isl_give isl_schedule * - optimizeSchedule(__isl_take isl_schedule *Schedule); + optimizeSchedule(__isl_take isl_schedule *Schedule, + const llvm::TargetTransformInfo *TTI = nullptr); /// @brief Apply schedule tree transformations. /// @@ -51,9 +54,11 @@ /// - Prevectorization /// /// @param Node The schedule object post-transformations will be applied to. + /// @param TTI Target Transform Info. /// @returns The transformed schedule. static __isl_give isl_schedule_node * - optimizeScheduleNode(__isl_take isl_schedule_node *Node); + optimizeScheduleNode(__isl_take isl_schedule_node *Node, + const llvm::TargetTransformInfo *TTI = nullptr); /// @brief Decide if the @p NewSchedule is profitable for @p S. /// @@ -90,6 +95,25 @@ tileNode(__isl_take isl_schedule_node *Node, const char *Identifier, llvm::ArrayRef TileSizes, int DefaultTileSize); + /// @brief Tile a schedule node and unroll point loops. + /// + /// @param Node The node to register tile. + /// @param TileSizes A vector of tile sizes that should be used for + /// tiling. + /// @param DefaultTileSize A default tile size that is used for dimensions + static __isl_give isl_schedule_node * + registerTileNode(__isl_take isl_schedule_node *Node, + llvm::ArrayRef TileSizes, int DefaultTileSize); + + /// @brief Apply an algorithm, which is similar to the one that is used + /// to get close-to-peak performance of matrix multiplication + /// + /// @param Node Node the node that contains a band to be optimized. + /// @return Modified isl_schedule_node. + static __isl_give isl_schedule_node * + matrMultPatternOpt(__isl_take isl_schedule_node *Node, + const llvm::TargetTransformInfo *TTI); + /// @brief Check if this node is a band node we want to tile. /// /// We look for innermost band nodes where individual dimensions are marked as Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -53,6 +53,7 @@ #include "polly/Options.h" #include "polly/ScopInfo.h" #include "polly/Support/GICHelper.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Support/Debug.h" #include "isl/aff.h" #include "isl/band.h" @@ -120,6 +121,20 @@ cl::init(true), cl::ZeroOrMore, cl::cat(PollyCategory)); +static cl::opt + Lvfma("polly-lvfma", + cl::desc("The minimum number of cycles between the issuance of two " + "dependent consecutive vector fused multiply-add " + "instructions instructions."), + cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory)); + +static cl::opt + Nvfma("polly-nvfma", + cl::desc("The number of vector fused multiply-add instructions per " + "clock cycle, that a throughput of the processor " + "floating-point arithmetic units."), + cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory)); + static cl::opt FirstLevelDefaultTileSize( "polly-default-tile-size", cl::desc("The default tile size (if not enough were provided by" @@ -336,6 +351,17 @@ return Node; } +__isl_give isl_schedule_node * +ScheduleTreeOptimizer::registerTileNode(__isl_take isl_schedule_node *Node, + llvm::ArrayRef TileSizes, + int DefaultTileSize) { + auto *Ctx = isl_schedule_node_get_ctx(Node); + Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize); + Node = isl_schedule_node_band_set_ast_build_options( + Node, isl_union_set_read_from_str(Ctx, "{unroll[x]}")); + return Node; +} + bool ScheduleTreeOptimizer::isTileableBandNode( __isl_keep isl_schedule_node *Node) { if (isl_schedule_node_get_type(Node) != isl_schedule_node_band) @@ -375,13 +401,8 @@ Node = tileNode(Node, "2nd level tiling", SecondLevelTileSizes, SecondLevelDefaultTileSize); - if (RegisterTiling) { - auto *Ctx = isl_schedule_node_get_ctx(Node); - Node = tileNode(Node, "Register tiling", RegisterTileSizes, - RegisterDefaultTileSize); - Node = isl_schedule_node_band_set_ast_build_options( - Node, isl_union_set_read_from_str(Ctx, "{unroll[x]}")); - } + if (RegisterTiling) + Node = registerTileNode(Node, RegisterTileSizes, RegisterDefaultTileSize); if (PollyVectorizerChoice == VECTORIZER_NONE) return Node; @@ -472,6 +493,22 @@ return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId); } +__isl_give isl_schedule_node *ScheduleTreeOptimizer::matrMultPatternOpt( + __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) { + assert(TTI && "The target transform info should be provided."); + // Get a micro-kernel. + // Nvec - Number of double-precision floating-point numbers that can be hold + // by a vector register. Use 2 by default. + auto Nvec = TTI->getRegisterBitWidth(true) / 64; + if (Nvec == 0) + Nvec = 2; + int Nr = ceil(sqrt(Nvec * Lvfma * Nvfma) / Nvec) * Nvec; + int Mr = ceil(Nvec * Lvfma * Nvfma / Nr); + std::vector MicroKernelParams{Mr, Nr}; + Node = registerTileNode(Node, MicroKernelParams, 1); + return Node; +} + bool ScheduleTreeOptimizer::isMatrMultPattern( __isl_keep isl_schedule_node *Node) { auto *PartialSchedule = @@ -502,16 +539,21 @@ if (!isTileableBandNode(Node)) return Node; - if (PMBasedOpts && isMatrMultPattern(Node)) + if (PMBasedOpts && User && isMatrMultPattern(Node)) { DEBUG(dbgs() << "The matrix multiplication pattern was detected\n"); + const llvm::TargetTransformInfo *TTI; + TTI = static_cast(User); + Node = matrMultPatternOpt(Node, TTI); + } return standardBandOpts(Node, User); } __isl_give isl_schedule * -ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) { +ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule, + const llvm::TargetTransformInfo *TTI) { isl_schedule_node *Root = isl_schedule_get_root(Schedule); - Root = optimizeScheduleNode(Root); + Root = optimizeScheduleNode(Root, TTI); isl_schedule_free(Schedule); auto S = isl_schedule_node_get_schedule(Root); isl_schedule_node_free(Root); @@ -519,8 +561,9 @@ } __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeScheduleNode( - __isl_take isl_schedule_node *Node) { - Node = isl_schedule_node_map_descendant_bottom_up(Node, optimizeBand, NULL); + __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) { + Node = isl_schedule_node_map_descendant_bottom_up( + Node, optimizeBand, const_cast(static_cast(TTI))); return Node; } @@ -708,7 +751,10 @@ isl_printer_free(P); }); - isl_schedule *NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule); + Function &F = S.getFunction(); + auto *TTI = &getAnalysis().getTTI(F); + isl_schedule *NewSchedule = + ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI); isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule); if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) { @@ -746,6 +792,7 @@ void IslScheduleOptimizer::getAnalysisUsage(AnalysisUsage &AU) const { ScopPass::getAnalysisUsage(AU); AU.addRequired(); + AU.addRequired(); } Pass *polly::createIslScheduleOptimizerPass() { @@ -756,5 +803,6 @@ "Polly - Optimize schedule of SCoP", false, false); INITIALIZE_PASS_DEPENDENCY(DependenceInfo); INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass); +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass); INITIALIZE_PASS_END(IslScheduleOptimizer, "polly-opt-isl", "Polly - Optimize schedule of SCoP", false, false) Index: test/ScheduleOptimizer/pattern-matching-based-opts_3.ll =================================================================== --- /dev/null +++ test/ScheduleOptimizer/pattern-matching-based-opts_3.ll @@ -0,0 +1,128 @@ +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -polly-nvfma=1 -polly-lvfma=8 -analyze -polly-ast < %s 2>&1 | FileCheck %s +; +; /* C := alpha*A*B + beta*C */ +; for (i = 0; i < _PB_NI; i++) +; for (j = 0; j < _PB_NJ; j++) +; { +; C[i][j] *= beta; +; for (k = 0; k < _PB_NK; ++k) +; C[i][j] += alpha * A[i][k] * B[k][j]; +; } +; +; CHECK: { +; CHECK: // 1st level tiling - Tiles +; CHECK: for (int c0 = 0; c0 <= 32; c0 += 1) +; CHECK: for (int c1 = 0; c1 <= 32; c1 += 1) { +; CHECK: // 1st level tiling - Points +; CHECK: for (int c2 = 0; c2 <= 31; c2 += 1) +; CHECK: for (int c3 = 0; c3 <= 31; c3 += 1) +; CHECK: Stmt_bb14(32 * c0 + c2, 32 * c1 + c3); +; CHECK: } +; CHECK: // Register tiling - Tiles +; CHECK: for (int c0 = 0; c0 <= 263; c0 += 1) +; CHECK: for (int c1 = 0; c1 <= 131; c1 += 1) +; CHECK: for (int c2 = 0; c2 <= 1023; c2 += 1) { +; CHECK: // Register tiling - Points +; CHECK: // 1st level tiling - Tiles +; CHECK: // 1st level tiling - Points +; CHECK: { +; CHECK: Stmt_bb24(4 * c0, 8 * c1, c2); +; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 1, c2); +; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 2, c2); +; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 3, c2); +; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 4, c2); +; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 5, c2); +; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 6, c2); +; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 7, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 1, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 2, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 3, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 4, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 5, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 6, c2); +; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 7, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 1, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 2, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 3, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 4, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 5, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 6, c2); +; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 7, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 1, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 2, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 3, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 4, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 5, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 6, c2); +; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 7, c2); +; CHECK: } +; CHECK: } +; CHECK: } +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-unknown" + +define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) #0 { +bb: + br label %bb8 + +bb8: ; preds = %bb39, %bb + %tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ] + %tmp9 = icmp slt i32 %tmp, 1056 + br i1 %tmp9, label %bb10, label %bb41 + +bb10: ; preds = %bb8 + br label %bb11 + +bb11: ; preds = %bb37, %bb10 + %tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ] + %tmp13 = icmp slt i32 %tmp12, 1056 + br i1 %tmp13, label %bb14, label %bb39 + +bb14: ; preds = %bb11 + %tmp15 = sext i32 %tmp12 to i64 + %tmp16 = sext i32 %tmp to i64 + %tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16 + %tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15 + %tmp19 = load double, double* %tmp18, align 8 + %tmp20 = fmul double %tmp19, %arg4 + store double %tmp20, double* %tmp18, align 8 + br label %bb21 + +bb21: ; preds = %bb24, %bb14 + %tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ] + %tmp23 = icmp slt i32 %tmp22, 1024 + br i1 %tmp23, label %bb24, label %bb37 + +bb24: ; preds = %bb21 + %tmp25 = sext i32 %tmp22 to i64 + %tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16 + %tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25 + %tmp28 = load double, double* %tmp27, align 8 + %tmp29 = fmul double %arg3, %tmp28 + %tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25 + %tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15 + %tmp32 = load double, double* %tmp31, align 8 + %tmp33 = fmul double %tmp29, %tmp32 + %tmp34 = load double, double* %tmp18, align 8 + %tmp35 = fadd double %tmp34, %tmp33 + store double %tmp35, double* %tmp18, align 8 + %tmp36 = add nsw i32 %tmp22, 1 + br label %bb21 + +bb37: ; preds = %bb21 + %tmp38 = add nsw i32 %tmp12, 1 + br label %bb11 + +bb39: ; preds = %bb11 + %tmp40 = add nsw i32 %tmp, 1 + br label %bb8 + +bb41: ; preds = %bb8 + ret void +} + +attributes #0 = { nounwind uwtable "target-cpu"="x86-64" "target-features"="+aes,+avx,+cmov,+cx16,+fxsr,+mmx,+pclmul,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt" }