diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -2842,6 +2842,9 @@ /// This field is true for the first(postfix) form of the expression and false /// otherwise. bool IsPostfixUpdate = false; + /// True if 'v' is updated only when the condition is false (compare capture + /// only). + bool IsFailOnly = false; /// Build directive with the given start and end location. /// @@ -2865,6 +2868,7 @@ POS_UpdateExpr, POS_D, POS_Cond, + POS_R, }; /// Set 'x' part of the associated expression/statement. @@ -2877,6 +2881,8 @@ } /// Set 'v' part of the associated expression/statement. void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; } + /// Set 'r' part of the associated expression/statement. + void setR(Expr *R) { Data->getChildren()[DataPositionTy::POS_R] = R; } /// Set 'expr' part of the associated expression/statement. void setExpr(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; } /// Set 'd' part of the associated expression/statement. @@ -2896,6 +2902,7 @@ /// \param AssociatedStmt Statement, associated with the directive. /// \param X 'x' part of the associated expression/statement. /// \param V 'v' part of the associated expression/statement. + /// \param R 'r' part of the associated expression/statement. /// \param E 'expr' part of the associated expression/statement. /// \param UE Helper expression of the form /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or @@ -2909,8 +2916,8 @@ static OMPAtomicDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V, - Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart, - bool IsPostfixUpdate); + Expr *R, Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart, + bool IsPostfixUpdate, bool IsFailOnly); /// Creates an empty directive with the place for \a NumClauses /// clauses. @@ -2943,6 +2950,9 @@ /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' and false if it has form /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'. bool isXLHSInRHSPart() const { return IsXLHSInRHSPart; } + /// Return true if 'v' is updated only when the condition is evaluated false + /// (compare capture only). + bool isFailOnly() const { return IsFailOnly; } /// Return true if 'v' expression must be updated to original value of /// 'x', false if 'v' must be updated to the new value of 'x'. bool isPostfixUpdate() const { return IsPostfixUpdate; } @@ -2953,6 +2963,13 @@ const Expr *getV() const { return cast_or_null(Data->getChildren()[DataPositionTy::POS_V]); } + /// Get 'r' part of the associated expression/statement. + Expr *getR() { + return cast_or_null(Data->getChildren()[DataPositionTy::POS_R]); + } + const Expr *getR() const { + return cast_or_null(Data->getChildren()[DataPositionTy::POS_R]); + } /// Get 'expr' part of the associated expression/statement. Expr *getExpr() { return cast_or_null(Data->getChildren()[DataPositionTy::POS_E]); diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -863,22 +863,23 @@ !IsStandalone); } -OMPAtomicDirective * -OMPAtomicDirective::Create(const ASTContext &C, SourceLocation StartLoc, - SourceLocation EndLoc, ArrayRef Clauses, - Stmt *AssociatedStmt, Expr *X, Expr *V, Expr *E, - Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart, - bool IsPostfixUpdate) { +OMPAtomicDirective *OMPAtomicDirective::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V, + Expr *R, Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart, + bool IsPostfixUpdate, bool IsFailOnly) { auto *Dir = createDirective( - C, Clauses, AssociatedStmt, /*NumChildren=*/6, StartLoc, EndLoc); + C, Clauses, AssociatedStmt, /*NumChildren=*/7, StartLoc, EndLoc); Dir->setX(X); Dir->setV(V); + Dir->setR(R); Dir->setExpr(E); Dir->setUpdateExpr(UE); Dir->setD(D); Dir->setCond(Cond); Dir->IsXLHSInRHSPart = IsXLHSInRHSPart; Dir->IsPostfixUpdate = IsPostfixUpdate; + Dir->IsFailOnly = IsFailOnly; return Dir; } @@ -886,7 +887,7 @@ unsigned NumClauses, EmptyShell) { return createEmptyDirective( - C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/6); + C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/7); } OMPTargetDirective *OMPTargetDirective::Create(const ASTContext &C, diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -6019,8 +6019,10 @@ static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF, llvm::AtomicOrdering AO, const Expr *X, + const Expr *V, const Expr *R, const Expr *E, const Expr *D, const Expr *CE, bool IsXBinopExpr, + bool IsPostfixUpdate, bool IsFailOnly, SourceLocation Loc) { llvm::OpenMPIRBuilder &OMPBuilder = CGF.CGM.getOpenMPRuntime().getOMPBuilder(); @@ -6050,17 +6052,26 @@ XPtr, XPtr->getType()->getPointerElementType(), X->getType().isVolatileQualified(), X->getType()->hasSignedIntegerRepresentation()}; + llvm::OpenMPIRBuilder::AtomicOpValue VOpVal{ + XPtr, XPtr->getType()->getPointerElementType(), + X->getType().isVolatileQualified(), + X->getType()->hasSignedIntegerRepresentation()}; + llvm::OpenMPIRBuilder::AtomicOpValue ROpVal{ + XPtr, XPtr->getType()->getPointerElementType(), + X->getType().isVolatileQualified(), + X->getType()->hasSignedIntegerRepresentation()}; CGF.Builder.restoreIP(OMPBuilder.createAtomicCompare( - CGF.Builder, XOpVal, EVal, DVal, AO, Op, IsXBinopExpr)); + CGF.Builder, XOpVal, VOpVal, ROpVal, EVal, DVal, AO, Op, IsXBinopExpr, + IsPostfixUpdate, IsFailOnly)); } static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind, llvm::AtomicOrdering AO, bool IsPostfixUpdate, - const Expr *X, const Expr *V, const Expr *E, - const Expr *UE, const Expr *D, const Expr *CE, - bool IsXLHSInRHSPart, bool IsCompareCapture, - SourceLocation Loc) { + const Expr *X, const Expr *V, const Expr *R, + const Expr *E, const Expr *UE, const Expr *D, + const Expr *CE, bool IsXLHSInRHSPart, + bool IsFailOnly, SourceLocation Loc) { switch (Kind) { case OMPC_read: emitOMPAtomicReadExpr(CGF, AO, X, V, Loc); @@ -6077,15 +6088,8 @@ IsXLHSInRHSPart, Loc); break; case OMPC_compare: { - if (IsCompareCapture) { - // Emit an error here. - unsigned DiagID = CGF.CGM.getDiags().getCustomDiagID( - DiagnosticsEngine::Error, - "'atomic compare capture' is not supported for now"); - CGF.CGM.getDiags().Report(DiagID); - } else { - emitOMPAtomicCompareExpr(CGF, AO, X, E, D, CE, IsXLHSInRHSPart, Loc); - } + emitOMPAtomicCompareExpr(CGF, AO, X, V, R, E, D, CE, IsXLHSInRHSPart, + IsPostfixUpdate, IsFailOnly, Loc); break; } case OMPC_if: @@ -6210,12 +6214,12 @@ Kind = K; KindsEncountered.insert(K); } - bool IsCompareCapture = false; + // We just need to correct Kind here. No need to set a bool saying it is + // actually compare capture because we can tell from whether V and R are + // nullptr. if (KindsEncountered.contains(OMPC_compare) && - KindsEncountered.contains(OMPC_capture)) { - IsCompareCapture = true; + KindsEncountered.contains(OMPC_capture)) Kind = OMPC_compare; - } if (!MemOrderingSpecified) { llvm::AtomicOrdering DefaultOrder = CGM.getOpenMPRuntime().getDefaultMemoryOrdering(); @@ -6237,8 +6241,9 @@ LexicalScope Scope(*this, S.getSourceRange()); EmitStopPoint(S.getAssociatedStmt()); emitOMPAtomicExpr(*this, Kind, AO, S.isPostfixUpdate(), S.getX(), S.getV(), - S.getExpr(), S.getUpdateExpr(), S.getD(), S.getCondExpr(), - S.isXLHSInRHSPart(), IsCompareCapture, S.getBeginLoc()); + S.getR(), S.getExpr(), S.getUpdateExpr(), S.getD(), + S.getCondExpr(), S.isXLHSInRHSPart(), S.isFailOnly(), + S.getBeginLoc()); } static void emitCommonOMPTargetDirective(CodeGenFunction &CGF, diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -11845,8 +11845,10 @@ Expr *UE = nullptr; Expr *D = nullptr; Expr *CE = nullptr; + Expr *R = nullptr; bool IsXLHSInRHSPart = false; bool IsPostfixUpdate = false; + bool IsFailOnly = false; // OpenMP [2.12.6, atomic Construct] // In the next expressions: // * x and v (as applicable) are both l-value expressions with scalar type. @@ -12242,8 +12244,15 @@ << ErrorInfo.Error << ErrorInfo.NoteRange; return StmtError(); } - // TODO: We don't set X, D, E, etc. here because in code gen we will emit - // error directly. + X = Checker.getX(); + E = Checker.getE(); + D = Checker.getD(); + CE = Checker.getCond(); + V = Checker.getV(); + R = Checker.getR(); + // We reuse IsXLHSInRHSPart to tell if it is in the form 'x ordop expr'. + IsXLHSInRHSPart = Checker.isXBinopExpr(); + IsFailOnly = Checker.isFailOnly(); } else { OpenMPAtomicCompareChecker::ErrorInfoTy ErrorInfo; OpenMPAtomicCompareChecker Checker(*this); @@ -12266,8 +12275,8 @@ setFunctionHasBranchProtectedScope(); return OMPAtomicDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - X, V, E, UE, D, CE, IsXLHSInRHSPart, - IsPostfixUpdate); + X, V, R, E, UE, D, CE, IsXLHSInRHSPart, + IsPostfixUpdate, IsFailOnly); } StmtResult Sema::ActOnOpenMPTargetDirective(ArrayRef Clauses,