diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -2265,6 +2265,48 @@ } }; +/// This is a dummy clause that represents 'compare' and 'capture' clauses are +/// present in the '#pragma omp atomic' directive. +/// +/// \code +/// #pragma omp atomic compare capture +/// \endcode +/// In this example directive '#pragma omp atomic' has 'compare' and 'capture' +/// clauses. +class OMPCompareCaptureClause : public OMPClause { +public: + /// Build 'compare' clause. + /// + /// \param StartLoc Starting location of the clause. + /// \param EndLoc Ending location of the clause. + OMPCompareCaptureClause(SourceLocation StartLoc, SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_compare_capture, StartLoc, EndLoc) {} + + /// Build an empty clause. + OMPCompareCaptureClause() + : OMPClause(llvm::omp::OMPC_compare_capture, SourceLocation(), + SourceLocation()) {} + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_compare_capture; + } +}; + /// This represents 'seq_cst' clause in the '#pragma omp atomic' /// directive. /// diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3239,6 +3239,12 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPCompareCaptureClause( + OMPCompareCaptureClause *) { + return true; +} + template bool RecursiveASTVisitor::VisitOMPSeqCstClause(OMPSeqCstClause *) { return true; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11193,6 +11193,9 @@ /// Called on well-formed 'compare' clause. OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc, SourceLocation EndLoc); + /// Called on well-formed 'compare' and 'capture' clauses. + OMPClause *ActOnOpenMPCompareCaptureClause(SourceLocation StartLoc, + SourceLocation EndLoc); /// Called on well-formed 'seq_cst' clause. OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc, SourceLocation EndLoc); diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -127,6 +127,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -219,6 +220,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -1798,6 +1800,10 @@ OS << "compare"; } +void OMPClausePrinter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) { + // Do nothing as it is dummy. +} + void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) { OS << "seq_cst"; } diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -553,6 +553,9 @@ void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseProfiler::VisitOMPCompareCaptureClause( + const OMPCompareCaptureClause *) {} + void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -164,6 +164,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -430,6 +431,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -5970,6 +5970,9 @@ case OMPC_compare: // Do nothing here as we already emit an error. break; + case OMPC_compare_capture: + // Do nothing here as we already emit an error. + break; case OMPC_if: case OMPC_final: case OMPC_num_threads: diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -3193,6 +3193,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/IndexedMap.h" #include "llvm/ADT/PointerEmbeddedInt.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include @@ -6355,6 +6356,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -10935,14 +10937,18 @@ SourceLocation AtomicKindLoc; OpenMPClauseKind MemOrderKind = OMPC_unknown; SourceLocation MemOrderLoc; + bool MutexClauseEncountered = false; + llvm::SmallSet EncounteredAtomicKinds; for (const OMPClause *C : Clauses) { switch (C->getClauseKind()) { case OMPC_read: case OMPC_write: case OMPC_update: + MutexClauseEncountered = true; + [[clang::fallthrough]]; case OMPC_capture: case OMPC_compare: { - if (AtomicKind != OMPC_unknown) { + if (AtomicKind != OMPC_unknown && MutexClauseEncountered) { Diag(C->getBeginLoc(), diag::err_omp_atomic_several_clauses) << SourceRange(C->getBeginLoc(), C->getEndLoc()); Diag(AtomicKindLoc, diag::note_omp_previous_mem_order_clause) @@ -10950,6 +10956,7 @@ } else { AtomicKind = C->getClauseKind(); AtomicKindLoc = C->getBeginLoc(); + EncounteredAtomicKinds.insert(C->getClauseKind()); } break; } @@ -10973,10 +10980,15 @@ // The following clauses are allowed, but we don't need to do anything here. case OMPC_hint: break; + case OMPC_compare_capture: + llvm_unreachable("OMPC_compare_capture should never be created"); default: llvm_unreachable("unknown clause is encountered"); } } + if (EncounteredAtomicKinds.contains(OMPC_compare) && + EncounteredAtomicKinds.contains(OMPC_capture)) + AtomicKind = OMPC_compare_capture; // OpenMP 5.0, 2.17.7 atomic Construct, Restrictions // If atomic-clause is read then memory-order-clause must not be acq_rel or // release. @@ -13481,6 +13493,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -14313,6 +14326,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -14775,6 +14789,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -15081,6 +15096,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -15272,6 +15288,9 @@ case OMPC_compare: Res = ActOnOpenMPCompareClause(StartLoc, EndLoc); break; + case OMPC_compare_capture: + Res = ActOnOpenMPCompareCaptureClause(StartLoc, EndLoc); + break; case OMPC_seq_cst: Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc); break; @@ -15423,6 +15442,11 @@ return new (Context) OMPCompareClause(StartLoc, EndLoc); } +OMPClause *Sema::ActOnOpenMPCompareCaptureClause(SourceLocation StartLoc, + SourceLocation EndLoc) { + return new (Context) OMPCompareCaptureClause(StartLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPSeqCstClause(SourceLocation StartLoc, SourceLocation EndLoc) { return new (Context) OMPSeqCstClause(StartLoc, EndLoc); @@ -15892,6 +15916,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -9467,6 +9467,13 @@ return C; } +template +OMPClause *TreeTransform::TransformOMPCompareCaptureClause( + OMPCompareCaptureClause *C) { + // No need to rebuild this clause, no template-dependent parameters. + return C; +} + template OMPClause * TreeTransform::TransformOMPSeqCstClause(OMPSeqCstClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11768,6 +11768,9 @@ case llvm::omp::OMPC_compare: C = new (Context) OMPCompareClause(); break; + case llvm::omp::OMPC_compare_capture: + C = new (Context) OMPCompareCaptureClause(); + break; case llvm::omp::OMPC_seq_cst: C = new (Context) OMPSeqCstClause(); break; @@ -12128,6 +12131,8 @@ void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {} +void OMPClauseReader::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) {} + void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {} diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6254,6 +6254,8 @@ void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {} +void OMPClauseWriter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) {} + void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {} diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2277,6 +2277,9 @@ void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseEnqueue::VisitOMPCompareCaptureClause( + const OMPCompareCaptureClause *) {} + void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -181,6 +181,10 @@ def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; } def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; } def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; } +// A dummy clause if compare and capture clauses are present. +def OMPC_CompareCapture : Clause<"compare_capture"> { + let clangClass = "OMPCompareCaptureClause"; +} def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; } def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; } def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; } @@ -283,7 +287,7 @@ def OMPC_NonTemporal : Clause<"nontemporal"> { let clangClass = "OMPNontemporalClause"; let flangClass = "Name"; - let isValueList = true; + let isValueList = true; } def OMP_ORDER_concurrent : ClauseVal<"concurrent",1,1> {}