diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -889,22 +889,23 @@ /// Calls the specified callback function for all the loops in \p CurStmt, /// from the outermost to the innermost. - static bool doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, - unsigned NumLoops, - llvm::function_ref Callback, - llvm::function_ref - OnTransformationCallback); + static bool + doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, + unsigned NumLoops, + llvm::function_ref Callback, + llvm::function_ref + OnTransformationCallback); static bool doForAllLoops(const Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref Callback, - llvm::function_ref + llvm::function_ref OnTransformationCallback) { auto &&NewCallback = [Callback](unsigned Cnt, Stmt *CurStmt) { return Callback(Cnt, CurStmt); }; auto &&NewTransformCb = - [OnTransformationCallback](OMPLoopBasedDirective *A) { + [OnTransformationCallback](OMPLoopTransformationDirective *A) { OnTransformationCallback(A); }; return doForAllLoops(const_cast(CurStmt), TryImperfectlyNestedLoops, @@ -917,7 +918,7 @@ doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref Callback) { - auto &&TransformCb = [](OMPLoopBasedDirective *) {}; + auto &&TransformCb = [](OMPLoopTransformationDirective *) {}; return doForAllLoops(CurStmt, TryImperfectlyNestedLoops, NumLoops, Callback, TransformCb); } @@ -954,6 +955,38 @@ } }; +/// The base class for all loop transformation directives. +class OMPLoopTransformationDirective : public OMPLoopBasedDirective { + friend class ASTStmtReader; + +protected: + explicit OMPLoopTransformationDirective(StmtClass SC, + OpenMPDirectiveKind Kind, + SourceLocation StartLoc, + SourceLocation EndLoc, + unsigned NumAssociatedLoops) + : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} + +public: + /// Return the number of associated (consumed) loops. + unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } + + /// Get the de-sugared statements after after the loop transformation. + /// + /// Might be nullptr if either the directive generates no loops and is handled + /// directly in CodeGen, or resolving a template-dependence context is + /// required. + Stmt *getTransformedStmt() const; + + /// Return preinits statement. + Stmt *getPreInits() const; + + static bool classof(const Stmt *T) { + return T->getStmtClass() == OMPTileDirectiveClass || + T->getStmtClass() == OMPUnrollDirectiveClass; + } +}; + /// This is a common base class for loop directives ('omp simd', 'omp /// for', 'omp for simd' etc.). It is responsible for the loop code generation. /// @@ -5011,7 +5044,7 @@ }; /// This represents the '#pragma omp tile' loop transformation directive. -class OMPTileDirective final : public OMPLoopBasedDirective { +class OMPTileDirective final : public OMPLoopTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5023,8 +5056,9 @@ explicit OMPTileDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopBasedDirective(OMPTileDirectiveClass, llvm::omp::OMPD_tile, - StartLoc, EndLoc, NumLoops) {} + : OMPLoopTransformationDirective(OMPTileDirectiveClass, + llvm::omp::OMPD_tile, StartLoc, EndLoc, + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5061,8 +5095,6 @@ static OMPTileDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned NumLoops); - unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } - /// Gets/sets the associated loops after tiling. /// /// This is in de-sugared format stored as a CompoundStmt. @@ -5092,7 +5124,7 @@ /// #pragma omp unroll /// for (int i = 0; i < 64; ++i) /// \endcode -class OMPUnrollDirective final : public OMPLoopBasedDirective { +class OMPUnrollDirective final : public OMPLoopTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5103,8 +5135,9 @@ }; explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc) - : OMPLoopBasedDirective(OMPUnrollDirectiveClass, llvm::omp::OMPD_unroll, - StartLoc, EndLoc, 1) {} + : OMPLoopTransformationDirective(OMPUnrollDirectiveClass, + llvm::omp::OMPD_unroll, StartLoc, EndLoc, + 1) {} /// Set the pre-init statements. void setPreInits(Stmt *PreInits) { diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -224,8 +224,9 @@ def OMPLoopDirective : StmtNode; def OMPParallelDirective : StmtNode; def OMPSimdDirective : StmtNode; -def OMPTileDirective : StmtNode; -def OMPUnrollDirective : StmtNode; +def OMPLoopTransformationDirective : StmtNode; +def OMPTileDirective : StmtNode; +def OMPUnrollDirective : StmtNode; def OMPForDirective : StmtNode; def OMPForSimdDirective : StmtNode; def OMPSectionsDirective : StmtNode; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -125,28 +125,25 @@ bool OMPLoopBasedDirective::doForAllLoops( Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref Callback, - llvm::function_ref + llvm::function_ref OnTransformationCallback) { CurStmt = CurStmt->IgnoreContainers(); for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) { while (true) { - auto *OrigStmt = CurStmt; - if (auto *Dir = dyn_cast(OrigStmt)) { - OnTransformationCallback(Dir); - CurStmt = Dir->getTransformedStmt(); - } else if (auto *Dir = dyn_cast(OrigStmt)) { - OnTransformationCallback(Dir); - CurStmt = Dir->getTransformedStmt(); - } else { + auto *Dir = dyn_cast(CurStmt); + if (!Dir) break; - } - if (!CurStmt) { - // May happen if the loop transformation does not result in a generated - // loop (such as full unrolling). - CurStmt = OrigStmt; + OnTransformationCallback(Dir); + + Stmt *TransformedStmt = Dir->getTransformedStmt(); + if (!TransformedStmt) { + // May happen if the loop transformation does not result in a + // generated loop (such as full unrolling). break; } + + CurStmt = TransformedStmt; } if (auto *CanonLoop = dyn_cast(CurStmt)) CurStmt = CanonLoop->getLoopStmt(); @@ -363,6 +360,32 @@ return Dir; } +Stmt *OMPLoopTransformationDirective::getTransformedStmt() const { + switch (getStmtClass()) { +#define STMT(CLASS, PARENT) +#define ABSTRACT_STMT(CLASS) +#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ + case Stmt::CLASS##Class: \ + return static_cast(this)->getTransformedStmt(); +#include "clang/AST/StmtNodes.inc" + default: + llvm_unreachable("Not a loop transformation"); + } +} + +Stmt *OMPLoopTransformationDirective::getPreInits() const { + switch (getStmtClass()) { +#define STMT(CLASS, PARENT) +#define ABSTRACT_STMT(CLASS) +#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ + case Stmt::CLASS##Class: \ + return static_cast(this)->getPreInits(); +#include "clang/AST/StmtNodes.inc" + default: + llvm_unreachable("Not a loop transformation"); + } +} + OMPForDirective *OMPForDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum, diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -915,12 +915,17 @@ VisitOMPLoopDirective(S); } -void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) { +void StmtProfiler::VisitOMPLoopTransformationDirective( + const OMPLoopTransformationDirective *S) { VisitOMPLoopBasedDirective(S); } +void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) { + VisitOMPLoopTransformationDirective(S); +} + void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) { - VisitOMPLoopBasedDirective(S); + VisitOMPLoopTransformationDirective(S); } void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) { diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -1829,9 +1829,7 @@ return; } if (SimplifiedS == NextLoop) { - if (auto *Dir = dyn_cast(SimplifiedS)) - SimplifiedS = Dir->getTransformedStmt(); - if (auto *Dir = dyn_cast(SimplifiedS)) + if (auto *Dir = dyn_cast(SimplifiedS)) SimplifiedS = Dir->getTransformedStmt(); if (const auto *CanonLoop = dyn_cast(SimplifiedS)) SimplifiedS = CanonLoop->getLoopStmt(); diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -3823,13 +3823,8 @@ VisitSubCaptures(S); } - void VisitOMPTileDirective(OMPTileDirective *S) { - // #pragma omp tile does not introduce data sharing. - VisitStmt(S); - } - - void VisitOMPUnrollDirective(OMPUnrollDirective *S) { - // #pragma omp unroll does not introduce data sharing. + void VisitOMPLoopTransformationDirective(OMPLoopTransformationDirective *S) { + // Loop transformation directives do not introduce data sharing VisitStmt(S); } @@ -9050,15 +9045,8 @@ } return false; }, - [&SemaRef, &Captures](OMPLoopBasedDirective *Transform) { - Stmt *DependentPreInits; - if (auto *Dir = dyn_cast(Transform)) { - DependentPreInits = Dir->getPreInits(); - } else if (auto *Dir = dyn_cast(Transform)) { - DependentPreInits = Dir->getPreInits(); - } else { - llvm_unreachable("Unexpected loop transformation"); - } + [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) { + Stmt *DependentPreInits = Transform->getPreInits(); if (!DependentPreInits) return; for (Decl *C : cast(DependentPreInits)->getDeclGroup()) { diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2324,12 +2324,17 @@ VisitOMPLoopDirective(D); } -void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { +void ASTStmtReader::VisitOMPLoopTransformationDirective( + OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); } +void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { + VisitOMPLoopTransformationDirective(D); +} + void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) { - VisitOMPLoopBasedDirective(D); + VisitOMPLoopTransformationDirective(D); } void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2223,13 +2223,18 @@ Code = serialization::STMT_OMP_SIMD_DIRECTIVE; } -void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { +void ASTStmtWriter::VisitOMPLoopTransformationDirective( + OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); +} + +void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { + VisitOMPLoopTransformationDirective(D); Code = serialization::STMT_OMP_TILE_DIRECTIVE; } void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) { - VisitOMPLoopBasedDirective(D); + VisitOMPLoopTransformationDirective(D); Code = serialization::STMT_OMP_UNROLL_DIRECTIVE; } diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2046,6 +2046,8 @@ void VisitOMPLoopDirective(const OMPLoopDirective *D); void VisitOMPParallelDirective(const OMPParallelDirective *D); void VisitOMPSimdDirective(const OMPSimdDirective *D); + void + VisitOMPLoopTransformationDirective(const OMPLoopTransformationDirective *D); void VisitOMPTileDirective(const OMPTileDirective *D); void VisitOMPUnrollDirective(const OMPUnrollDirective *D); void VisitOMPForDirective(const OMPForDirective *D); @@ -2901,12 +2903,17 @@ VisitOMPLoopDirective(D); } -void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) { +void EnqueueVisitor::VisitOMPLoopTransformationDirective( + const OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); } +void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) { + VisitOMPLoopTransformationDirective(D); +} + void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) { - VisitOMPLoopBasedDirective(D); + VisitOMPLoopTransformationDirective(D); } void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) {