diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -2266,6 +2266,48 @@ } }; +/// This is a dummy clause that represents 'compare' and 'capture' clauses are +/// present in the '#pragma omp atomic' directive. +/// +/// \code +/// #pragma omp atomic compare capture +/// \endcode +/// In this example directive '#pragma omp atomic' has 'compare' and 'capture' +/// clauses. +class OMPCompareCaptureClause final : public OMPClause { +public: + /// Build 'compare capture' clause. + /// + /// \param StartLoc Starting location of the clause. + /// \param EndLoc Ending location of the clause. + OMPCompareCaptureClause(SourceLocation StartLoc, SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_compare_capture, StartLoc, EndLoc) {} + + /// Build an empty clause. + OMPCompareCaptureClause() + : OMPClause(llvm::omp::OMPC_compare_capture, SourceLocation(), + SourceLocation()) {} + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_compare_capture; + } +}; + /// This represents 'seq_cst' clause in the '#pragma omp atomic' /// directive. /// diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3239,6 +3239,12 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPCompareCaptureClause( + OMPCompareCaptureClause *) { + llvm_unreachable("OMPCompareCaptureClause should never be reached"); +} + template bool RecursiveASTVisitor::VisitOMPSeqCstClause(OMPSeqCstClause *) { return true; diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -127,6 +127,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -219,6 +220,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -1798,6 +1800,10 @@ OS << "compare"; } +void OMPClausePrinter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) { + llvm_unreachable("OMPCompareCaptureClause should never be reached"); +} + void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) { OS << "seq_cst"; } diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -553,6 +553,11 @@ void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseProfiler::VisitOMPCompareCaptureClause( + const OMPCompareCaptureClause *) { + llvm_unreachable("OMPCompareCaptureClause should never be reached"); +} + void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -168,6 +168,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -434,6 +435,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -24,6 +24,7 @@ #include "clang/AST/StmtVisitor.h" #include "clang/Basic/OpenMPKinds.h" #include "clang/Basic/PrettyStackTrace.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" @@ -6038,6 +6039,9 @@ CGF.CGM.getDiags().Report(DiagID); break; } + case OMPC_compare_capture: + // Do nothing here as we already emit an error. + break; case OMPC_if: case OMPC_final: case OMPC_num_threads: @@ -6148,19 +6152,21 @@ AO = llvm::AtomicOrdering::Monotonic; MemOrderingSpecified = true; } + llvm::SmallSet KindsEncountered; OpenMPClauseKind Kind = OMPC_unknown; for (const OMPClause *C : S.clauses()) { // Find first clause (skip seq_cst|acq_rel|aqcuire|release|relaxed clause, // if it is first). - if (C->getClauseKind() != OMPC_seq_cst && - C->getClauseKind() != OMPC_acq_rel && - C->getClauseKind() != OMPC_acquire && - C->getClauseKind() != OMPC_release && - C->getClauseKind() != OMPC_relaxed && C->getClauseKind() != OMPC_hint) { - Kind = C->getClauseKind(); - break; - } + OpenMPClauseKind K = C->getClauseKind(); + if (K == OMPC_seq_cst || K == OMPC_acq_rel || K == OMPC_acquire || + K == OMPC_release || K == OMPC_relaxed || K == OMPC_hint) + continue; + Kind = K; + KindsEncountered.insert(K); } + if (KindsEncountered.contains(OMPC_compare) && + KindsEncountered.contains(OMPC_capture)) + Kind = OMPC_compare_capture; if (!MemOrderingSpecified) { llvm::AtomicOrdering DefaultOrder = CGM.getOpenMPRuntime().getDefaultMemoryOrdering(); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -3222,6 +3222,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/IndexedMap.h" #include "llvm/ADT/PointerEmbeddedInt.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Frontend/OpenMP/OMPAssume.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -6365,6 +6366,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -11315,14 +11317,18 @@ SourceLocation AtomicKindLoc; OpenMPClauseKind MemOrderKind = OMPC_unknown; SourceLocation MemOrderLoc; + bool MutexClauseEncountered = false; + llvm::SmallSet EncounteredAtomicKinds; for (const OMPClause *C : Clauses) { switch (C->getClauseKind()) { case OMPC_read: case OMPC_write: case OMPC_update: + MutexClauseEncountered = true; + LLVM_FALLTHROUGH; case OMPC_capture: case OMPC_compare: { - if (AtomicKind != OMPC_unknown) { + if (AtomicKind != OMPC_unknown && MutexClauseEncountered) { Diag(C->getBeginLoc(), diag::err_omp_atomic_several_clauses) << SourceRange(C->getBeginLoc(), C->getEndLoc()); Diag(AtomicKindLoc, diag::note_omp_previous_mem_order_clause) @@ -11330,6 +11336,7 @@ } else { AtomicKind = C->getClauseKind(); AtomicKindLoc = C->getBeginLoc(); + EncounteredAtomicKinds.insert(C->getClauseKind()); } break; } @@ -11353,10 +11360,15 @@ // The following clauses are allowed, but we don't need to do anything here. case OMPC_hint: break; + case OMPC_compare_capture: + llvm_unreachable("OMPC_compare_capture should never be reached"); default: llvm_unreachable("unknown clause is encountered"); } } + if (EncounteredAtomicKinds.contains(OMPC_compare) && + EncounteredAtomicKinds.contains(OMPC_capture)) + AtomicKind = OMPC_compare_capture; // OpenMP 5.0, 2.17.7 atomic Construct, Restrictions // If atomic-clause is read then memory-order-clause must not be acq_rel or // release. @@ -11786,6 +11798,13 @@ } // TODO: We don't set X, D, E, etc. here because in code gen we will emit // error directly. + } else if (AtomicKind == OMPC_compare_capture) { + // TODO: For now we emit an error here and in emitOMPAtomicExpr we ignore + // code gen. + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "atomic compare capture is not supported for now"); + Diag(AtomicKindLoc, DiagID); } setFunctionHasBranchProtectedScope(); @@ -13867,6 +13886,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -14699,6 +14719,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -15161,6 +15182,7 @@ case OMPC_write: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -15469,6 +15491,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -15660,6 +15683,8 @@ case OMPC_compare: Res = ActOnOpenMPCompareClause(StartLoc, EndLoc); break; + case OMPC_compare_capture: + llvm_unreachable("compare_capture is dummy node"); case OMPC_seq_cst: Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc); break; @@ -16280,6 +16305,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_compare_capture: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -9476,6 +9476,12 @@ return C; } +template +OMPClause *TreeTransform::TransformOMPCompareCaptureClause( + OMPCompareCaptureClause *C) { + llvm_unreachable("OMPCompareCaptureClause should never be reached"); +} + template OMPClause * TreeTransform::TransformOMPSeqCstClause(OMPSeqCstClause *C) { diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -11786,6 +11786,8 @@ case llvm::omp::OMPC_compare: C = new (Context) OMPCompareClause(); break; + case llvm::omp::OMPC_compare_capture: + llvm_unreachable("OMPCompareCaptureClause should never be reached"); case llvm::omp::OMPC_seq_cst: C = new (Context) OMPSeqCstClause(); break; @@ -12146,6 +12148,10 @@ void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {} +void OMPClauseReader::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) { + llvm_unreachable("OMPCompareCaptureClause should never be reached"); +} + void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {} diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6295,6 +6295,10 @@ void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {} +void OMPClauseWriter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) { + llvm_unreachable("OMPCompareCaptureClause should never be reached"); +} + void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {} diff --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp --- a/clang/test/OpenMP/atomic_messages.cpp +++ b/clang/test/OpenMP/atomic_messages.cpp @@ -958,3 +958,26 @@ // expected-note@+1 {{in instantiation of function template specialization 'mixed' requested here}} return mixed(); } + +#if _OPENMP >= 202011 +int compare() { + int a, b, c; +// omp51-error@+1 {{atomic compare is not supported for now}} +#pragma omp atomic compare + { + if (a == b) + a = c; + } +} + +int compare_capture() { + int a, b, c, x; +// omp51-error@+1 {{atomic compare capture is not supported for now}} +#pragma omp atomic compare capture + { + x = a; + if (a == b) + a = c; + } +} +#endif diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2277,6 +2277,11 @@ void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseEnqueue::VisitOMPCompareCaptureClause( + const OMPCompareCaptureClause *) { + llvm_unreachable("OMPCompareCaptureClause should never be reached"); +} + void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -181,6 +181,10 @@ def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; } def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; } def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; } +// A dummy clause if compare and capture clauses are present. +def OMPC_CompareCapture : Clause<"compare_capture"> { + let clangClass = "OMPCompareCaptureClause"; +} def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; } def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; } def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }