Index: clang/include/clang/AST/OpenMPClause.h =================================================================== --- clang/include/clang/AST/OpenMPClause.h +++ clang/include/clang/AST/OpenMPClause.h @@ -2513,6 +2513,89 @@ } }; +/// This represents 'fail' clause in the '#pragma omp atomic' +/// directive. +/// +/// \code +/// #pragma omp atomic compare fail +/// \endcode +/// In this example directive '#pragma omp atomic compare' has 'fail' clause. +class OMPFailClause final : public OMPClause { + + // FailParameter is a memory-order-clause. Storing the ClauseKind is + // sufficient for our purpose. + OpenMPClauseKind FailParameter = llvm::omp::Clause::OMPC_unknown; + SourceLocation FailParameterLoc; + SourceLocation LParenLoc; + + friend class OMPClauseReader; + + /// Sets the location of '(' in fail clause. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Sets the location of memoryOrder clause argument in fail clause. + void setFailParameterLoc(SourceLocation Loc) { FailParameterLoc = Loc; } + + /// Sets the mem_order clause for 'atomic compare fail' directive. + void setFailParameter(OpenMPClauseKind FailParameter) { + this->FailParameter = FailParameter; + assert(checkFailClauseParameter(FailParameter) && + "Invalid fail clause parameter"); + } + +public: + /// Build 'fail' clause. + /// + /// \param StartLoc Starting location of the clause. + /// \param EndLoc Ending location of the clause. + OMPFailClause(SourceLocation StartLoc, SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_fail, StartLoc, EndLoc) {} + + OMPFailClause(OpenMPClauseKind FailParameter, SourceLocation FailParameterLoc, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) + : OMPClause(llvm::omp::OMPC_fail, StartLoc, EndLoc), + FailParameterLoc(FailParameterLoc), LParenLoc(LParenLoc) { + + setFailParameter(FailParameter); + } + + /// Build an empty clause. + OMPFailClause() + : OMPClause(llvm::omp::OMPC_fail, SourceLocation(), SourceLocation()) {} + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_fail; + } + + /// Gets the location of '(' (for the parameter) in fail clause. + SourceLocation getLParenLoc() const { + return LParenLoc; + } + + /// Gets the location of Fail Parameter (type memory-order-clause) in + /// fail clause. + SourceLocation getFailParameterLoc() const { return FailParameterLoc; } + + /// Gets the parameter (type memory-order-clause) in Fail clause. + OpenMPClauseKind getFailParameter() const { return FailParameter; } +}; + /// This represents clause 'private' in the '#pragma omp ...' directives. /// /// \code Index: clang/include/clang/AST/RecursiveASTVisitor.h =================================================================== --- clang/include/clang/AST/RecursiveASTVisitor.h +++ clang/include/clang/AST/RecursiveASTVisitor.h @@ -3398,6 +3398,11 @@ return true; } +template +bool RecursiveASTVisitor::VisitOMPFailClause(OMPFailClause *) { + return true; +} + template bool RecursiveASTVisitor::VisitOMPSeqCstClause(OMPSeqCstClause *) { return true; Index: clang/include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- clang/include/clang/Basic/DiagnosticSemaKinds.td +++ clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10967,6 +10967,8 @@ "expect binary operator in conditional expression|expect '<', '>' or '==' as order operator|expect comparison in a form of 'x == e', 'e == x', 'x ordop expr', or 'expr ordop x'|" "expect lvalue for result value|expect scalar value|expect integer value|unexpected 'else' statement|expect '==' operator|expect an assignment statement 'v = x'|" "expect a 'if' statement|expect no more than two statements|expect a compound statement|expect 'else' statement|expect a form 'r = x == e; if (r) ...'}0">; +def err_omp_atomic_fail_wrong_or_no_clauses : Error<"expected a memory order clause">; +def err_omp_atomic_fail_no_compare : Error<"expected 'compare' clause with the 'fail' modifier">; def err_omp_atomic_several_clauses : Error< "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">; def err_omp_several_mem_order_clauses : Error< Index: clang/include/clang/Basic/OpenMPKinds.h =================================================================== --- clang/include/clang/Basic/OpenMPKinds.h +++ clang/include/clang/Basic/OpenMPKinds.h @@ -363,6 +363,11 @@ /// \return true - if the above condition is met for this directive /// otherwise - false. bool needsTaskBasedThreadLimit(OpenMPDirectiveKind DKind); + +/// Checks if the parameter to the fail clause in "#pragma atomic compare fail" +/// is restricted only to memory order clauses of "OMPC_acquire", +/// "OMPC_relaxed" and "OMPC_seq_cst". +bool checkFailClauseParameter(OpenMPClauseKind FailClauseParameter); } #endif Index: clang/include/clang/Basic/OpenMPKinds.def =================================================================== --- clang/include/clang/Basic/OpenMPKinds.def +++ clang/include/clang/Basic/OpenMPKinds.def @@ -41,6 +41,9 @@ #ifndef OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND #define OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(Name) #endif +#ifndef OPENMP_ATOMIC_FAIL_MODIFIER +#define OPENMP_ATOMIC_FAIL_MODIFIER(Name) +#endif #ifndef OPENMP_AT_KIND #define OPENMP_AT_KIND(Name) #endif @@ -138,6 +141,11 @@ OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(acq_rel) OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(relaxed) +// Modifiers for atomic 'fail' clause. +OPENMP_ATOMIC_FAIL_MODIFIER(seq_cst) +OPENMP_ATOMIC_FAIL_MODIFIER(acquire) +OPENMP_ATOMIC_FAIL_MODIFIER(relaxed) + // Modifiers for 'at' clause. OPENMP_AT_KIND(compilation) OPENMP_AT_KIND(execution) @@ -226,6 +234,7 @@ #undef OPENMP_SCHEDULE_MODIFIER #undef OPENMP_SCHEDULE_KIND #undef OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND +#undef OPENMP_ATOMIC_FAIL_MODIFIER #undef OPENMP_AT_KIND #undef OPENMP_SEVERITY_KIND #undef OPENMP_MAP_KIND Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -12198,6 +12198,13 @@ /// Called on well-formed 'compare' clause. OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc, SourceLocation EndLoc); + /// Called on well-formed 'fail' clause. + OMPClause *ActOnOpenMPFailClause(SourceLocation StartLoc, + SourceLocation EndLoc); + OMPClause *ActOnOpenMPFailClause( + OpenMPClauseKind Kind, SourceLocation KindLoc, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); + /// Called on well-formed 'seq_cst' clause. OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc, SourceLocation EndLoc); Index: clang/lib/AST/OpenMPClause.cpp =================================================================== --- clang/lib/AST/OpenMPClause.cpp +++ clang/lib/AST/OpenMPClause.cpp @@ -130,6 +130,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_fail: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -227,6 +228,7 @@ case OMPC_update: case OMPC_capture: case OMPC_compare: + case OMPC_fail: case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -1925,6 +1927,16 @@ OS << "compare"; } +void OMPClausePrinter::VisitOMPFailClause(OMPFailClause *Node) { + OS << "fail"; + if (Node) { + OS << "("; + OS << getOpenMPSimpleClauseTypeName( + Node->getClauseKind(), static_cast(Node->getFailParameter())); + OS << ")"; + } +} + void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) { OS << "seq_cst"; } Index: clang/lib/AST/StmtProfile.cpp =================================================================== --- clang/lib/AST/StmtProfile.cpp +++ clang/lib/AST/StmtProfile.cpp @@ -582,6 +582,8 @@ void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseProfiler::VisitOMPFailClause(const OMPFailClause *) {} + void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} Index: clang/lib/Basic/CMakeLists.txt =================================================================== --- clang/lib/Basic/CMakeLists.txt +++ clang/lib/Basic/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_LINK_COMPONENTS Support TargetParser + FrontendOpenMP ) find_first_existing_vc_file("${LLVM_MAIN_SRC_DIR}" llvm_vc) Index: clang/lib/Basic/OpenMPKinds.cpp =================================================================== --- clang/lib/Basic/OpenMPKinds.cpp +++ clang/lib/Basic/OpenMPKinds.cpp @@ -104,6 +104,11 @@ .Case(#Name, OMPC_ATOMIC_DEFAULT_MEM_ORDER_##Name) #include "clang/Basic/OpenMPKinds.def" .Default(OMPC_ATOMIC_DEFAULT_MEM_ORDER_unknown); + case OMPC_fail: + return static_cast(llvm::StringSwitch(Str) +#define OPENMP_ATOMIC_FAIL_MODIFIER(Name) .Case(#Name, OMPC_##Name) +#include "clang/Basic/OpenMPKinds.def" + .Default(OMPC_unknown)); case OMPC_device_type: return llvm::StringSwitch(Str) #define OPENMP_DEVICE_TYPE_KIND(Name) .Case(#Name, OMPC_DEVICE_TYPE_##Name) @@ -434,6 +439,11 @@ #include "clang/Basic/OpenMPKinds.def" } llvm_unreachable("Invalid OpenMP 'depend' clause type"); + case OMPC_fail: { + OpenMPClauseKind CK = static_cast(Type); + return getOpenMPClauseName(CK).data(); + llvm_unreachable("Invalid OpenMP 'fail' clause modifier"); + } case OMPC_device: switch (Type) { case OMPC_DEVICE_unknown: @@ -889,3 +899,10 @@ llvm_unreachable("Unknown OpenMP directive"); } } + +bool clang::checkFailClauseParameter(OpenMPClauseKind FailClauseParameter) { + return FailClauseParameter == llvm::omp::OMPC_acquire || + FailClauseParameter == llvm::omp::OMPC_relaxed || + FailClauseParameter == llvm::omp::OMPC_seq_cst; +} + Index: clang/lib/CodeGen/CGStmtOpenMP.cpp =================================================================== --- clang/lib/CodeGen/CGStmtOpenMP.cpp +++ clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -6516,6 +6516,10 @@ IsPostfixUpdate, IsFailOnly, Loc); break; } + case OMPC_fail: { + //TODO + break; + } default: llvm_unreachable("Clause is not allowed in 'omp atomic'."); } Index: clang/lib/Parse/ParseOpenMP.cpp =================================================================== --- clang/lib/Parse/ParseOpenMP.cpp +++ clang/lib/Parse/ParseOpenMP.cpp @@ -3248,6 +3248,7 @@ else Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective); break; + case OMPC_fail: case OMPC_default: case OMPC_proc_bind: case OMPC_atomic_default_mem_order: Index: clang/lib/Sema/SemaOpenMP.cpp =================================================================== --- clang/lib/Sema/SemaOpenMP.cpp +++ clang/lib/Sema/SemaOpenMP.cpp @@ -12682,6 +12682,14 @@ } break; } + case OMPC_fail: { + if (AtomicKind != OMPC_compare) { + Diag(C->getBeginLoc(), diag::err_omp_atomic_fail_no_compare) + << SourceRange(C->getBeginLoc(), C->getEndLoc()); + return StmtError(); + } + break; + } case OMPC_seq_cst: case OMPC_acq_rel: case OMPC_acquire: @@ -16883,6 +16891,11 @@ static_cast(Argument), ArgumentLoc, StartLoc, LParenLoc, EndLoc); break; + case OMPC_fail: + Res = ActOnOpenMPFailClause( + static_cast(Argument), + ArgumentLoc, StartLoc, LParenLoc, EndLoc); + break; case OMPC_update: Res = ActOnOpenMPUpdateClause(static_cast(Argument), ArgumentLoc, StartLoc, LParenLoc, EndLoc); @@ -17523,6 +17536,9 @@ case OMPC_compare: Res = ActOnOpenMPCompareClause(StartLoc, EndLoc); break; + case OMPC_fail: + Res = ActOnOpenMPFailClause(StartLoc, EndLoc); + break; case OMPC_seq_cst: Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc); break; @@ -17683,6 +17699,24 @@ return new (Context) OMPCompareClause(StartLoc, EndLoc); } +OMPClause *Sema::ActOnOpenMPFailClause(SourceLocation StartLoc, + SourceLocation EndLoc) { + return new (Context) OMPFailClause(StartLoc, EndLoc); +} + +OMPClause *Sema::ActOnOpenMPFailClause( + OpenMPClauseKind Parameter, SourceLocation KindLoc, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc) { + + if (!checkFailClauseParameter(Parameter)) { + Diag(KindLoc, diag::err_omp_atomic_fail_wrong_or_no_clauses); + return nullptr; + } + return new (Context) + OMPFailClause(Parameter, KindLoc, StartLoc, LParenLoc, EndLoc); +} + OMPClause *Sema::ActOnOpenMPSeqCstClause(SourceLocation StartLoc, SourceLocation EndLoc) { return new (Context) OMPSeqCstClause(StartLoc, EndLoc); Index: clang/lib/Sema/TreeTransform.h =================================================================== --- clang/lib/Sema/TreeTransform.h +++ clang/lib/Sema/TreeTransform.h @@ -9868,6 +9868,12 @@ return C; } +template +OMPClause *TreeTransform::TransformOMPFailClause(OMPFailClause *C) { + // No need to rebuild this clause, no template-dependent parameters. + return C; +} + template OMPClause * TreeTransform::TransformOMPSeqCstClause(OMPSeqCstClause *C) { Index: clang/lib/Serialization/ASTReader.cpp =================================================================== --- clang/lib/Serialization/ASTReader.cpp +++ clang/lib/Serialization/ASTReader.cpp @@ -10276,6 +10276,9 @@ case llvm::omp::OMPC_compare: C = new (Context) OMPCompareClause(); break; + case llvm::omp::OMPC_fail: + C = new (Context) OMPFailClause(); + break; case llvm::omp::OMPC_seq_cst: C = new (Context) OMPSeqCstClause(); break; @@ -10669,6 +10672,16 @@ void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {} +// Read the parameter of fail clause. This will have been saved when +// OMPClauseWriter is called. +void OMPClauseReader::VisitOMPFailClause(OMPFailClause *C) { + C->setLParenLoc(Record.readSourceLocation()); + SourceLocation FailParameterLoc = Record.readSourceLocation(); + C->setFailParameterLoc(FailParameterLoc); + OpenMPClauseKind CKind = Record.readEnum(); + C->setFailParameter(CKind); +} + void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {} Index: clang/lib/Serialization/ASTWriter.cpp =================================================================== --- clang/lib/Serialization/ASTWriter.cpp +++ clang/lib/Serialization/ASTWriter.cpp @@ -6622,6 +6622,13 @@ void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {} +// Save the parameter of fail clause. +void OMPClauseWriter::VisitOMPFailClause(OMPFailClause *C) { + Record.AddSourceLocation(C->getLParenLoc()); + Record.AddSourceLocation(C->getFailParameterLoc()); + Record.writeEnum(C->getFailParameter()); +} + void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {} void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {} Index: clang/test/OpenMP/atomic_ast_print.cpp =================================================================== --- clang/test/OpenMP/atomic_ast_print.cpp +++ clang/test/OpenMP/atomic_ast_print.cpp @@ -226,6 +226,12 @@ { v = a; if (a < b) { a = b; } } #pragma omp atomic compare capture hint(6) { v = a == b; if (v) a = c; } +#pragma omp atomic compare fail(acquire) + { if (a < c) { a = c; } } +#pragma omp atomic compare fail(relaxed) + { if (a < c) { a = c; } } +#pragma omp atomic compare fail(seq_cst) + { if (a < c) { a = c; } } #endif return T(); } @@ -1099,6 +1105,12 @@ { v = a; if (a < b) { a = b; } } #pragma omp atomic compare capture hint(6) { v = a == b; if (v) a = c; } +#pragma omp atomic compare fail(acquire) + if(a < b) { a = b; } +#pragma omp atomic compare fail(relaxed) + if(a < b) { a = b; } +#pragma omp atomic compare fail(seq_cst) + if(a < b) { a = b; } #endif // CHECK-NEXT: #pragma omp atomic // CHECK-NEXT: a++; @@ -1429,6 +1441,18 @@ // CHECK-51-NEXT: if (v) // CHECK-51-NEXT: a = c; // CHECK-51-NEXT: } + // CHECK-51-NEXT: #pragma omp atomic compare fail(acquire) + // CHECK-51-NEXT: if (a < b) { + // CHECK-51-NEXT: a = b; + // CHECK-51-NEXT: } + // CHECK-51-NEXT: #pragma omp atomic compare fail(relaxed) + // CHECK-51-NEXT: if (a < b) { + // CHECK-51-NEXT: a = b; + // CHECK-51-NEXT: } + // CHECK-51-NEXT: #pragma omp atomic compare fail(seq_cst) + // CHECK-51-NEXT: if (a < b) { + // CHECK-51-NEXT: a = b; + // CHECK-51-NEXT: } // expect-note@+1 {{in instantiation of function template specialization 'foo' requested here}} return foo(a); } Index: clang/test/OpenMP/atomic_messages.cpp =================================================================== --- clang/test/OpenMP/atomic_messages.cpp +++ clang/test/OpenMP/atomic_messages.cpp @@ -958,6 +958,24 @@ // expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'capture' clause}} #pragma omp atomic compare compare capture capture { v = a; if (a > b) a = b; } +// expected-error@+1 {{expected 'compare' clause with the 'fail' modifier}} +#pragma omp atomic fail(seq_cst) + if(v == a) { v = a; } +// expected-error@+1 {{expected '(' after 'fail'}} +#pragma omp atomic compare fail + if(v < a) { v = a; } +// expected-error@+1 {{expected a memory order clause}} +#pragma omp atomic compare fail(capture) + if(v < a) { v = a; } + // expected-error@+2 {{expected ')'}} + // expected-note@+1 {{to match this '('}} +#pragma omp atomic compare fail(seq_cst | acquire) + if(v < a) { v = a; } +// expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'fail' clause}} +#pragma omp atomic compare fail(relaxed) fail(seq_cst) + if(v < a) { v = a; } + + #endif // expected-note@+1 {{in instantiation of function template specialization 'mixed' requested here}} return mixed(); Index: clang/tools/libclang/CIndex.cpp =================================================================== --- clang/tools/libclang/CIndex.cpp +++ clang/tools/libclang/CIndex.cpp @@ -2402,6 +2402,8 @@ void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {} +void OMPClauseEnqueue::VisitOMPFailClause(const OMPFailClause *) {} + void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {} void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {} Index: flang/lib/Semantics/check-omp-structure.cpp =================================================================== --- flang/lib/Semantics/check-omp-structure.cpp +++ flang/lib/Semantics/check-omp-structure.cpp @@ -2254,6 +2254,7 @@ CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute) CHECK_SIMPLE_CLAUSE(OmpxBare, OMPC_ompx_bare) CHECK_SIMPLE_CLAUSE(Enter, OMPC_enter) +CHECK_SIMPLE_CLAUSE(Fail, OMPC_fail) CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize) CHECK_REQ_SCALAR_INT_CLAUSE(NumTasks, OMPC_num_tasks) Index: llvm/include/llvm/Frontend/OpenMP/OMP.td =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMP.td +++ llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -209,6 +209,7 @@ def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; } def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; } def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; } +def OMPC_Fail : Clause<"fail"> { let clangClass = "OMPFailClause"; } def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; } def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; } def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; } @@ -640,7 +641,8 @@ VersionedClause, VersionedClause, VersionedClause, - VersionedClause + VersionedClause, + VersionedClause ]; } def OMP_Target : Directive<"target"> {