diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -2576,7 +2576,11 @@ */ CXCursor_OMPCanonicalLoop = 289, - CXCursor_LastStmt = CXCursor_OMPCanonicalLoop, + /** OpenMP unroll directive. + */ + CXCursor_OMPUnrollDirective = 290, + + CXCursor_LastStmt = CXCursor_OMPUnrollDirective, /** * Cursor that represents the translation unit itself. diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -888,6 +888,106 @@ } }; +class OMPFullClause final : public OMPClause { + friend class OMPClauseReader; + + /// Build an empty clause. + explicit OMPFullClause() + : OMPClause(llvm::omp::OMPC_full, SourceLocation(), SourceLocation()) {} + +public: + /// Build a 'sizes' AST node. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the 'sizes' identifier. + /// \param LParenLoc Location of '('. + /// \param EndLoc Location of ')'. + /// \param Sizes Content of the clause. + static OMPFullClause *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc); + + /// Build an empty 'sizes' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumSizes Number of items in the clause. + static OMPFullClause *CreateEmpty(const ASTContext &C); + + child_range children() { return {child_iterator(), child_iterator()}; } + const_child_range children() const { + return {const_child_iterator(), const_child_iterator()}; + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_full; + } +}; + +class OMPPartialClause final : public OMPClause { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + Stmt *Factor; + + /// Build an empty clause. + explicit OMPPartialClause() + : OMPClause(llvm::omp::OMPC_partial, SourceLocation(), SourceLocation()) { + } + +public: + /// Build a 'sizes' AST node. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the 'sizes' identifier. + /// \param LParenLoc Location of '('. + /// \param EndLoc Location of ')'. + /// \param Sizes Content of the clause. + static OMPPartialClause *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc, Expr *Factor); + + /// Build an empty 'sizes' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumSizes Number of items in the clause. + static OMPPartialClause *CreateEmpty(const ASTContext &C); + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + Expr *getFactor() const { return cast_or_null(Factor); } + + void setFactor(Expr *E) { Factor = E; } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + const_child_range children() const { + return const_child_range(&Factor, &Factor + 1); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_partial; + } +}; + /// This represents 'collapse' clause in the '#pragma omp ...' /// directive. /// diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -2810,6 +2810,9 @@ DEF_TRAVERSE_STMT(OMPTileDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) +DEF_TRAVERSE_STMT(OMPUnrollDirective, + { TRY_TO(TraverseOMPExecutableDirective(S)); }) + DEF_TRAVERSE_STMT(OMPForDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) @@ -3057,6 +3060,17 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPFullClause(OMPFullClause *C) { + return true; +} + +template +bool RecursiveASTVisitor::VisitOMPPartialClause(OMPPartialClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + template bool RecursiveASTVisitor::VisitOMPCollapseClause(OMPCollapseClause *C) { diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -5034,6 +5034,78 @@ } }; +/// This represents the '#pragma omp tile' loop transformation directive. +class OMPUnrollDirective final : public OMPLoopBasedDirective { + friend class ASTStmtReader; + friend class OMPExecutableDirective; + + /// Default list of offsets. + enum { + PreInitsOffset = 0, + TransformedStmtOffset, + }; + + explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc) + : OMPLoopBasedDirective(OMPUnrollDirectiveClass, llvm::omp::OMPD_unroll, + StartLoc, EndLoc, 1) {} + + void setPreInits(Stmt *PreInits) { + Data->getChildren()[PreInitsOffset] = PreInits; + } + + void setTransformedStmt(Stmt *S) { + Data->getChildren()[TransformedStmtOffset] = S; + } + +public: + /// Create a new AST node representation for '#pragma omp tile'. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the introducer (e.g. the 'omp' token). + /// \param EndLoc Location of the directive's end (e.g. the tok::eod). + /// \param Clauses The directive's clauses. + /// \param NumLoops Number of associated loops (number of items in the + /// 'sizes' clause). + /// \param AssociatedStmt The outermost associated loop. + /// \param TransformedStmt The loop nest after tiling, or nullptr in + /// dependent contexts. + /// \param PreInits Helper preinits statements for the loop nest. + static OMPUnrollDirective * + Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef Clauses, Stmt *AssociatedStmt, + Stmt *TransformedStmt, Stmt *PreInits); + + /// Build an empty '#pragma omp tile' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumClauses Number of clauses to allocate. + /// \param NumLoops Number of associated loops to allocate. + static OMPUnrollDirective *CreateEmpty(const ASTContext &C, + unsigned NumClauses); + + /// Gets/sets the associated loops after tiling. + /// + /// This is in de-sugared format stored as a CompoundStmt. + /// + /// \code + /// for (...) + /// ... + /// \endcode + /// + /// Note that if the generated loops a become associated loops of another + /// directive, they may need to be hoisted before them. + Stmt *getTransformedStmt() const { + return Data->getChildren()[TransformedStmtOffset]; + } + + /// Return preinits statement. + Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == OMPUnrollDirectiveClass; + } +}; + /// This represents '#pragma omp scan' directive. /// /// \code diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -223,6 +223,7 @@ def OMPParallelDirective : StmtNode; def OMPSimdDirective : StmtNode; def OMPTileDirective : StmtNode; +def OMPUnrollDirective : StmtNode; def OMPForDirective : StmtNode; def OMPForSimdDirective : StmtNode; def OMPSectionsDirective : StmtNode; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -32,6 +32,7 @@ #include "clang/AST/NSAPI.h" #include "clang/AST/PrettyPrinter.h" #include "clang/AST/StmtCXX.h" +#include "clang/AST/StmtOpenMP.h" #include "clang/AST/TypeLoc.h" #include "clang/AST/TypeOrdering.h" #include "clang/Basic/BitmaskEnum.h" @@ -10234,6 +10235,11 @@ MapT &Map, unsigned Selector = 0, SourceRange SrcRange = SourceRange()); + bool checkTransformableLoopNest( + OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops, + SmallVectorImpl &LoopHelpers, + Stmt *&Body, SmallVectorImpl &OriginalInits); + /// Helper to keep information about the current `omp begin/end declare /// variant` nesting. struct OMPDeclareVariantScope { @@ -10530,6 +10536,11 @@ StmtResult ActOnOpenMPTileDirective(ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); + /// Called on well-formed '#pragma omp unroll' after parsing of its clauses + /// and the associated statement. + StmtResult ActOnOpenMPUnrollDirective(ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, + SourceLocation EndLoc); /// Called on well-formed '\#pragma omp for' after parsing /// of the associated statement. StmtResult @@ -10871,6 +10882,14 @@ SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); + + OMPClause *ActOnOpenMPFullClause(SourceLocation StartLoc, + SourceLocation EndLoc); + + OMPClause *ActOnOpenMPPartialClause(Expr *FactorExpr, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); + /// Called on well-formed 'collapse' clause. OMPClause *ActOnOpenMPCollapseClause(Expr *NumForLoops, SourceLocation StartLoc, diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h --- a/clang/include/clang/Serialization/ASTBitCodes.h +++ b/clang/include/clang/Serialization/ASTBitCodes.h @@ -1888,6 +1888,7 @@ STMT_OMP_PARALLEL_DIRECTIVE, STMT_OMP_SIMD_DIRECTIVE, STMT_OMP_TILE_DIRECTIVE, + STMT_OMP_UNROLL_DIRECTIVE, STMT_OMP_FOR_DIRECTIVE, STMT_OMP_FOR_SIMD_DIRECTIVE, STMT_OMP_SECTIONS_DIRECTIVE, diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -922,6 +922,36 @@ return new (Mem) OMPSizesClause(NumSizes); } +OMPFullClause *OMPFullClause::Create(const ASTContext &C, + SourceLocation StartLoc, + SourceLocation EndLoc) { + OMPFullClause *Clause = CreateEmpty(C); + Clause->setLocStart(StartLoc); + Clause->setLocEnd(EndLoc); + return Clause; +} + +OMPFullClause *OMPFullClause::CreateEmpty(const ASTContext &C) { + return new (C) OMPFullClause(); +} + +OMPPartialClause *OMPPartialClause::Create(const ASTContext &C, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc, + Expr *Factor) { + OMPPartialClause *Clause = CreateEmpty(C); + Clause->setLocStart(StartLoc); + Clause->setLParenLoc(LParenLoc); + Clause->setLocEnd(EndLoc); + Clause->setFactor(Factor); + return Clause; +} + +OMPPartialClause *OMPPartialClause::CreateEmpty(const ASTContext &C) { + return new (C) OMPPartialClause(); +} + OMPAllocateClause * OMPAllocateClause::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, Expr *Allocator, @@ -1561,6 +1591,18 @@ OS << ")"; } +void OMPClausePrinter::VisitOMPFullClause(OMPFullClause *Node) { OS << "full"; } + +void OMPClausePrinter::VisitOMPPartialClause(OMPPartialClause *Node) { + OS << "partial"; + + if (Expr *Factor = Node->getFactor()) { + OS << '('; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; + } +} + void OMPClausePrinter::VisitOMPAllocatorClause(OMPAllocatorClause *Node) { OS << "allocator("; Node->getAllocator()->printPretty(OS, nullptr, Policy, 0); diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -127,10 +127,16 @@ llvm::function_ref Callback) { CurStmt = CurStmt->IgnoreContainers(); for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) { - if (auto *Dir = dyn_cast(CurStmt)) - CurStmt = Dir->getTransformedStmt(); - if (auto *CanonLoop = dyn_cast(CurStmt)) - CurStmt = CanonLoop->getLoopStmt(); + while (true) { + if (auto *Dir = dyn_cast(CurStmt)) + CurStmt = Dir->getTransformedStmt(); + else if (auto *Dir = dyn_cast(CurStmt)) + CurStmt = Dir->getTransformedStmt(); + else if (auto *CanonLoop = dyn_cast(CurStmt)) + CurStmt = CanonLoop->getLoopStmt(); + else + break; + } if (Callback(Cnt, CurStmt)) return false; // Move on to the next nested for loop, or to the loop body. @@ -355,6 +361,25 @@ SourceLocation(), SourceLocation(), NumLoops); } +OMPUnrollDirective * +OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc, ArrayRef Clauses, + Stmt *AssociatedStmt, Stmt *TransformedStmt, + Stmt *PreInits) { + OMPUnrollDirective *Dir = createDirective( + C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc); + Dir->setTransformedStmt(TransformedStmt); + Dir->setPreInits(PreInits); + return Dir; +} + +OMPUnrollDirective *OMPUnrollDirective::CreateEmpty(const ASTContext &C, + unsigned NumClauses) { + return createEmptyDirective( + C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1, + SourceLocation(), SourceLocation()); +} + OMPForSimdDirective * OMPForSimdDirective::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, unsigned CollapsedNum, diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -669,6 +669,11 @@ PrintOMPExecutableDirective(Node); } +void StmtPrinter::VisitOMPUnrollDirective(OMPUnrollDirective *Node) { + Indent() << "#pragma omp unroll"; + PrintOMPExecutableDirective(Node); +} + void StmtPrinter::VisitOMPForDirective(OMPForDirective *Node) { Indent() << "#pragma omp for"; PrintOMPExecutableDirective(Node); diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -463,11 +463,18 @@ } void OMPClauseProfiler::VisitOMPSizesClause(const OMPSizesClause *C) { - for (auto E : C->getSizesRefs()) + for (Expr *E : C->getSizesRefs()) if (E) Profiler->VisitExpr(E); } +void OMPClauseProfiler::VisitOMPFullClause(const OMPFullClause *C) {} + +void OMPClauseProfiler::VisitOMPPartialClause(const OMPPartialClause *C) { + if (Expr *Factor = C->getFactor()) + Profiler->VisitExpr(Factor); +} + void OMPClauseProfiler::VisitOMPAllocatorClause(const OMPAllocatorClause *C) { if (C->getAllocator()) Profiler->VisitStmt(C->getAllocator()); @@ -878,6 +885,10 @@ VisitOMPLoopBasedDirective(S); } +void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) { + VisitOMPLoopBasedDirective(S); +} + void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) { VisitOMPLoopDirective(S); } diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -448,7 +448,8 @@ DKind == OMPD_target_teams_distribute || DKind == OMPD_target_teams_distribute_parallel_for || DKind == OMPD_target_teams_distribute_parallel_for_simd || - DKind == OMPD_target_teams_distribute_simd || DKind == OMPD_tile; + DKind == OMPD_target_teams_distribute_simd || DKind == OMPD_tile || + DKind == OMPD_unroll; } bool clang::isOpenMPWorksharingDirective(OpenMPDirectiveKind DKind) { @@ -576,7 +577,7 @@ } bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) { - return DKind == OMPD_tile; + return DKind == OMPD_tile || DKind == OMPD_unroll; } void clang::getOpenMPCaptureRegions( @@ -663,6 +664,7 @@ CaptureRegions.push_back(OMPD_unknown); break; case OMPD_tile: + case OMPD_unroll: // loop transformations do not introduce captures. break; case OMPD_threadprivate: diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6636,6 +6636,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -6954,6 +6955,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -9525,6 +9527,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -10349,6 +10352,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -11032,6 +11036,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -206,6 +206,9 @@ case Stmt::OMPTileDirectiveClass: EmitOMPTileDirective(cast(*S)); break; + case Stmt::OMPUnrollDirectiveClass: + EmitOMPUnrollDirective(cast(*S)); + break; case Stmt::OMPForDirectiveClass: EmitOMPForDirective(cast(*S)); break; diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -176,6 +176,8 @@ PreInits = cast_or_null(LD->getPreInits()); } else if (const auto *Tile = dyn_cast(&S)) { PreInits = cast_or_null(Tile->getPreInits()); + } else if (const auto *Unroll = dyn_cast(&S)) { + PreInits = cast_or_null(Unroll->getPreInits()); } else { llvm_unreachable("Unknown loop-based directive kind."); } @@ -1803,6 +1805,8 @@ SimplifiedS); if (auto *Dir = dyn_cast(SimplifiedS)) SimplifiedS = Dir->getTransformedStmt(); + if (auto *Dir = dyn_cast(SimplifiedS)) + SimplifiedS = Dir->getTransformedStmt(); if (const auto *CanonLoop = dyn_cast(SimplifiedS)) SimplifiedS = CanonLoop->getLoopStmt(); if (const auto *For = dyn_cast(SimplifiedS)) { @@ -2561,6 +2565,45 @@ EmitStmt(S.getTransformedStmt()); } +void CodeGenFunction::EmitOMPUnrollDirective(const OMPUnrollDirective &S) { + // This function is only called if the unrolled loop is not consumed by any + // other loop-associated construct. Such a loop-associated construct will have + // used the transformed AST. + + auto FullClauses = S.getClausesOfKind(); + const OMPFullClause *FullClause = nullptr; + if (!FullClauses.empty()) { + assert(hasSingleElement(FullClauses)); + FullClause = *FullClauses.begin(); + } + + auto PartialClauses = S.getClausesOfKind(); + const OMPPartialClause *PartialClause = nullptr; + if (!PartialClauses.empty()) { + assert(hasSingleElement(PartialClauses)); + PartialClause = *PartialClauses.begin(); + } + + uint64_t Factor = 0; + if (PartialClause) { + if (Expr *FactorExpr = PartialClause->getFactor()) { + RValue FactorRVal = EmitAnyExpr(FactorExpr, AggValueSlot::ignored(), + /*ignoreResult=*/true); + Factor = + cast(FactorRVal.getScalarVal())->getZExtValue(); + assert(Factor >= 1 && "One positive factors are valid"); + } + } + + // OMPTransformDirectiveScopeRAII UnrollScope(*this, &S); + LoopStack.setUnrollState(LoopAttributes::Enable); + if (Factor >= 1) + LoopStack.setUnrollCount(Factor); + else if (FullClause) + LoopStack.setUnrollState(LoopAttributes::Full); + EmitStmt(S.getAssociatedStmt()); +} + void CodeGenFunction::EmitOMPOuterLoop( bool DynamicOrOrdered, bool IsMonotonic, const OMPLoopDirective &S, CodeGenFunction::OMPPrivateScope &LoopScope, diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3417,6 +3417,7 @@ void EmitOMPParallelDirective(const OMPParallelDirective &S); void EmitOMPSimdDirective(const OMPSimdDirective &S); void EmitOMPTileDirective(const OMPTileDirective &S); + void EmitOMPUnrollDirective(const OMPUnrollDirective &S); void EmitOMPForDirective(const OMPForDirective &S); void EmitOMPForSimdDirective(const OMPForSimdDirective &S); void EmitOMPSectionsDirective(const OMPSectionsDirective &S); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -2154,6 +2154,7 @@ case OMPD_parallel: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_task: case OMPD_taskyield: case OMPD_barrier: @@ -2389,6 +2390,7 @@ case OMPD_parallel: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -2773,6 +2775,7 @@ case OMPC_allocator: case OMPC_depobj: case OMPC_detach: + case OMPC_partial: // OpenMP [2.5, Restrictions] // At most one num_threads clause can appear on the directive. // OpenMP [2.8.1, simd construct, Restrictions] @@ -2801,7 +2804,8 @@ ErrorFound = true; } - if (CKind == OMPC_ordered && PP.LookAhead(/*N=*/0).isNot(tok::l_paren)) + if ((CKind == OMPC_ordered || CKind == OMPC_partial) && + PP.LookAhead(/*N=*/0).isNot(tok::l_paren)) Clause = ParseOpenMPClause(CKind, WrongDirective); else Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective); @@ -2865,6 +2869,7 @@ case OMPC_reverse_offload: case OMPC_dynamic_allocators: case OMPC_destroy: + case OMPC_full: // OpenMP [2.7.1, Restrictions, p. 9] // Only one ordered clause can appear on a loop directive. // OpenMP [2.7.1, Restrictions, C/C++, p. 4] @@ -2941,7 +2946,7 @@ SkipUntil(tok::comma, tok::annot_pragma_openmp_end, StopBeforeMatch); break; default: - break; + llvm_unreachable("Unhandled clause"); } return ErrorFound ? nullptr : Clause; } diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1460,6 +1460,7 @@ case Stmt::OMPSectionsDirectiveClass: case Stmt::OMPSimdDirectiveClass: case Stmt::OMPTileDirectiveClass: + case Stmt::OMPUnrollDirectiveClass: case Stmt::OMPSingleDirectiveClass: case Stmt::OMPTargetDataDirectiveClass: case Stmt::OMPTargetDirectiveClass: diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -3804,6 +3804,11 @@ VisitStmt(S); } + void VisitOMPUnrollDirective(OMPUnrollDirective *S) { + // #pragma omp unroll does not introduce data sharing. + VisitStmt(S); + } + void VisitStmt(Stmt *S) { for (Stmt *C : S->children()) { if (C) { @@ -3969,6 +3974,7 @@ case OMPD_section: case OMPD_master: case OMPD_tile: + case OMPD_unroll: break; case OMPD_simd: case OMPD_for: @@ -5825,6 +5831,10 @@ Res = ActOnOpenMPTileDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); break; + case OMPD_unroll: + Res = ActOnOpenMPUnrollDirective(ClausesWithImplicit, AStmt, StartLoc, + EndLoc); + break; case OMPD_for: Res = ActOnOpenMPForDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc, VarsWithInheritedDSA); @@ -12438,6 +12448,35 @@ Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } +bool Sema::checkTransformableLoopNest( + OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops, + SmallVectorImpl &LoopHelpers, + Stmt *&Body, SmallVectorImpl &OriginalInits) { + return OMPLoopBasedDirective::doForAllLoops( + AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops, + [this, &LoopHelpers, &Body, &OriginalInits, Kind](unsigned Cnt, + Stmt *CurStmt) { + VarsWithInheritedDSAType TmpDSA; + unsigned SingleNumLoops = + checkOpenMPLoop(Kind, nullptr, nullptr, CurStmt, *this, *DSAStack, + TmpDSA, LoopHelpers[Cnt]); + if (SingleNumLoops == 0) + return true; + assert(SingleNumLoops == 1 && "Expect single loop iteration space"); + if (auto *For = dyn_cast(CurStmt)) { + OriginalInits.push_back(For->getInit()); + Body = For->getBody(); + } else { + assert(isa(CurStmt) && + "Expected canonical for or range-based for loops."); + auto *CXXFor = cast(CurStmt); + OriginalInits.push_back(CXXFor->getBeginStmt()); + Body = CXXFor->getBody(); + } + return false; + }); +} + StmtResult Sema::ActOnOpenMPTileDirective(ArrayRef Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) { @@ -12458,30 +12497,8 @@ SmallVector LoopHelpers(NumLoops); Stmt *Body = nullptr; SmallVector OriginalInits; - if (!OMPLoopBasedDirective::doForAllLoops( - AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, - NumLoops, - [this, &LoopHelpers, &Body, &OriginalInits](unsigned Cnt, - Stmt *CurStmt) { - VarsWithInheritedDSAType TmpDSA; - unsigned SingleNumLoops = - checkOpenMPLoop(OMPD_tile, nullptr, nullptr, CurStmt, *this, - *DSAStack, TmpDSA, LoopHelpers[Cnt]); - if (SingleNumLoops == 0) - return true; - assert(SingleNumLoops == 1 && "Expect single loop iteration space"); - if (auto *For = dyn_cast(CurStmt)) { - OriginalInits.push_back(For->getInit()); - Body = For->getBody(); - } else { - assert(isa(CurStmt) && - "Expected canonical for or range-based for loops."); - auto *CXXFor = cast(CurStmt); - OriginalInits.push_back(CXXFor->getBeginStmt()); - Body = CXXFor->getBody(); - } - return false; - })) + if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body, + OriginalInits)) return StmtError(); // Delay tiling to when template is completely instantiated. @@ -12666,6 +12683,243 @@ buildPreInits(Context, PreInits)); } +StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc) { + auto FullClauses = + OMPExecutableDirective::getClausesOfKind(Clauses); + const OMPFullClause *FullClause = nullptr; + if (!FullClauses.empty()) { + assert(hasSingleElement(FullClauses)); + FullClause = *FullClauses.begin(); + } + + auto PartialClauses = + OMPExecutableDirective::getClausesOfKind(Clauses); + const OMPPartialClause *PartialClause = nullptr; + if (!PartialClauses.empty()) { + assert(hasSingleElement(PartialClauses)); + PartialClause = *PartialClauses.begin(); + } + + assert(!(FullClause && PartialClause)); + + // Empty statement should only be possible if there already was an error. + if (!AStmt) + return StmtError(); + + if (!PartialClause) + return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + nullptr, nullptr); + + Stmt *TransformedStmt = nullptr; + // Stmt* PreInits = nullptr; + + constexpr unsigned NumLoops = 1; + SmallVector LoopHelpers( + NumLoops); + Stmt *Body = nullptr; + SmallVector OriginalInits; + if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body, + OriginalInits)) + return StmtError(); + auto &LoopHelper = LoopHelpers.front(); + auto &OriginalInit = OriginalInits.front(); + + // Delay unrolling to when template is completely instantiated. + if (CurContext->isDependentContext()) + return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + nullptr, nullptr); + + uint64_t Factor = 0; + Expr *FactorExpr = PartialClause->getFactor(); + if (FactorExpr) { + llvm::APSInt FactorInt; + VerifyIntegerConstantExpression(FactorExpr, &FactorInt); + Factor = FactorInt.getZExtValue(); + } else { + CanQualType FactorTy = Context.IntTy; + FactorExpr = new (Context) IntegerLiteral( + Context, llvm::APInt(Context.getIntWidth(FactorTy), 0), FactorTy, {}); + } + + // Collection of generated variable declaration. + SmallVector PreInits; + + // Create iteration variables for the generated loops. + SmallVector FloorIndVars; + SmallVector TileIndVars; + FloorIndVars.resize(NumLoops); + TileIndVars.resize(NumLoops); + for (unsigned I = 0; I < NumLoops; ++I) { + OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I]; + if (auto *PI = cast_or_null(LoopHelper.PreInits)) + PreInits.append(PI->decl_begin(), PI->decl_end()); + assert(LoopHelper.Counters.size() == 1 && + "Expect single-dimensional loop iteration space"); + auto *OrigCntVar = cast(LoopHelper.Counters.front()); + std::string OrigVarName = OrigCntVar->getNameInfo().getAsString(); + DeclRefExpr *IterVarRef = cast(LoopHelper.IterationVarRef); + QualType CntTy = IterVarRef->getType(); + + // Iteration variable for the floor (i.e. outer) loop. + { + std::string FloorCntName = + (Twine(".floor_") + llvm::utostr(I) + ".iv." + OrigVarName).str(); + VarDecl *FloorCntDecl = + buildVarDecl(*this, {}, CntTy, FloorCntName, nullptr, OrigCntVar); + FloorIndVars[I] = FloorCntDecl; + } + + // Iteration variable for the tile (i.e. inner) loop. + { + std::string TileCntName = + (Twine(".tile_") + llvm::utostr(I) + ".iv." + OrigVarName).str(); + + // Reuse the iteration variable created by checkOpenMPLoop. It is also + // used by the expressions to derive the original iteration variable's + // value from the logical iteration number. + auto *TileCntDecl = cast(IterVarRef->getDecl()); + TileCntDecl->setDeclName(&PP.getIdentifierTable().get(TileCntName)); + TileIndVars[I] = TileCntDecl; + } + if (auto *PI = dyn_cast_or_null(OriginalInits[I])) + PreInits.append(PI->decl_begin(), PI->decl_end()); + // Gather declarations for the data members used as counters. + for (Expr *CounterRef : LoopHelper.Counters) { + auto *CounterDecl = cast(CounterRef)->getDecl(); + if (isa(CounterDecl)) + PreInits.push_back(CounterDecl); + } + } + + // Once the original iteration values are set, append the innermost body. + Stmt *Inner = Body; + + // Create tile loops from the inside to the outside. + for (int I = NumLoops - 1; I >= 0; --I) { + OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I]; + Expr *NumIterations = LoopHelper.NumIterations; + auto *OrigCntVar = cast(LoopHelper.Counters[0]); + QualType CntTy = OrigCntVar->getType(); + Expr *DimTileSize = FactorExpr; + Scope *CurScope = getCurScope(); + + // Commonly used variables. + DeclRefExpr *TileIV = buildDeclRefExpr(*this, TileIndVars[I], CntTy, + OrigCntVar->getExprLoc()); + DeclRefExpr *FloorIV = buildDeclRefExpr(*this, FloorIndVars[I], CntTy, + OrigCntVar->getExprLoc()); + + // For init-statement: auto .tile.iv = .floor.iv + AddInitializerToDecl(TileIndVars[I], DefaultLvalueConversion(FloorIV).get(), + /*DirectInit=*/false); + Decl *CounterDecl = TileIndVars[I]; + StmtResult InitStmt = new (Context) + DeclStmt(DeclGroupRef::Create(Context, &CounterDecl, 1), + OrigCntVar->getBeginLoc(), OrigCntVar->getEndLoc()); + if (!InitStmt.isUsable()) + return StmtError(); + + // For cond-expression: .tile.iv < min(.floor.iv + DimTileSize, + // NumIterations) + ExprResult EndOfTile = BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), + BO_Add, FloorIV, DimTileSize); + if (!EndOfTile.isUsable()) + return StmtError(); + ExprResult IsPartialTile = + BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), BO_LT, + NumIterations, EndOfTile.get()); + if (!IsPartialTile.isUsable()) + return StmtError(); + ExprResult MinTileAndIterSpace = ActOnConditionalOp( + LoopHelper.Cond->getBeginLoc(), LoopHelper.Cond->getEndLoc(), + IsPartialTile.get(), NumIterations, EndOfTile.get()); + if (!MinTileAndIterSpace.isUsable()) + return StmtError(); + ExprResult CondExpr = BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), + BO_LT, TileIV, MinTileAndIterSpace.get()); + if (!CondExpr.isUsable()) + return StmtError(); + + // For incr-statement: ++.tile.iv + ExprResult IncrStmt = + BuildUnaryOp(CurScope, LoopHelper.Inc->getExprLoc(), UO_PreInc, TileIV); + if (!IncrStmt.isUsable()) + return StmtError(); + + // Statements to set the original iteration variable's value from the + // logical iteration number. + // Generated for loop is: + // Original_for_init; + // for (auto .tile.iv = .floor.iv; .tile.iv < min(.floor.iv + DimTileSize, + // NumIterations); ++.tile.iv) { + // Original_Body; + // Original_counter_update; + // } + // FIXME: If the innermost body is an loop itself, inserting these + // statements stops it being recognized as a perfectly nested loop (e.g. + // for applying tiling again). If this is the case, sink the expressions + // further into the inner loop. + SmallVector BodyParts; + BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end()); + BodyParts.push_back(Inner); + Inner = CompoundStmt::Create(Context, BodyParts, Inner->getBeginLoc(), + Inner->getEndLoc()); + Inner = new (Context) + ForStmt(Context, InitStmt.get(), CondExpr.get(), nullptr, + IncrStmt.get(), Inner, LoopHelper.Init->getBeginLoc(), + LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); + } + + // Create floor loops from the inside to the outside. + for (int I = NumLoops - 1; I >= 0; --I) { + auto &LoopHelper = LoopHelpers[I]; + Expr *NumIterations = LoopHelper.NumIterations; + DeclRefExpr *OrigCntVar = cast(LoopHelper.Counters[0]); + QualType CntTy = OrigCntVar->getType(); + Expr *DimTileSize = FactorExpr; + Scope *CurScope = getCurScope(); + + // Commonly used variables. + DeclRefExpr *FloorIV = buildDeclRefExpr(*this, FloorIndVars[I], CntTy, + OrigCntVar->getExprLoc()); + + // For init-statement: auto .floor.iv = 0 + AddInitializerToDecl( + FloorIndVars[I], + ActOnIntegerConstant(LoopHelper.Init->getExprLoc(), 0).get(), + /*DirectInit=*/false); + Decl *CounterDecl = FloorIndVars[I]; + StmtResult InitStmt = new (Context) + DeclStmt(DeclGroupRef::Create(Context, &CounterDecl, 1), + OrigCntVar->getBeginLoc(), OrigCntVar->getEndLoc()); + if (!InitStmt.isUsable()) + return StmtError(); + + // For cond-expression: .floor.iv < NumIterations + ExprResult CondExpr = BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), + BO_LT, FloorIV, NumIterations); + if (!CondExpr.isUsable()) + return StmtError(); + + // For incr-statement: .floor.iv += DimTileSize + ExprResult IncrStmt = BuildBinOp(CurScope, LoopHelper.Inc->getExprLoc(), + BO_AddAssign, FloorIV, DimTileSize); + if (!IncrStmt.isUsable()) + return StmtError(); + + Inner = new (Context) + ForStmt(Context, InitStmt.get(), CondExpr.get(), nullptr, + IncrStmt.get(), Inner, LoopHelper.Init->getBeginLoc(), + LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); + } + + return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + Inner, buildPreInits(Context, PreInits)); +} + OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr, SourceLocation StartLoc, SourceLocation LParenLoc, @@ -12717,6 +12971,9 @@ case OMPC_detach: Res = ActOnOpenMPDetachClause(Expr, StartLoc, LParenLoc, EndLoc); break; + case OMPC_partial: + Res = ActOnOpenMPPartialClause(Expr, StartLoc, LParenLoc, EndLoc); + break; case OMPC_device: case OMPC_if: case OMPC_default: @@ -12919,6 +13176,7 @@ case OMPD_end_declare_target: case OMPD_teams: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_sections: case OMPD_section: @@ -12996,6 +13254,7 @@ case OMPD_teams: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13076,6 +13335,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13154,6 +13414,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13233,6 +13494,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -13311,6 +13573,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13388,6 +13651,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13468,6 +13732,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -14139,6 +14404,25 @@ SizeExprs); } +OMPClause *Sema::ActOnOpenMPFullClause(SourceLocation StartLoc, + SourceLocation EndLoc) { + return OMPFullClause::Create(Context, StartLoc, EndLoc); +} + +OMPClause *Sema::ActOnOpenMPPartialClause(Expr *FactorExpr, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + if (FactorExpr) { + ExprResult FactorResult = VerifyPositiveIntegerConstantInClause( + FactorExpr, OMPC_partial, /*StrictlyPositive=*/true); + FactorExpr = AssertSuccess(FactorResult); + } + + return OMPPartialClause::Create(Context, StartLoc, LParenLoc, EndLoc, + FactorExpr); +} + OMPClause *Sema::ActOnOpenMPSingleExprWithArgClause( OpenMPClauseKind Kind, ArrayRef Argument, Expr *Expr, SourceLocation StartLoc, SourceLocation LParenLoc, @@ -14437,6 +14721,12 @@ case OMPC_destroy: Res = ActOnOpenMPDestroyClause(StartLoc, EndLoc); break; + case OMPC_full: + Res = ActOnOpenMPFullClause(StartLoc, EndLoc); + break; + case OMPC_partial: + Res = ActOnOpenMPPartialClause(nullptr, StartLoc, {}, EndLoc); + break; case OMPC_if: case OMPC_final: case OMPC_num_threads: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -1633,6 +1633,18 @@ return getSema().ActOnOpenMPSizesClause(Sizes, StartLoc, LParenLoc, EndLoc); } + OMPClause *RebuildOMPFullClause(SourceLocation StartLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPFullClause(StartLoc, EndLoc); + } + + OMPClause *RebuildOMPPartialClause(Expr *Factor, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPPartialClause(Factor, StartLoc, LParenLoc, + EndLoc); + } + /// Build a new OpenMP 'allocator' clause. /// /// By default, performs semantic analysis to build the new OpenMP clause. @@ -8441,6 +8453,17 @@ return Res; } +template +StmtResult +TreeTransform::TransformOMPUnrollDirective(OMPUnrollDirective *D) { + DeclarationNameInfo DirName; + getDerived().getSema().StartOpenMPDSABlock(D->getDirectiveKind(), DirName, + nullptr, D->getBeginLoc()); + StmtResult Res = getDerived().TransformOMPExecutableDirective(D); + getDerived().getSema().EndOpenMPDSABlock(Res.get()); + return Res; +} + template StmtResult TreeTransform::TransformOMPForDirective(OMPForDirective *D) { @@ -9108,6 +9131,28 @@ C->getLParenLoc(), C->getEndLoc()); } +template +OMPClause *TreeTransform::TransformOMPFullClause(OMPFullClause *C) { + if (!getDerived().AlwaysRebuild()) + return C; + return RebuildOMPFullClause(C->getBeginLoc(), C->getEndLoc()); +} + +template +OMPClause * +TreeTransform::TransformOMPPartialClause(OMPPartialClause *C) { + ExprResult T = getDerived().TransformExpr(C->getFactor()); + if (T.isInvalid()) + return nullptr; + Expr *Factor = T.get(); + bool Changed = Factor != C->getFactor(); + + if (!Changed && !getDerived().AlwaysRebuild()) + return C; + return RebuildOMPPartialClause(Factor, C->getBeginLoc(), C->getLParenLoc(), + C->getEndLoc()); +} + template OMPClause * TreeTransform::TransformOMPCollapseClause(OMPCollapseClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11745,6 +11745,12 @@ C = OMPSizesClause::CreateEmpty(Context, NumSizes); break; } + case llvm::omp::OMPC_full: + C = OMPFullClause::CreateEmpty(Context); + break; + case llvm::omp::OMPC_partial: + C = OMPPartialClause::CreateEmpty(Context); + break; case llvm::omp::OMPC_allocator: C = new (Context) OMPAllocatorClause(); break; @@ -12042,6 +12048,13 @@ C->setLParenLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPFullClause(OMPFullClause *C) {} + +void OMPClauseReader::VisitOMPPartialClause(OMPPartialClause *C) { + C->setFactor(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPAllocatorClause(OMPAllocatorClause *C) { C->setAllocator(Record.readExpr()); C->setLParenLoc(Record.readSourceLocation()); diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2287,7 +2287,7 @@ void ASTStmtReader::VisitOMPLoopBasedDirective(OMPLoopBasedDirective *D) { VisitStmt(D); - // Field CollapsedNum was read in ReadStmtFromStream. + // Field NumAssociatedLoops was read in ReadStmtFromStream. Record.skipInts(1); VisitOMPExecutableDirective(D); } @@ -2310,6 +2310,10 @@ VisitOMPLoopBasedDirective(D); } +void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) { + VisitOMPLoopBasedDirective(D); +} + void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) { VisitOMPLoopDirective(D); D->setHasCancel(Record.readBool()); @@ -3170,6 +3174,14 @@ break; } + case STMT_OMP_UNROLL_DIRECTIVE: { + unsigned NumLoops = Record[ASTStmtReader::NumStmtFields]; + assert(NumLoops == 1); + unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1]; + S = OMPUnrollDirective::CreateEmpty(Context, NumClauses); + break; + } + case STMT_OMP_FOR_DIRECTIVE: { unsigned CollapsedNum = Record[ASTStmtReader::NumStmtFields]; unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1]; diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6128,6 +6128,13 @@ Record.AddSourceLocation(C->getLParenLoc()); } +void OMPClauseWriter::VisitOMPFullClause(OMPFullClause *C) {} + +void OMPClauseWriter::VisitOMPPartialClause(OMPPartialClause *C) { + Record.AddStmt(C->getFactor()); + Record.AddSourceLocation(C->getLParenLoc()); +} + void OMPClauseWriter::VisitOMPAllocatorClause(OMPAllocatorClause *C) { Record.AddStmt(C->getAllocator()); Record.AddSourceLocation(C->getLParenLoc()); diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2210,6 +2210,11 @@ Code = serialization::STMT_OMP_TILE_DIRECTIVE; } +void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) { + VisitOMPLoopBasedDirective(D); + Code = serialization::STMT_OMP_UNROLL_DIRECTIVE; +} + void ASTStmtWriter::VisitOMPForDirective(OMPForDirective *D) { VisitOMPLoopDirective(D); Record.writeBool(D->hasCancel()); diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1294,6 +1294,7 @@ case Stmt::OMPTargetTeamsDistributeParallelForSimdDirectiveClass: case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass: case Stmt::OMPTileDirectiveClass: + case Stmt::OMPUnrollDirectiveClass: case Stmt::CapturedStmtClass: { const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState()); Engine.addAbortedBlock(node, currBldrCtx->getBlock()); diff --git a/clang/test/OpenMP/unroll_ast_print.cpp b/clang/test/OpenMP/unroll_ast_print.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/unroll_ast_print.cpp @@ -0,0 +1,107 @@ +// Check no warnings/errors +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -fsyntax-only -verify %s +// expected-no-diagnostics + +// Check AST and unparsing +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -ast-dump %s | FileCheck %s --check-prefix=DUMP +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -ast-print %s | FileCheck %s --check-prefix=PRINT --match-full-lines + +// Check same results after serialization round-trip +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -emit-pch -o %t %s +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -include-pch %t -ast-dump-all %s | FileCheck %s --check-prefix=DUMP +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -include-pch %t -ast-print %s | FileCheck %s --check-prefix=PRINT --match-full-lines + +#ifndef HEADER +#define HEADER + +// placeholder for loop body code. +extern "C" void body(...); + + + +// PRINT-LABEL: void func_unroll() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll +void func_unroll() { + // PRINT: #pragma omp unroll + // DUMP: OMPUnrollDirective + #pragma omp unroll + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_full() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_full +void func_unroll_full() { + // PRINT: #pragma omp unroll full + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPFullClause + #pragma omp unroll full + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_partial() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_partial +void func_unroll_partial() { + // PRINT: #pragma omp unroll partial + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + // DUMP-NEXT: <<>> + #pragma omp unroll partial + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_partial_factor() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_partial_factor +void func_unroll_partial_factor() { + // PRINT: #pragma omp unroll partial(4) + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + // DUMP-NEXT: ConstantExpr + // DUMP-NEXT: value: Int 4 + // DUMP-NEXT: IntegerLiteral {{.*}} 4 + #pragma omp unroll partial(4) + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_partial_factor_for() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_partial_factor_for +void func_unroll_partial_factor_for() { + // PRINT: #pragma omp for + // DUMP: OMPForDirective + #pragma omp for + // PRINT: #pragma omp unroll partial(2) + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + #pragma omp unroll partial(2) + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + +#endif diff --git a/clang/test/OpenMP/unroll_codegen.cpp b/clang/test/OpenMP/unroll_codegen.cpp new file mode 100644 --- /dev/null +++ b/clang/test/OpenMP/unroll_codegen.cpp @@ -0,0 +1,48 @@ +// Check code generation +// RUN: %clang_cc1 -verify -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -emit-llvm %s -o - | FileCheck %s --check-prefix=IR + +// Check same results after serialization round-trip +// RUN: %clang_cc1 -verify -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -emit-pch -o %t %s +// RUN: %clang_cc1 -verify -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -include-pch %t -emit-llvm %s -o - | FileCheck %s --check-prefix=IR +// expected-no-diagnostics + +#ifndef HEADER +#define HEADER + +// placeholder for loop body code. +extern "C" void body(...) {} + +#if 0 +void func_unroll(int n) { + #pragma omp unroll + for (int i = 7; i < n; i += 3) + body(i); +} + + + +void func_unroll_full() { + #pragma omp unroll full + for (int i = 7; i < 17; i += 3) + body(i); +} +#endif + +void func_unroll_partial() { + #pragma omp unroll partial + for (int i = 7; i < 789; i += 3) + body(i); +} +#if 0 + +void func_unroll_partial_factor() { + #pragma omp unroll partial(4) + for (int i = 7; i < 789; i += 3) + body(i); +} +#endif + + + + +#endif /* HEADER */ diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2045,6 +2045,7 @@ void VisitOMPParallelDirective(const OMPParallelDirective *D); void VisitOMPSimdDirective(const OMPSimdDirective *D); void VisitOMPTileDirective(const OMPTileDirective *D); + void VisitOMPUnrollDirective(const OMPUnrollDirective *D); void VisitOMPForDirective(const OMPForDirective *D); void VisitOMPForSimdDirective(const OMPForSimdDirective *D); void VisitOMPSectionsDirective(const OMPSectionsDirective *D); @@ -2219,10 +2220,27 @@ } void OMPClauseEnqueue::VisitOMPSizesClause(const OMPSizesClause *C) { - for (auto E : C->getSizesRefs()) + for (Expr* E : C->getSizesRefs()) Visitor->AddStmt(E); } + + + + +void OMPClauseEnqueue::VisitOMPFullClause(const OMPFullClause *C) {} + + +void OMPClauseEnqueue::VisitOMPPartialClause(const OMPPartialClause *C) { + Visitor->AddStmt(C->getFactor()); +} + + + + + + + void OMPClauseEnqueue::VisitOMPAllocatorClause(const OMPAllocatorClause *C) { Visitor->AddStmt(C->getAllocator()); } @@ -2872,6 +2890,12 @@ VisitOMPLoopBasedDirective(D); } + +void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) { + VisitOMPLoopBasedDirective(D); +} + + void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) { VisitOMPLoopDirective(D); } @@ -5550,6 +5574,8 @@ return cxstring::createRef("OMPSimdDirective"); case CXCursor_OMPTileDirective: return cxstring::createRef("OMPTileDirective"); + case CXCursor_OMPUnrollDirective: + return cxstring::createRef("OMPUnrollDirective"); case CXCursor_OMPForDirective: return cxstring::createRef("OMPForDirective"); case CXCursor_OMPForSimdDirective: diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -651,6 +651,9 @@ case Stmt::OMPTileDirectiveClass: K = CXCursor_OMPTileDirective; break; + case Stmt::OMPUnrollDirectiveClass: + K = CXCursor_OMPUnrollDirective; + break; case Stmt::OMPForDirectiveClass: K = CXCursor_OMPForDirective; break; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -69,6 +69,8 @@ let flangClass = "OmpObjectList"; } def OMPC_Sizes: Clause<"sizes"> { let clangClass = "OMPSizesClause"; } +def OMPC_Full: Clause<"full"> { let clangClass = "OMPFullClause"; } +def OMPC_Partial: Clause<"partial"> { let clangClass = "OMPPartialClause"; } def OMPC_FirstPrivate : Clause<"firstprivate"> { let clangClass = "OMPFirstprivateClause"; let flangClass = "OmpObjectList"; @@ -381,6 +383,12 @@ VersionedClause, ]; } +def OMP_Unroll : Directive<"unroll"> { + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + ]; +} def OMP_For : Directive<"for"> { let allowedClauses = [ VersionedClause,