Index: lib/Transform/ScheduleOptimizer.cpp =================================================================== --- lib/Transform/ScheduleOptimizer.cpp +++ lib/Transform/ScheduleOptimizer.cpp @@ -238,14 +238,18 @@ /// Create an isl_union_set, which describes the isolate option based on /// IsoalteDomain. /// -/// @param IsolateDomain An isl_set whose last dimension is the only one that -/// should belong to the current band node. +/// @param IsolateDomain An isl_set whose @p OutDimsNum last dimensions should +/// belong to the current band node. +/// @param OutDimsNum A number of dimensions that should belong to +/// the current band node. static __isl_give isl_union_set * -getIsolateOptions(__isl_take isl_set *IsolateDomain) { +getIsolateOptions(__isl_take isl_set *IsolateDomain, unsigned OutDimsNum) { auto Dims = isl_set_dim(IsolateDomain, isl_dim_set); + assert(OutDimsNum <= Dims && "..."); auto *IsolateRelation = isl_map_from_domain(IsolateDomain); - IsolateRelation = isl_map_move_dims(IsolateRelation, isl_dim_out, 0, - isl_dim_in, Dims - 1, 1); + IsolateRelation = + isl_map_move_dims(IsolateRelation, isl_dim_out, 0, isl_dim_in, + Dims - OutDimsNum, OutDimsNum); auto *IsolateOption = isl_map_wrap(IsolateRelation); auto *Id = isl_id_alloc(isl_set_get_ctx(IsolateOption), "isolate", nullptr); return isl_union_set_from_set(isl_set_set_tuple_id(IsolateOption, Id)); @@ -264,6 +268,24 @@ return isl_union_set_from_set(isl_set_set_tuple_id(AtomicOption, Id)); } +/// Create an isl_union_set, which describes the option of the form +/// [isolate[] -> unroll[x]]. +/// +/// +/// @param Ctx An isl_ctx, which is used to create the isl_union_set. +static __isl_give isl_union_set * +getUnrollIsolatedSetOptions(__isl_take isl_ctx *Ctx) { + auto *Space = isl_space_alloc(Ctx, 0, 0, 1); + auto *UnrollIsolatedSetOption = isl_map_universe(Space); + auto *DimInId = isl_id_alloc(Ctx, "isolate", nullptr); + auto *DimOutId = isl_id_alloc(Ctx, "unroll", nullptr); + UnrollIsolatedSetOption = + isl_map_set_tuple_id(UnrollIsolatedSetOption, isl_dim_in, DimInId); + UnrollIsolatedSetOption = + isl_map_set_tuple_id(UnrollIsolatedSetOption, isl_dim_out, DimOutId); + return isl_union_set_from_set(isl_map_wrap(UnrollIsolatedSetOption)); +} + /// Make the last dimension of Set to take values from 0 to VectorWidth - 1. /// /// @param Set A set, which should be modified. @@ -324,7 +346,7 @@ auto *ScheduleRange = isl_map_range(ScheduleRelation); auto *IsolateDomain = getPartialTilePrefixes(ScheduleRange, VectorWidth); auto *AtomicOption = getAtomicOptions(isl_set_get_ctx(IsolateDomain)); - auto *IsolateOption = getIsolateOptions(IsolateDomain); + auto *IsolateOption = getIsolateOptions(IsolateDomain, 1); Node = isl_schedule_node_parent(Node); Node = isl_schedule_node_parent(Node); auto *Options = isl_union_set_union(IsolateOption, AtomicOption); @@ -1103,6 +1125,47 @@ return MapOldIndVar; } +/// Isolate a set of partial tile prefixes and unroll the isolated part. +/// +/// The set should ensure that it contains only partial tile prefixes that have +/// exactly Mr x Nr iterations of the two innermost loops produced by +/// the optimization of the matrix multiplication. Mr and Mr are parameters of +/// the micro-kernel. +/// +/// This helps to auto-vectorize the unrolled innermost loops in case of +/// parametric bounds. +/// +/// @param Node The schedule node to be modified. +/// @param MicroKernelParams Parameters of the micro-kernel +/// to be taken into account. +/// @return The modified isl_schedule_node. +static __isl_give isl_schedule_node * +isolateAndUnrollMatMulInnerLoops(__isl_take isl_schedule_node *Node, + struct MicroKernelParamsTy MicroKernelParams) { + auto *Child = isl_schedule_node_get_child(Node, 0); + auto *UnMapOldIndVar = isl_schedule_node_get_prefix_schedule_relation(Child); + isl_schedule_node_free(Child); + auto *Prefix = isl_map_range(isl_map_from_union_map(UnMapOldIndVar)); + auto Dims = isl_set_dim(Prefix, isl_dim_set); + Prefix = isl_set_project_out(Prefix, isl_dim_set, Dims - 1, 1); + Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr); + Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr); + auto *IsolateOption = getIsolateOptions( + isl_set_add_dims(isl_set_copy(Prefix), isl_dim_set, 3), 3); + auto *Ctx = isl_schedule_node_get_ctx(Node); + auto *AtomicOption = getAtomicOptions(Ctx); + auto *Options = + isl_union_set_union(IsolateOption, isl_union_set_copy(AtomicOption)); + Options = isl_union_set_union(Options, getUnrollIsolatedSetOptions(Ctx)); + Node = isl_schedule_node_band_set_ast_build_options(Node, Options); + Node = isl_schedule_node_parent(isl_schedule_node_parent(Node)); + IsolateOption = getIsolateOptions(Prefix, 3); + Options = isl_union_set_union(IsolateOption, AtomicOption); + Node = isl_schedule_node_band_set_ast_build_options(Node, Options); + Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0); + return Node; +} + __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern( __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI, MatMulInfoTy &MMI) { @@ -1129,6 +1192,7 @@ Node, MicroKernelParams, MacroKernelParams); if (!MapOldIndVar) return Node; + Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams); return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams, MacroKernelParams, MMI); } @@ -1164,7 +1228,7 @@ MatMulInfoTy MMI; if (PMBasedOpts && User && isMatrMultPattern(Node, OAI->D, MMI)) { DEBUG(dbgs() << "The matrix multiplication pattern was detected\n"); - Node = optimizeMatMulPattern(Node, OAI->TTI, MMI); + return optimizeMatMulPattern(Node, OAI->TTI, MMI); } return standardBandOpts(Node, User); Index: test/ScheduleOptimizer/mat_mul_pattern_data_layout_2.ll =================================================================== --- test/ScheduleOptimizer/mat_mul_pattern_data_layout_2.ll +++ test/ScheduleOptimizer/mat_mul_pattern_data_layout_2.ll @@ -42,8 +42,6 @@ ; CHECK-NEXT: for (int c4 = 0; c4 <= 23; c4 += 1) ; CHECK-NEXT: for (int c5 = 0; c5 <= min(255, -256 * c1 + 1022); c5 += 1) { ; CHECK-NEXT: // Register tiling - Points -; CHECK-NEXT: // 1st level tiling - Tiles -; CHECK-NEXT: // 1st level tiling - Points ; CHECK-NEXT: { ; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 8 * c3, 256 * c1 + c5); ; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 8 * c3 + 1, 256 * c1 + c5); Index: test/ScheduleOptimizer/pattern-matching-based-opts_3.ll =================================================================== --- test/ScheduleOptimizer/pattern-matching-based-opts_3.ll +++ test/ScheduleOptimizer/pattern-matching-based-opts_3.ll @@ -38,8 +38,6 @@ ; CHECK-NEXT: for (int c1 = 0; c1 <= 263; c1 += 1) ; CHECK-NEXT: for (int c2 = 0; c2 <= 1023; c2 += 1) { ; CHECK-NEXT: // Register tiling - Points -; CHECK-NEXT: // 1st level tiling - Tiles -; CHECK-NEXT: // 1st level tiling - Points ; CHECK-NEXT: { ; CHECK-NEXT: Stmt_Copy_0(4 * c1, 8 * c0, c2); ; CHECK-NEXT: Stmt_Copy_0(4 * c1, 8 * c0 + 1, c2); @@ -101,8 +99,6 @@ ; EXTRACTION-OF-MACRO-KERNEL-NEXT: for (int c4 = 0; c4 <= 23; c4 += 1) ; EXTRACTION-OF-MACRO-KERNEL-NEXT: for (int c5 = 0; c5 <= 255; c5 += 1) { ; EXTRACTION-OF-MACRO-KERNEL-NEXT: // Register tiling - Points -; EXTRACTION-OF-MACRO-KERNEL-NEXT: // 1st level tiling - Tiles -; EXTRACTION-OF-MACRO-KERNEL-NEXT: // 1st level tiling - Points ; EXTRACTION-OF-MACRO-KERNEL-NEXT: { ; EXTRACTION-OF-MACRO-KERNEL-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 8 * c3, 256 * c1 + c5); ; EXTRACTION-OF-MACRO-KERNEL-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 8 * c3 + 1, 256 * c1 + c5); Index: test/ScheduleOptimizer/pattern-matching-based-opts_5.ll =================================================================== --- /dev/null +++ test/ScheduleOptimizer/pattern-matching-based-opts_5.ll @@ -0,0 +1,167 @@ +; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true \ +; RUN: -polly-target-throughput-vector-fma=1 \ +; RUN: -polly-target-latency-vector-fma=8 \ +; RUN: -analyze -polly-ast -polly-target-1st-cache-level-associativity=8 \ +; RUN: -polly-target-2nd-cache-level-associativity=8 \ +; RUN: -polly-target-1st-cache-level-size=32768 \ +; RUN: -polly-target-vector-register-bitwidth=256 \ +; RUN: -polly-target-2nd-cache-level-size=262144 < %s \ +; RUN: | FileCheck %s +; +; /* C := A * B + C */ +; for (i = 0; i < _PB_NI; i++) +; for (j = 0; j < _PB_NJ; j++) +; for (k = 0; k < _PB_NK; ++k) +; C[i][j] += A[i][k] * B[k][j]; +; +; CHECK: if (ni >= 1) { +; CHECK-NEXT: // 1st level tiling - Tiles +; CHECK-NEXT: for (int c0 = 0; c0 <= floord(nj - 1, 2048); c0 += 1) +; CHECK-NEXT: for (int c1 = 0; c1 <= floord(nk - 1, 256); c1 += 1) { +; CHECK-NEXT: for (int c3 = 2048 * c0; c3 <= min(nj - 1, 2048 * c0 + 2047); c3 += 1) +; CHECK-NEXT: for (int c4 = 256 * c1; c4 <= min(nk - 1, 256 * c1 + 255); c4 += 1) +; CHECK-NEXT: CopyStmt_0(0, c3, c4); +; CHECK-NEXT: for (int c2 = 0; c2 <= floord(ni - 1, 96); c2 += 1) { +; CHECK-NEXT: if (c0 == 0) +; CHECK-NEXT: for (int c3 = 96 * c2; c3 <= min(ni - 1, 96 * c2 + 95); c3 += 1) +; CHECK-NEXT: for (int c5 = 256 * c1; c5 <= min(nk - 1, 256 * c1 + 255); c5 += 1) +; CHECK-NEXT: CopyStmt_1(c3, 0, c5); +; CHECK-NEXT: // 1st level tiling - Points +; CHECK-NEXT: // Register tiling - Tiles +; CHECK-NEXT: { +; CHECK-NEXT: if (ni >= 96 * c2 + 4) +; CHECK-NEXT: for (int c3 = 0; c3 <= min(255, -256 * c0 + nj / 8 - 1); c3 += 1) { +; CHECK-NEXT: for (int c4 = 0; c4 <= min(23, -24 * c2 + ni / 4 - 1); c4 += 1) +; CHECK-NEXT: for (int c5 = 0; c5 <= min(255, nk - 256 * c1 - 1); c5 += 1) { +; CHECK-NEXT: // Register tiling - Points +; CHECK-NEXT: { +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3 + 1, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3 + 2, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3 + 3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3 + 4, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3 + 5, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3 + 6, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4, 2048 * c0 + 8 * c3 + 7, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3 + 1, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3 + 2, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3 + 3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3 + 4, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3 + 5, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3 + 6, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 1, 2048 * c0 + 8 * c3 + 7, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3 + 1, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3 + 2, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3 + 3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3 + 4, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3 + 5, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3 + 6, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 2, 2048 * c0 + 8 * c3 + 7, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3 + 1, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3 + 2, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3 + 3, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3 + 4, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3 + 5, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3 + 6, 256 * c1 + c5); +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + 3, 2048 * c0 + 8 * c3 + 7, 256 * c1 + c5); +; CHECK-NEXT: } +; CHECK-NEXT: } +; CHECK-NEXT: if (96 * c2 + 95 >= ni) +; CHECK-NEXT: for (int c5 = 0; c5 <= min(255, nk - 256 * c1 - 1); c5 += 1) { +; CHECK-NEXT: // Register tiling - Points +; CHECK-NEXT: for (int c6 = 0; c6 < ni % 4; c6 += 1) +; CHECK-NEXT: for (int c7 = 0; c7 <= 7; c7 += 1) +; CHECK-NEXT: Stmt_for_body6(-((ni + 4) % 4) + ni + c6, 2048 * c0 + 8 * c3 + c7, 256 * c1 + c5); +; CHECK-NEXT: } +; CHECK-NEXT: } +; CHECK-NEXT: if (96 * c2 + 3 >= ni || (2048 * c0 + 2047 >= nj && nj % 8 >= 1)) +; CHECK-NEXT: for (int c3 = 0; c3 <= min(255, -256 * c0 + (nj - 1) / 8); c3 += 1) +; CHECK-NEXT: if (96 * c2 + 3 >= ni || 2048 * c0 + 8 * c3 + 7 >= nj) +; CHECK-NEXT: for (int c4 = 0; c4 <= min(23, -24 * c2 + (ni - 1) / 4); c4 += 1) +; CHECK-NEXT: if ((ni >= 96 * c2 + 4 && 2048 * c0 + 8 * c3 + 7 >= nj) || 1) +; CHECK-NEXT: for (int c5 = 0; c5 <= min(255, nk - 256 * c1 - 1); c5 += 1) { +; CHECK-NEXT: // Register tiling - Points +; CHECK-NEXT: for (int c6 = 0; c6 <= min(3, ni - 96 * c2 - 4 * c4 - 1); c6 += 1) +; CHECK-NEXT: for (int c7 = 0; c7 <= min(7, nj - 2048 * c0 - 8 * c3 - 1); c7 += 1) +; CHECK-NEXT: Stmt_for_body6(96 * c2 + 4 * c4 + c6, 2048 * c0 + 8 * c3 + c7, 256 * c1 + c5); +; CHECK-NEXT: } +; CHECK-NEXT: } +; CHECK-NEXT: } +; CHECK-NEXT: } +; CHECK-NEXT: } +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-unknown" + +define internal void @kernel_gemm(i32 %ni, i32 %nj, i32 %nk, double %alpha, double %beta, [1024 x double]* %C, [1024 x double]* %A, [1024 x double]* %B) #0 { +entry: + br label %entry.split + +entry.split: ; preds = %entry + %cmp39 = icmp sgt i32 %ni, 0 + br i1 %cmp39, label %for.cond1.preheader.lr.ph, label %for.end22 + +for.cond1.preheader.lr.ph: ; preds = %entry.split + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.inc20, %for.cond1.preheader.lr.ph + %indvars.iv45 = phi i64 [ 0, %for.cond1.preheader.lr.ph ], [ %indvars.iv.next46, %for.inc20 ] + %cmp237 = icmp sgt i32 %nj, 0 + br i1 %cmp237, label %for.cond4.preheader.lr.ph, label %for.inc20 + +for.cond4.preheader.lr.ph: ; preds = %for.cond1.preheader + br label %for.cond4.preheader + +for.cond4.preheader: ; preds = %for.inc17, %for.cond4.preheader.lr.ph + %indvars.iv41 = phi i64 [ 0, %for.cond4.preheader.lr.ph ], [ %indvars.iv.next42, %for.inc17 ] + %cmp535 = icmp sgt i32 %nk, 0 + br i1 %cmp535, label %for.body6.lr.ph, label %for.inc17 + +for.body6.lr.ph: ; preds = %for.cond4.preheader + br label %for.body6 + +for.body6: ; preds = %for.body6, %for.body6.lr.ph + %indvars.iv = phi i64 [ 0, %for.body6.lr.ph ], [ %indvars.iv.next, %for.body6 ] + %arrayidx8 = getelementptr inbounds [1024 x double], [1024 x double]* %A, i64 %indvars.iv45, i64 %indvars.iv + %tmp = load double, double* %arrayidx8, align 8 + %arrayidx12 = getelementptr inbounds [1024 x double], [1024 x double]* %B, i64 %indvars.iv, i64 %indvars.iv41 + %tmp1 = load double, double* %arrayidx12, align 8 + %mul = fmul double %tmp, %tmp1 + %arrayidx16 = getelementptr inbounds [1024 x double], [1024 x double]* %C, i64 %indvars.iv45, i64 %indvars.iv41 + %tmp2 = load double, double* %arrayidx16, align 8 + %add = fadd double %tmp2, %mul + store double %add, double* %arrayidx16, align 8 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %wide.trip.count = zext i32 %nk to i64 + %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %for.body6, label %for.cond4.for.inc17_crit_edge + +for.cond4.for.inc17_crit_edge: ; preds = %for.body6 + br label %for.inc17 + +for.inc17: ; preds = %for.cond4.for.inc17_crit_edge, %for.cond4.preheader + %indvars.iv.next42 = add nuw nsw i64 %indvars.iv41, 1 + %wide.trip.count43 = zext i32 %nj to i64 + %exitcond44 = icmp ne i64 %indvars.iv.next42, %wide.trip.count43 + br i1 %exitcond44, label %for.cond4.preheader, label %for.cond1.for.inc20_crit_edge + +for.cond1.for.inc20_crit_edge: ; preds = %for.inc17 + br label %for.inc20 + +for.inc20: ; preds = %for.cond1.for.inc20_crit_edge, %for.cond1.preheader + %indvars.iv.next46 = add nuw nsw i64 %indvars.iv45, 1 + %wide.trip.count47 = zext i32 %ni to i64 + %exitcond48 = icmp ne i64 %indvars.iv.next46, %wide.trip.count47 + br i1 %exitcond48, label %for.cond1.preheader, label %for.cond.for.end22_crit_edge + +for.cond.for.end22_crit_edge: ; preds = %for.inc20 + br label %for.end22 + +for.end22: ; preds = %for.cond.for.end22_crit_edge, %entry.split + ret void +} + +attributes #0 = { nounwind uwtable "target-cpu"="x86-64" "target-features"="+aes,+avx,+cmov,+cx16,+fxsr,+mmx,+pclmul,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt" }