diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -2265,6 +2265,48 @@ } }; +/// This is a dummy clause that represents 'compare' and 'capture' clauses are +/// present in the '#pragma omp atomic' directive. +/// +/// \code +/// #pragma omp atomic compare capture +/// \endcode +/// In this example directive '#pragma omp atomic' has 'compare' and 'capture' +/// clauses. +class OMPCompareCaptureClause : public OMPClause { +public: + /// Build 'compare' clause. + /// + /// \param StartLoc Starting location of the clause. + /// \param EndLoc Ending location of the clause. + OMPCompareCaptureClause(SourceLocation StartLoc, SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_compare_capture, StartLoc, EndLoc) {} + + /// Build an empty clause. + OMPCompareCaptureClause() + : OMPClause(llvm::omp::OMPC_compare_capture, SourceLocation(), + SourceLocation()) {} + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_compare_capture; + } +}; + /// This represents 'seq_cst' clause in the '#pragma omp atomic' /// directive. /// diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3239,6 +3239,12 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPCompareCaptureClause( + OMPCompareCaptureClause *) { + return true; +} + template bool RecursiveASTVisitor::VisitOMPSeqCstClause(OMPSeqCstClause *) { return true; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11193,6 +11193,9 @@ /// Called on well-formed 'compare' clause. OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc, SourceLocation EndLoc); + /// Called on well-formed 'compare' and 'capture' clauses. + OMPClause *ActOnOpenMPCompareCaptureClause(SourceLocation StartLoc, + SourceLocation EndLoc); /// Called on well-formed 'seq_cst' clause. OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc, SourceLocation EndLoc); diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -127,6 +127,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -219,6 +220,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -1798,6 +1800,10 @@ OS << "compare"; } +void OMPClausePrinter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) { + // Do nothing as it is dummy. +} + void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) { OS << "seq_cst"; } diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -553,6 +553,9 @@ void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseProfiler::VisitOMPCompareCaptureClause( + const OMPCompareCaptureClause *) {} + void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -164,6 +164,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -430,6 +431,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -24,6 +24,7 @@ #include "clang/AST/StmtVisitor.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/PrettyStackTrace.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" @@ -5970,6 +5971,9 @@ case OMPC_compare: // Do nothing here as we already emit an error. break; + case OMPC_compare_capture: + // Do nothing here as we already emit an error. + break; case OMPC_if: case OMPC_final: case OMPC_num_threads: @@ -6079,19 +6083,21 @@ AO = llvm::AtomicOrdering::Monotonic; MemOrderingSpecified = true; } + llvm::SmallSet KindsEncountered; OpenMPClauseKind Kind = OMPC_unknown; for (const OMPClause *C : S.clauses()) { // Find first clause (skip seq_cst|acq_rel|aqcuire|release|relaxed clause, // if it is first). - if (C->getClauseKind() != OMPC_seq_cst && - C->getClauseKind() != OMPC_acq_rel && - C->getClauseKind() != OMPC_acquire && - C->getClauseKind() != OMPC_release && - C->getClauseKind() != OMPC_relaxed && C->getClauseKind() != OMPC_hint) { - Kind = C->getClauseKind(); - break; - } + OpenMPClauseKind K = C->getClauseKind(); + if (K == OMPC_seq_cst || K == OMPC_acq_rel || K == OMPC_acquire || + K == OMPC_release || K == OMPC_relaxed || K == OMPC_hint) + continue; + Kind = K; + KindsEncountered.insert(K); } + if (KindsEncountered.contains(OMPC_compare) && + KindsEncountered.contains(OMPC_capture)) + Kind = OMPC_compare_capture; if (!MemOrderingSpecified) { llvm::AtomicOrdering DefaultOrder = CGM.getOpenMPRuntime().getDefaultMemoryOrdering(); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -3193,6 +3193,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/IndexedMap.h" #include "llvm/ADT/PointerEmbeddedInt.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include @@ -6355,6 +6356,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -10935,14 +10937,18 @@ SourceLocation AtomicKindLoc; OpenMPClauseKind MemOrderKind = OMPC_unknown; SourceLocation MemOrderLoc; + bool MutexClauseEncountered = false; + llvm::SmallSet EncounteredAtomicKinds; for (const OMPClause *C : Clauses) { switch (C->getClauseKind()) { case OMPC_read: case OMPC_write: case OMPC_update: + MutexClauseEncountered = true; + [[clang::fallthrough]]; case OMPC_capture: case OMPC_compare: { - if (AtomicKind != OMPC_unknown) { + if (AtomicKind != OMPC_unknown && MutexClauseEncountered) { Diag(C->getBeginLoc(), diag::err_omp_atomic_several_clauses) << SourceRange(C->getBeginLoc(), C->getEndLoc()); Diag(AtomicKindLoc, diag::note_omp_previous_mem_order_clause) @@ -10950,6 +10956,7 @@ } else { AtomicKind = C->getClauseKind(); AtomicKindLoc = C->getBeginLoc(); + EncounteredAtomicKinds.insert(C->getClauseKind()); } break; } @@ -10973,10 +10980,15 @@ // The following clauses are allowed, but we don't need to do anything here. case OMPC_hint: break; + case OMPC_compare_capture: + llvm_unreachable("OMPC_compare_capture should never be created"); default: llvm_unreachable("unknown clause is encountered"); } } + if (EncounteredAtomicKinds.contains(OMPC_compare) && + EncounteredAtomicKinds.contains(OMPC_capture)) + AtomicKind = OMPC_compare_capture; // OpenMP 5.0, 2.17.7 atomic Construct, Restrictions // If atomic-clause is read then memory-order-clause must not be acq_rel or // release. @@ -11400,6 +11412,13 @@ unsigned DiagID = Diags.getCustomDiagID( DiagnosticsEngine::Error, "atomic compare is not supported for now"); Diag(AtomicKindLoc, DiagID); + } else if (AtomicKind == OMPC_compare_capture) { + // TODO: For now we emit an error here and in emitOMPAtomicExpr we ignore + // code gen. + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "atomic compare capture is not supported for now"); + Diag(AtomicKindLoc, DiagID); } setFunctionHasBranchProtectedScope(); @@ -13481,6 +13500,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -14313,6 +14333,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -14775,6 +14796,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -15081,6 +15103,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -15272,6 +15295,9 @@ case OMPC_compare: Res = ActOnOpenMPCompareClause(StartLoc, EndLoc); break; + case OMPC_compare_capture: + Res = ActOnOpenMPCompareCaptureClause(StartLoc, EndLoc); + break; case OMPC_seq_cst: Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc); break; @@ -15423,6 +15449,11 @@ return new (Context) OMPCompareClause(StartLoc, EndLoc); } +OMPClause *Sema::ActOnOpenMPCompareCaptureClause(SourceLocation StartLoc, + SourceLocation EndLoc) { + return new (Context) OMPCompareCaptureClause(StartLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPSeqCstClause(SourceLocation StartLoc, SourceLocation EndLoc) { return new (Context) OMPSeqCstClause(StartLoc, EndLoc); @@ -15892,6 +15923,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -9467,6 +9467,13 @@ return C; } +template +OMPClause *TreeTransform::TransformOMPCompareCaptureClause( + OMPCompareCaptureClause *C) { + // No need to rebuild this clause, no template-dependent parameters. + return C; +} + template OMPClause * TreeTransform::TransformOMPSeqCstClause(OMPSeqCstClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11768,6 +11768,9 @@ case llvm::omp::OMPC_compare: C = new (Context) OMPCompareClause(); break; + case llvm::omp::OMPC_compare_capture: + C = new (Context) OMPCompareCaptureClause(); + break; case llvm::omp::OMPC_seq_cst: C = new (Context) OMPSeqCstClause(); break; @@ -12128,6 +12131,8 @@ void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {} +void OMPClauseReader::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) {} + void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {} diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6254,6 +6254,8 @@ void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {} +void OMPClauseWriter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) {} + void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {} diff --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp --- a/clang/test/OpenMP/atomic_messages.cpp +++ b/clang/test/OpenMP/atomic_messages.cpp @@ -949,4 +949,15 @@ a = c; } } + +int compare_capture() { + int a, b, c, x; +// omp51-error@+1 {{atomic compare capture is not supported for now}} +#pragma omp atomic compare capture + { + x = a; + if (a == b) + a = c; + } +} #endif diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2277,6 +2277,9 @@ void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseEnqueue::VisitOMPCompareCaptureClause( + const OMPCompareCaptureClause *) {} + void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -181,6 +181,10 @@ def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; } def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; } def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; } +// A dummy clause if compare and capture clauses are present. +def OMPC_CompareCapture : Clause<"compare_capture"> { + let clangClass = "OMPCompareCaptureClause"; +} def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; } def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; } def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }