Index: polly/trunk/include/polly/ScheduleOptimizer.h =================================================================== --- polly/trunk/include/polly/ScheduleOptimizer.h +++ polly/trunk/include/polly/ScheduleOptimizer.h @@ -64,6 +64,17 @@ static bool isProfitableSchedule(polly::Scop &S, __isl_keep isl_union_map *NewSchedule); + /// @brief Isolate a set of partial tile prefixes. + /// + /// This set should ensure that it contains only partial tile prefixes that + /// have exactly VectorWidth iterations. + /// + /// @param Node A schedule node band, which is a parent of a band node, + /// that contains a vector loop. + /// @return Modified isl_schedule_node. + static __isl_give isl_schedule_node * + isolateFullPartialTiles(__isl_take isl_schedule_node *Node, int VectorWidth); + private: /// @brief Tile a schedule node. /// Index: polly/trunk/lib/Transform/ScheduleOptimizer.cpp =================================================================== --- polly/trunk/lib/Transform/ScheduleOptimizer.cpp +++ polly/trunk/lib/Transform/ScheduleOptimizer.cpp @@ -160,6 +160,104 @@ cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated, cl::cat(PollyCategory)); +/// @brief Create an isl_union_set, which describes the isolate option based +/// on IsoalteDomain. +/// +/// @param IsolateDomain An isl_set whose last dimension is the only one that +/// should belong to the current band node. +static __isl_give isl_union_set * +getIsolateOptions(__isl_take isl_set *IsolateDomain) { + auto Dims = isl_set_dim(IsolateDomain, isl_dim_set); + auto *IsolateRelation = isl_map_from_domain(IsolateDomain); + IsolateRelation = isl_map_move_dims(IsolateRelation, isl_dim_out, 0, + isl_dim_in, Dims - 1, 1); + auto *IsolateOption = isl_map_wrap(IsolateRelation); + auto *Id = isl_id_alloc(isl_set_get_ctx(IsolateOption), "isolate", NULL); + return isl_union_set_from_set(isl_set_set_tuple_id(IsolateOption, Id)); +} + +/// @brief Create an isl_union_set, which describes the atomic option for the +/// dimension of the current node. +/// +/// It may help to reduce the size of generated code. +/// +/// @param Ctx An isl_ctx, which is used to create the isl_union_set. +static __isl_give isl_union_set *getAtomicOptions(__isl_take isl_ctx *Ctx) { + auto *Space = isl_space_set_alloc(Ctx, 0, 1); + auto *AtomicOption = isl_set_universe(Space); + auto *Id = isl_id_alloc(Ctx, "atomic", NULL); + return isl_union_set_from_set(isl_set_set_tuple_id(AtomicOption, Id)); +} + +/// @brief Make the last dimension of Set to take values +/// from 0 to VectorWidth - 1. +/// +/// @param Set A set, which should be modified. +/// @param VectorWidth A parameter, which determines the constraint. +static __isl_give isl_set *addExtentConstraints(__isl_take isl_set *Set, + int VectorWidth) { + auto Dims = isl_set_dim(Set, isl_dim_set); + auto Space = isl_set_get_space(Set); + auto *LocalSpace = isl_local_space_from_space(Space); + auto *ExtConstr = + isl_constraint_alloc_inequality(isl_local_space_copy(LocalSpace)); + ExtConstr = isl_constraint_set_constant_si(ExtConstr, 0); + ExtConstr = + isl_constraint_set_coefficient_si(ExtConstr, isl_dim_set, Dims - 1, 1); + Set = isl_set_add_constraint(Set, ExtConstr); + ExtConstr = isl_constraint_alloc_inequality(LocalSpace); + ExtConstr = isl_constraint_set_constant_si(ExtConstr, VectorWidth - 1); + ExtConstr = + isl_constraint_set_coefficient_si(ExtConstr, isl_dim_set, Dims - 1, -1); + return isl_set_add_constraint(Set, ExtConstr); +} + +/// @brief Build the desired set of partial tile prefixes. +/// +/// We build a set of partial tile prefixes, which are prefixes of the vector +/// loop that have exactly VectorWidth iterations. +/// +/// 1. Get all prefixes of the vector loop. +/// 2. Extend it to a set, which has exactly VectorWidth iterations for +/// any prefix from the set that was built on the previous step. +/// 3. Subtract loop domain from it, project out the vector loop dimension and +/// get a set of prefixes, which don’t have exactly VectorWidth iterations. +/// 4. Subtract it from all prefixes of the vector loop and get the desired +/// set. +/// +/// @param ScheduleRange A range of a map, which describes a prefix schedule +/// relation. +static __isl_give isl_set * +getPartialTilePrefixes(__isl_take isl_set *ScheduleRange, int VectorWidth) { + auto Dims = isl_set_dim(ScheduleRange, isl_dim_set); + auto *LoopPrefixes = isl_set_project_out(isl_set_copy(ScheduleRange), + isl_dim_set, Dims - 1, 1); + auto *ExtentPrefixes = + isl_set_add_dims(isl_set_copy(LoopPrefixes), isl_dim_set, 1); + ExtentPrefixes = addExtentConstraints(ExtentPrefixes, VectorWidth); + auto *BadPrefixes = isl_set_subtract(ExtentPrefixes, ScheduleRange); + BadPrefixes = isl_set_project_out(BadPrefixes, isl_dim_set, Dims - 1, 1); + return isl_set_subtract(LoopPrefixes, BadPrefixes); +} + +__isl_give isl_schedule_node *ScheduleTreeOptimizer::isolateFullPartialTiles( + __isl_take isl_schedule_node *Node, int VectorWidth) { + assert(isl_schedule_node_get_type(Node) == isl_schedule_node_band); + Node = isl_schedule_node_child(Node, 0); + Node = isl_schedule_node_child(Node, 0); + auto *SchedRelUMap = isl_schedule_node_get_prefix_schedule_relation(Node); + auto *ScheduleRelation = isl_map_from_union_map(SchedRelUMap); + auto *ScheduleRange = isl_map_range(ScheduleRelation); + auto *IsolateDomain = getPartialTilePrefixes(ScheduleRange, VectorWidth); + auto *AtomicOption = getAtomicOptions(isl_set_get_ctx(IsolateDomain)); + auto *IsolateOption = getIsolateOptions(IsolateDomain); + Node = isl_schedule_node_parent(Node); + Node = isl_schedule_node_parent(Node); + auto *Options = isl_union_set_union(IsolateOption, AtomicOption); + Node = isl_schedule_node_band_set_ast_build_options(Node, Options); + return Node; +} + __isl_give isl_schedule_node * ScheduleTreeOptimizer::prevectSchedBand(__isl_take isl_schedule_node *Node, unsigned DimToVectorize, @@ -183,6 +281,7 @@ Sizes = isl_multi_val_set_val(Sizes, 0, isl_val_int_from_si(Ctx, VectorWidth)); Node = isl_schedule_node_band_tile(Node, Sizes); + Node = isolateFullPartialTiles(Node, VectorWidth); Node = isl_schedule_node_child(Node, 0); // Make sure the "trivially vectorizable loop" is not unrolled. Otherwise, // we will have troubles to match it in the backend. Index: polly/trunk/test/ScheduleOptimizer/full_partial_tile_separation.ll =================================================================== --- polly/trunk/test/ScheduleOptimizer/full_partial_tile_separation.ll +++ polly/trunk/test/ScheduleOptimizer/full_partial_tile_separation.ll @@ -0,0 +1,92 @@ +; RUN: opt -S %loadPolly -polly-vectorizer=stripmine -polly-opt-isl -polly-ast -analyze < %s | FileCheck %s + +; CHECK: // 1st level tiling - Tiles +; CHECK: #pragma known-parallel +; CHECK: for (int c0 = 0; c0 <= floord(ni - 1, 32); c0 += 1) +; CHECK: for (int c1 = 0; c1 <= floord(nj - 1, 32); c1 += 1) +; CHECK: for (int c2 = 0; c2 <= floord(nk - 1, 32); c2 += 1) { +; CHECK: // 1st level tiling - Points +; CHECK: for (int c3 = 0; c3 <= min(31, ni - 32 * c0 - 1); c3 += 1) { +; CHECK: for (int c4 = 0; c4 <= min(7, -8 * c1 + nj / 4 - 1); c4 += 1) +; CHECK: for (int c5 = 0; c5 <= min(31, nk - 32 * c2 - 1); c5 += 1) +; CHECK: #pragma simd +; CHECK: for (int c6 = 0; c6 <= 3; c6 += 1) +; CHECK: Stmt_for_body_6(32 * c0 + c3, 32 * c1 + 4 * c4 + c6, 32 * c2 + c5); +; CHECK: if (nj >= 32 * c1 + 4 && 32 * c1 + 31 >= nj) { +; CHECK: for (int c5 = 0; c5 <= min(31, nk - 32 * c2 - 1); c5 += 1) +; CHECK: #pragma simd +; CHECK: for (int c6 = 0; c6 < nj % 4; c6 += 1) +; CHECK: Stmt_for_body_6(32 * c0 + c3, -((nj - 1) % 4) + nj + c6 - 1, 32 * c2 + c5); +; CHECK: } else if (32 * c1 + 3 >= nj) +; CHECK: for (int c5 = 0; c5 <= min(31, nk - 32 * c2 - 1); c5 += 1) +; CHECK: #pragma simd +; CHECK: for (int c6 = 0; c6 < nj - 32 * c1; c6 += 1) +; CHECK: Stmt_for_body_6(32 * c0 + c3, 32 * c1 + c6, 32 * c2 + c5); +; CHECK: } +; CHECK: } + +; Function Attrs: nounwind uwtable +define void @kernel_gemm(i32 %ni, i32 %nj, i32 %nk, double %alpha, double %beta, [1024 x double]* %C, [1024 x double]* %A, [1024 x double]* %B) #0 { +entry: + %cmp.27 = icmp sgt i32 %ni, 0 + br i1 %cmp.27, label %for.cond.1.preheader.lr.ph, label %for.end.22 + +for.cond.1.preheader.lr.ph: ; preds = %entry + br label %for.cond.1.preheader + +for.cond.1.preheader: ; preds = %for.cond.1.preheader.lr.ph, %for.inc.20 + %indvars.iv33 = phi i64 [ 0, %for.cond.1.preheader.lr.ph ], [ %indvars.iv.next34, %for.inc.20 ] + %cmp2.25 = icmp sgt i32 %nj, 0 + br i1 %cmp2.25, label %for.cond.4.preheader.lr.ph, label %for.inc.20 + +for.cond.4.preheader.lr.ph: ; preds = %for.cond.1.preheader + br label %for.cond.4.preheader + +for.cond.4.preheader: ; preds = %for.cond.4.preheader.lr.ph, %for.inc.17 + %indvars.iv29 = phi i64 [ 0, %for.cond.4.preheader.lr.ph ], [ %indvars.iv.next30, %for.inc.17 ] + %cmp5.23 = icmp sgt i32 %nk, 0 + br i1 %cmp5.23, label %for.body.6.lr.ph, label %for.inc.17 + +for.body.6.lr.ph: ; preds = %for.cond.4.preheader + br label %for.body.6 + +for.body.6: ; preds = %for.body.6.lr.ph, %for.body.6 + %indvars.iv = phi i64 [ 0, %for.body.6.lr.ph ], [ %indvars.iv.next, %for.body.6 ] + %arrayidx8 = getelementptr inbounds [1024 x double], [1024 x double]* %A, i64 %indvars.iv33, i64 %indvars.iv + %0 = load double, double* %arrayidx8, align 8 + %arrayidx12 = getelementptr inbounds [1024 x double], [1024 x double]* %B, i64 %indvars.iv, i64 %indvars.iv29 + %1 = load double, double* %arrayidx12, align 8 + %mul = fmul double %0, %1 + %arrayidx16 = getelementptr inbounds [1024 x double], [1024 x double]* %C, i64 %indvars.iv33, i64 %indvars.iv29 + %2 = load double, double* %arrayidx16, align 8 + %add = fadd double %2, %mul + store double %add, double* %arrayidx16, align 8 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp ne i32 %lftr.wideiv, %nk + br i1 %exitcond, label %for.body.6, label %for.cond.4.for.inc.17_crit_edge + +for.cond.4.for.inc.17_crit_edge: ; preds = %for.body.6 + br label %for.inc.17 + +for.inc.17: ; preds = %for.cond.4.for.inc.17_crit_edge, %for.cond.4.preheader + %indvars.iv.next30 = add nuw nsw i64 %indvars.iv29, 1 + %lftr.wideiv31 = trunc i64 %indvars.iv.next30 to i32 + %exitcond32 = icmp ne i32 %lftr.wideiv31, %nj + br i1 %exitcond32, label %for.cond.4.preheader, label %for.cond.1.for.inc.20_crit_edge + +for.cond.1.for.inc.20_crit_edge: ; preds = %for.inc.17 + br label %for.inc.20 + +for.inc.20: ; preds = %for.cond.1.for.inc.20_crit_edge, %for.cond.1.preheader + %indvars.iv.next34 = add nuw nsw i64 %indvars.iv33, 1 + %lftr.wideiv35 = trunc i64 %indvars.iv.next34 to i32 + %exitcond36 = icmp ne i32 %lftr.wideiv35, %ni + br i1 %exitcond36, label %for.cond.1.preheader, label %for.cond.for.end.22_crit_edge + +for.cond.for.end.22_crit_edge: ; preds = %for.inc.20 + br label %for.end.22 + +for.end.22: ; preds = %for.cond.for.end.22_crit_edge, %entry + ret void +}