diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -959,6 +959,9 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective { friend class ASTStmtReader; + /// Number of loops generated by this loop transformation. + unsigned NumGeneratedLoops = 0; + protected: explicit OMPLoopTransformationDirective(StmtClass SC, OpenMPDirectiveKind Kind, @@ -967,10 +970,16 @@ unsigned NumAssociatedLoops) : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} + /// Set the number of loops generated by this loop transformation. + void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; } + public: /// Return the number of associated (consumed) loops. unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } + /// Return the number of loops generated by this loop transformation. + unsigned getNumGeneratedLoops() { return NumGeneratedLoops; } + /// Get the de-sugared statements after after the loop transformation. /// /// Might be nullptr if either the directive generates no loops and is handled @@ -5058,7 +5067,9 @@ unsigned NumLoops) : OMPLoopTransformationDirective(OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc, - NumLoops) {} + NumLoops) { + setNumGeneratedLoops(3 * NumLoops); + } void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5163,7 +5174,7 @@ static OMPUnrollDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef Clauses, Stmt *AssociatedStmt, - Stmt *TransformedStmt, Stmt *PreInits); + unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits); /// Build an empty '#pragma omp unroll' AST node for deserialization. /// diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -138,9 +138,18 @@ Stmt *TransformedStmt = Dir->getTransformedStmt(); if (!TransformedStmt) { - // May happen if the loop transformation does not result in a - // generated loop (such as full unrolling). - break; + unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops(); + if (NumGeneratedLoops == 0) { + // May happen if the loop transformation does not result in a + // generated loop (such as full unrolling). + break; + } + if (NumGeneratedLoops > 0) { + // The loop transformation construct has generated loops, but these + // may not have been generated yet due to being in a dependent + // context. + return true; + } } CurStmt = TransformedStmt; @@ -419,10 +428,13 @@ OMPUnrollDirective * OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef Clauses, - Stmt *AssociatedStmt, Stmt *TransformedStmt, - Stmt *PreInits) { + Stmt *AssociatedStmt, unsigned NumGeneratedLoops, + Stmt *TransformedStmt, Stmt *PreInits) { + assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop"); + auto *Dir = createDirective( C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc); + Dir->setNumGeneratedLoops(NumGeneratedLoops); Dir->setTransformedStmt(TransformedStmt); Dir->setPreInits(PreInits); return Dir; diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -12919,10 +12919,12 @@ Body, OriginalInits)) return StmtError(); + unsigned NumGeneratedLoops = PartialClause ? 1 : 0; + // Delay unrolling to when template is completely instantiated. if (CurContext->isDependentContext()) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - nullptr, nullptr); + NumGeneratedLoops, nullptr, nullptr); OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front(); @@ -12941,9 +12943,9 @@ // The generated loop may only be passed to other loop-associated directive // when a partial clause is specified. Without the requirement it is // sufficient to generate loop unroll metadata at code-generation. - if (!PartialClause) + if (NumGeneratedLoops == 0) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - nullptr, nullptr); + NumGeneratedLoops, nullptr, nullptr); // Otherwise, we need to provide a de-sugared/transformed AST that can be // associated with another loop directive. @@ -13164,7 +13166,8 @@ LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - OuterFor, buildPreInits(Context, PreInits)); + NumGeneratedLoops, OuterFor, + buildPreInits(Context, PreInits)); } OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr, diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2327,6 +2327,7 @@ void ASTStmtReader::VisitOMPLoopTransformationDirective( OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); + D->setNumGeneratedLoops(Record.readUInt32()); } void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2226,6 +2226,7 @@ void ASTStmtWriter::VisitOMPLoopTransformationDirective( OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); + Record.writeUInt32(D->getNumGeneratedLoops()); } void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { diff --git a/clang/test/OpenMP/tile_ast_print.cpp b/clang/test/OpenMP/tile_ast_print.cpp --- a/clang/test/OpenMP/tile_ast_print.cpp +++ b/clang/test/OpenMP/tile_ast_print.cpp @@ -162,4 +162,25 @@ } +// PRINT-LABEL: template void foo7(int start, int stop, int step) { +// DUMP-LABEL: FunctionTemplateDecl {{.*}} foo7 +template +void foo7(int start, int stop, int step) { + // PRINT: #pragma omp tile sizes(Tile) + // DUMP: OMPTileDirective + // DUMP-NEXT: OMPSizesClause + // DUMP-NEXT: DeclRefExpr {{.*}} 'Tile' 'int' + #pragma omp tile sizes(Tile) + // PRINT-NEXT: for (int i = start; i < stop; i += step) + // DUMP-NEXT: ForStmt + for (int i = start; i < stop; i += step) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} +void tfoo7() { + foo7<5>(0, 42, 2); +} + + #endif diff --git a/clang/test/OpenMP/unroll_ast_print.cpp b/clang/test/OpenMP/unroll_ast_print.cpp --- a/clang/test/OpenMP/unroll_ast_print.cpp +++ b/clang/test/OpenMP/unroll_ast_print.cpp @@ -124,4 +124,26 @@ unroll_templated(); } + +// PRINT-LABEL: template void unroll_templated_factor(int start, int stop, int step) { +// DUMP-LABEL: FunctionTemplateDecl {{.*}} unroll_templated_factor +template +void unroll_templated_factor(int start, int stop, int step) { + // PRINT: #pragma omp unroll partial(Factor) + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + // DUMP-NEXT: DeclRefExpr {{.*}} 'Factor' 'int' + #pragma omp unroll partial(Factor) + // PRINT-NEXT: for (int i = start; i < stop; i += step) + // DUMP-NEXT: ForStmt + for (int i = start; i < stop; i += step) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} +void unroll_template_factor() { + unroll_templated_factor<4>(0, 42, 2); +} + + #endif