Index: lib/CodeGen/IslAst.cpp =================================================================== --- lib/CodeGen/IslAst.cpp +++ lib/CodeGen/IslAst.cpp @@ -197,6 +197,82 @@ return true; } +/// @brief Collect loop annotations from the orignal loops surrounding @p Body. +/// +/// We will look at all statements in @p Body and all loops formerly surrounding +/// those statements and aggregate their loop annotations if they are invovled +/// in the __new__ innermost dimension. +static void collectLoopAnnotations(__isl_take isl_ast_node *Body) { + // Recurce for block and conditional statements but extract the annotations + // once a user ast node was found. + switch (isl_ast_node_get_type(Body)) { + case isl_ast_node_block: { + isl_ast_node_list *List = isl_ast_node_block_get_children(Body); + for (int i = 0; i < isl_ast_node_list_n_ast_node(List); ++i) + collectLoopAnnotations(isl_ast_node_list_get_ast_node(List, i)); + isl_ast_node_list_free(List); + break; + } + case isl_ast_node_if: { + collectLoopAnnotations(isl_ast_node_if_get_then(Body)); + if (isl_ast_node_if_has_else(Body)) + collectLoopAnnotations(isl_ast_node_if_get_else(Body)); + break; + } + case isl_ast_node_user: { + isl_ast_expr *Expr, *UserExpr; + isl_pw_multi_aff *ScatPMA; + isl_pw_aff *ScatPA; + isl_id *Id; + + UserExpr = isl_ast_node_user_get_expr(Body); + Expr = isl_ast_expr_get_op_arg(UserExpr, 0); + Id = isl_ast_expr_get_id(Expr); + + ScopStmt *Stmt = (ScopStmt *)isl_id_get_user(Id); + assert(Stmt->getNumIterators() && "Unexpected scattering found"); + + // Find the highest/innermost dimension which is not constant. + ScatPMA = isl_pw_multi_aff_from_map(Stmt->getScattering()); + unsigned pos = isl_pw_multi_aff_dim(ScatPMA, isl_dim_out); + assert(pos && "Unexpected scattering found"); + + ScatPA = nullptr; + do { + isl_pw_aff_free(ScatPA); + ScatPA = isl_pw_multi_aff_get_pw_aff(ScatPMA, --pos); + } while (pos && isl_pw_aff_is_cst(ScatPA)); + + // If a non constant dimension was found check for loops. + if (!isl_pw_aff_is_cst(ScatPA)) { + + // Get rid of the constrains caused by the domain. + ScatPA = isl_pw_aff_gist(ScatPA, Stmt->getDomain()); + + // For each input dimension we check if it is actually used in the + // innermost + // (now only) dimension. If so we can get the corresponding Loop and check + // for annotations. + for (unsigned u = 0, e = Stmt->getNumIterators(); u != e; u++) + if (isl_pw_aff_involves_dims(ScatPA, isl_dim_in, u, 1)) + if (const Loop *L = Stmt->getLoopForDimension(u)) + /* TODO Actually check and extract annotations */ L->getLoopID(); + } + + isl_pw_multi_aff_free(ScatPMA); + isl_ast_expr_free(UserExpr); + isl_ast_expr_free(Expr); + isl_pw_aff_free(ScatPA); + isl_id_free(Id); + break; + } + default: + llvm_unreachable("Loop body was unexpected"); + } + + isl_ast_node_free(Body); +} + // This method is executed before the construction of a for node. It creates // an isl_id that is used to annotate the subsequently generated ast for nodes. // @@ -253,6 +329,11 @@ if (Payload->IsOutermostParallel) BuildInfo->InParallelFor = false; + // For innermost loops collect all loop annotations from the orignal loop(s) + // involved in this new innermost dimension. + if (Payload->IsInnermost) + collectLoopAnnotations(isl_ast_node_for_get_body(Node)); + isl_id_free(Id); return Node; }