Index: include/polly/CodeGen/IslAst.h =================================================================== --- include/polly/CodeGen/IslAst.h +++ include/polly/CodeGen/IslAst.h @@ -24,6 +24,7 @@ #include "polly/Config/config.h" #include "polly/ScopPass.h" +#include "llvm/ADT/SetVector.h" #include "isl/ast.h" @@ -45,6 +46,8 @@ class IslAstInfo : public ScopPass { public: using MemoryAccessSet = SmallPtrSet; + using LoopSet = SetVector, + SmallPtrSet >; /// @brief Payload information used to annotate an AST node. struct IslAstUserPayload { @@ -74,6 +77,9 @@ /// @brief Set of accesses which break reduction dependences. MemoryAccessSet BrokenReductions; + + /// @brief Set of loops surrounding the node. + LoopSet Loops; }; private: @@ -129,6 +135,14 @@ /// @brief Get the nodes broken reductions or a nullptr if not available. static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node); + /// @brief Get the original loops surrounding the new node or a nullptr + /// if not available. + static LoopSet *getLoops(__isl_keep isl_ast_node *Node); + + /// @brief Get the original loop surrounding the new node or a nullptr + /// if not available. + static const Loop *getLoop(__isl_keep isl_ast_node *Node); + /// @brief Get the nodes build context or a nullptr if not available. static __isl_give isl_ast_build *getBuild(__isl_keep isl_ast_node *Node); Index: lib/CodeGen/IslAst.cpp =================================================================== --- lib/CodeGen/IslAst.cpp +++ lib/CodeGen/IslAst.cpp @@ -25,7 +25,9 @@ #include "polly/LinkAllPasses.h" #include "polly/Options.h" #include "polly/ScopInfo.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "isl/union_map.h" #include "isl/list.h" @@ -136,6 +138,18 @@ return str; } +/// @brief Return a pragam string for annotating an isl ast. +static std::string getLLVMLoopPragmaStr(const Loop *L) { + SmallString<128> Buf; + raw_svector_ostream OS(Buf); + OS << "#llvm loop(" << L->getHeader()->getName(); + if (BasicBlock *Latch = L->getLoopLatch()) + OS << ", " << Latch->getName(); + OS << ")"; + OS << " depth(" << L->getLoopDepth() << ")"; + return OS.str(); +} + /// @brief Callback executed for each for node in the ast in order to print it. static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, __isl_take isl_ast_print_options *Options, @@ -145,6 +159,10 @@ const std::string SimdPragmaStr = "#pragma simd"; const std::string OmpPragmaStr = "#pragma omp parallel for"; + if (IslAstInfo::LoopSet *Loops = IslAstInfo::getLoops(Node)) + for (const Loop *L : *Loops) + Printer = printLine(Printer, getLLVMLoopPragmaStr(L)); + if (IslAstInfo::isInnermostParallel(Node)) Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); @@ -202,21 +220,22 @@ /// We will look at all statements in @p Body and all loops formerly surrounding /// those statements and aggregate their loop annotations if they are invovled /// in the __new__ innermost dimension. -static void collectLoopAnnotations(__isl_take isl_ast_node *Body) { +static void collectLoopAnnotations(__isl_take isl_ast_node *Body, + IslAstUserPayload *Payload) { // Recurce for block and conditional statements but extract the annotations // once a user ast node was found. switch (isl_ast_node_get_type(Body)) { case isl_ast_node_block: { isl_ast_node_list *List = isl_ast_node_block_get_children(Body); for (int i = 0; i < isl_ast_node_list_n_ast_node(List); ++i) - collectLoopAnnotations(isl_ast_node_list_get_ast_node(List, i)); + collectLoopAnnotations(isl_ast_node_list_get_ast_node(List, i), Payload); isl_ast_node_list_free(List); break; } case isl_ast_node_if: { - collectLoopAnnotations(isl_ast_node_if_get_then(Body)); + collectLoopAnnotations(isl_ast_node_if_get_then(Body), Payload); if (isl_ast_node_if_has_else(Body)) - collectLoopAnnotations(isl_ast_node_if_get_else(Body)); + collectLoopAnnotations(isl_ast_node_if_get_else(Body), Payload); break; } case isl_ast_node_user: { @@ -252,8 +271,9 @@ // for annotations. for (unsigned u = 0, e = Stmt->getNumIterators(); u != e; u++) if (isl_pw_aff_involves_dims(ScatPA, isl_dim_in, u, 1)) - if (const Loop *L = Stmt->getLoopForDimension(u)) - /* TODO Actually check and extract annotations */ L->getLoopID(); + if (const Loop *L = Stmt->getLoopForDimension(u)) { + Payload->Loops.insert(L); + } isl_pw_multi_aff_free(ScatPMA); isl_ast_expr_free(UserExpr); @@ -327,7 +347,7 @@ // For innermost loops collect all loop annotations from the orignal loop(s) // involved in this new innermost dimension. if (Payload->IsInnermost) - collectLoopAnnotations(isl_ast_node_for_get_body(Node)); + collectLoopAnnotations(isl_ast_node_for_get_body(Node), Payload); isl_id_free(Id); return Node; @@ -490,6 +510,18 @@ return Payload ? &Payload->BrokenReductions : nullptr; } +IslAstInfo::LoopSet * IslAstInfo::getLoops(__isl_keep isl_ast_node *Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + return Payload ? &Payload->Loops : nullptr; +} + +const Loop * IslAstInfo::getLoop(__isl_keep isl_ast_node *Node) { + if (LoopSet *Loops = getLoops(Node)) + if (Loops->size() == 1) + return *Loops->begin(); + return nullptr; +} + isl_ast_build *IslAstInfo::getBuild(__isl_keep isl_ast_node *Node) { IslAstUserPayload *Payload = getNodePayload(Node); return Payload ? Payload->Build : nullptr; Index: lib/CodeGen/IslCodeGeneration.cpp =================================================================== --- lib/CodeGen/IslCodeGeneration.cpp +++ lib/CodeGen/IslCodeGeneration.cpp @@ -18,6 +18,7 @@ // its code in the new execution order defined by the changed scattering. // //===----------------------------------------------------------------------===// +#define DEBUG_TYPE "polly-codegen-isl" #include "polly/Config/config.h" #include "polly/CodeGen/IslExprBuilder.h" #include "polly/CodeGen/BlockGenerators.h" @@ -52,7 +53,6 @@ using namespace polly; using namespace llvm; -#define DEBUG_TYPE "polly-codegen-isl" class IslNodeBuilder { public: @@ -579,6 +579,7 @@ BasicBlock *StartBlock = executeScopConditionally(S, this); isl_ast_node *Ast = AstInfo.getAst(); + DEBUG(dbgs() << "[IslCodeGen] "; AstInfo.printScop(dbgs())); LoopAnnotator Annotator; PollyIRBuilder Builder(StartBlock->getContext(), llvm::ConstantFolder(), polly::IRInserter(Annotator));