diff --git a/polly/include/polly/CodeGen/IslAst.h b/polly/include/polly/CodeGen/IslAst.h --- a/polly/include/polly/CodeGen/IslAst.h +++ b/polly/include/polly/CodeGen/IslAst.h @@ -142,7 +142,7 @@ static bool isInnermost(const isl::ast_node &Node); /// Is this loop a parallel loop? - static bool isParallel(__isl_keep isl_ast_node *Node); + static bool isParallel(const isl::ast_node &Node); /// Is this loop an outermost parallel loop? static bool isOutermostParallel(const isl::ast_node &Node); @@ -151,20 +151,19 @@ static bool isInnermostParallel(const isl::ast_node &Node); /// Is this loop a reduction parallel loop? - static bool isReductionParallel(__isl_keep isl_ast_node *Node); + static bool isReductionParallel(const isl::ast_node &Node); /// Will the loop be run as thread parallel? - static bool isExecutedInParallel(__isl_keep isl_ast_node *Node); + static bool isExecutedInParallel(const isl::ast_node &Node); /// Get the nodes schedule or a nullptr if not available. - static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node); + static isl::union_map getSchedule(const isl::ast_node &Node); /// Get minimal dependence distance or nullptr if not available. - static __isl_give isl_pw_aff * - getMinimalDependenceDistance(__isl_keep isl_ast_node *Node); + static isl::pw_aff getMinimalDependenceDistance(const isl::ast_node &Node); /// Get the nodes broken reductions or a nullptr if not available. - static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node); + static MemoryAccessSet *getBrokenReductions(const isl::ast_node &Node); /// Get the nodes build context or a nullptr if not available. static __isl_give isl_ast_build *getBuild(__isl_keep isl_ast_node *Node); diff --git a/polly/include/polly/CodeGen/IslNodeBuilder.h b/polly/include/polly/CodeGen/IslNodeBuilder.h --- a/polly/include/polly/CodeGen/IslNodeBuilder.h +++ b/polly/include/polly/CodeGen/IslNodeBuilder.h @@ -248,7 +248,7 @@ /// this subtree. /// @param Loops A vector that will be filled with the Loops referenced in /// this subtree. - void getReferencesInSubtree(__isl_keep isl_ast_node *For, + void getReferencesInSubtree(const isl::ast_node &For, SetVector &Values, SetVector &Loops); @@ -398,8 +398,7 @@ /// below this ast node to the scheduling vectors used to enumerate /// them. /// - virtual __isl_give isl_union_map * - getScheduleForAstNode(__isl_take isl_ast_node *Node); + virtual isl::union_map getScheduleForAstNode(const isl::ast_node &Node); private: /// Create code for a copy statement. diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp --- a/polly/lib/CodeGen/IslAst.cpp +++ b/polly/lib/CodeGen/IslAst.cpp @@ -140,7 +140,7 @@ } /// Return all broken reductions as a string of clauses (OpenMP style). -static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) { +static const std::string getBrokenReductionsStr(const isl::ast_node &Node) { IslAstInfo::MemoryAccessSet *BrokenReductions; std::string str; @@ -171,25 +171,26 @@ static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, __isl_take isl_ast_print_options *Options, __isl_keep isl_ast_node *Node, void *) { - isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node); - const std::string BrokenReductionsStr = getBrokenReductionsStr(Node); + isl::pw_aff DD = + IslAstInfo::getMinimalDependenceDistance(isl::manage_copy(Node)); + const std::string BrokenReductionsStr = + getBrokenReductionsStr(isl::manage_copy(Node)); const std::string KnownParallelStr = "#pragma known-parallel"; const std::string DepDisPragmaStr = "#pragma minimal dependence distance: "; const std::string SimdPragmaStr = "#pragma simd"; const std::string OmpPragmaStr = "#pragma omp parallel for"; - if (DD) - Printer = printLine(Printer, DepDisPragmaStr, DD); + if (!DD.is_null()) + Printer = printLine(Printer, DepDisPragmaStr, DD.get()); if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node))) Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); - if (IslAstInfo::isExecutedInParallel(Node)) + if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node))) Printer = printLine(Printer, OmpPragmaStr); else if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node))) Printer = printLine(Printer, KnownParallelStr + BrokenReductionsStr); - isl_pw_aff_free(DD); return isl_ast_node_for_print(Node, Printer, Options); } @@ -472,15 +473,15 @@ switch (isl_ast_node_get_type(Node)) { case isl_ast_node_for: NumForLoops++; - if (IslAstInfo::isParallel(Node)) + if (IslAstInfo::isParallel(isl::manage_copy(Node))) NumParallel++; if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node))) NumInnermostParallel++; if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node))) NumOutermostParallel++; - if (IslAstInfo::isReductionParallel(Node)) + if (IslAstInfo::isReductionParallel(isl::manage_copy(Node))) NumReductionParallel++; - if (IslAstInfo::isExecutedInParallel(Node)) + if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node))) NumExecutedInParallel++; break; @@ -593,9 +594,9 @@ return Payload && Payload->IsInnermost; } -bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) { - return IslAstInfo::isInnermostParallel(isl::manage_copy(Node)) || - IslAstInfo::isOutermostParallel(isl::manage_copy(Node)); +bool IslAstInfo::isParallel(const isl::ast_node &Node) { + return IslAstInfo::isInnermostParallel(Node) || + IslAstInfo::isOutermostParallel(Node); } bool IslAstInfo::isInnermostParallel(const isl::ast_node &Node) { @@ -608,12 +609,12 @@ return Payload && Payload->IsOutermostParallel; } -bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); +bool IslAstInfo::isReductionParallel(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); return Payload && Payload->IsReductionParallel; } -bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) { +bool IslAstInfo::isExecutedInParallel(const isl::ast_node &Node) { if (!PollyParallel) return false; @@ -626,28 +627,30 @@ // executed. This can possibly require run-time checks, which again // raises the question of both run-time check overhead and code size // costs. - if (!PollyParallelForce && isInnermost(isl::manage_copy(Node))) + if (!PollyParallelForce && isInnermost(Node)) return false; - return isOutermostParallel(isl::manage_copy(Node)) && - !isReductionParallel(Node); + return isOutermostParallel(Node) && !isReductionParallel(Node); } -__isl_give isl_union_map * -IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); - return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr; +isl::union_map IslAstInfo::getSchedule(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + if (!Payload) + return nullptr; + + isl::ast_build Build = isl::manage_copy(Payload->Build); + return Build.get_schedule(); } -__isl_give isl_pw_aff * -IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); - return Payload ? Payload->MinimalDependenceDistance.copy() : nullptr; +isl::pw_aff +IslAstInfo::getMinimalDependenceDistance(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + return Payload ? Payload->MinimalDependenceDistance : nullptr; } IslAstInfo::MemoryAccessSet * -IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); +IslAstInfo::getBrokenReductions(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); return Payload ? &Payload->BrokenReductions : nullptr; } diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp --- a/polly/lib/CodeGen/IslNodeBuilder.cpp +++ b/polly/lib/CodeGen/IslNodeBuilder.cpp @@ -300,12 +300,12 @@ addReferencesFromStmtSet(Set, &References); } -__isl_give isl_union_map * -IslNodeBuilder::getScheduleForAstNode(__isl_keep isl_ast_node *For) { - return IslAstInfo::getSchedule(For); +isl::union_map +IslNodeBuilder::getScheduleForAstNode(const isl::ast_node &Node) { + return IslAstInfo::getSchedule(Node); } -void IslNodeBuilder::getReferencesInSubtree(__isl_keep isl_ast_node *For, +void IslNodeBuilder::getReferencesInSubtree(const isl::ast_node &For, SetVector &Values, SetVector &Loops) { SetVector SCEVs; @@ -319,8 +319,7 @@ for (const auto &I : OutsideLoopIterations) Values.insert(cast(I.second)->getValue()); - isl::union_set Schedule = - isl::manage(isl_union_map_domain(getScheduleForAstNode(For))); + isl::union_set Schedule = getScheduleForAstNode(For).domain(); addReferencesFromStmtUnionSet(Schedule, References); for (const SCEV *Expr : SCEVs) { @@ -476,22 +475,22 @@ for (int i = 1; i < VectorWidth; i++) IVS[i] = Builder.CreateAdd(IVS[i - 1], ValueInc, "p_vector_iv"); - isl_union_map *Schedule = getScheduleForAstNode(For); - assert(Schedule && "For statement annotation does not contain its schedule"); + isl::union_map Schedule = getScheduleForAstNode(isl::manage_copy(For)); + assert(!Schedule.is_null() && + "For statement annotation does not contain its schedule"); IDToValue[IteratorID] = ValueLB; switch (isl_ast_node_get_type(Body)) { case isl_ast_node_user: - createUserVector(Body, IVS, isl_id_copy(IteratorID), - isl_union_map_copy(Schedule)); + createUserVector(Body, IVS, isl_id_copy(IteratorID), Schedule.copy()); break; case isl_ast_node_block: { isl_ast_node_list *List = isl_ast_node_block_get_children(Body); for (int i = 0; i < isl_ast_node_list_n_ast_node(List); ++i) createUserVector(isl_ast_node_list_get_ast_node(List, i), IVS, - isl_id_copy(IteratorID), isl_union_map_copy(Schedule)); + isl_id_copy(IteratorID), Schedule.copy()); isl_ast_node_free(Body); isl_ast_node_list_free(List); @@ -504,7 +503,6 @@ IDToValue.erase(IDToValue.find(IteratorID)); isl_id_free(IteratorID); - isl_union_map_free(Schedule); isl_ast_node_free(For); isl_ast_expr_free(Iterator); @@ -685,7 +683,7 @@ SetVector SubtreeValues; SetVector Loops; - getReferencesInSubtree(For, SubtreeValues, Loops); + getReferencesInSubtree(isl::manage_copy(For), SubtreeValues, Loops); // Create for all loops we depend on values that contain the current loop // iteration. These values are necessary to generate code for SCEVs that @@ -783,7 +781,7 @@ bool Vector = PollyVectorizerChoice == VECTORIZER_POLLY; if (Vector && IslAstInfo::isInnermostParallel(isl::manage_copy(For)) && - !IslAstInfo::isReductionParallel(For)) { + !IslAstInfo::isReductionParallel(isl::manage_copy(For))) { int VectorWidth = getNumberOfIterations(isl::manage_copy(For)); if (1 < VectorWidth && VectorWidth <= 16 && !hasPartialAccesses(For)) { createForVector(For, VectorWidth); @@ -791,12 +789,12 @@ } } - if (IslAstInfo::isExecutedInParallel(For)) { + if (IslAstInfo::isExecutedInParallel(isl::manage_copy(For))) { createForParallel(For); return; } - bool Parallel = - (IslAstInfo::isParallel(For) && !IslAstInfo::isReductionParallel(For)); + bool Parallel = (IslAstInfo::isParallel(isl::manage_copy(For)) && + !IslAstInfo::isReductionParallel(isl::manage_copy(For))); createForSequential(isl::manage(For), Parallel); }