diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10508,6 +10508,15 @@ " where x is an lvalue expression with scalar type">; def note_omp_atomic_capture: Note< "%select{expected assignment expression|expected compound statement|expected exactly two expression statements|expected in right hand side of the first expression}0">; +def err_omp_atomic_compare : Error< + "the statement for 'atomic compare' must be a compound statement of form '{x = expr ordop x ? expr : x;}', '{x = x ordop expr? expr : x;}'," + " '{x = x == e ? d : x;}', '{x = e == x ? d : x;}', or 'if(expr ordop x) {x = expr;}', 'if(x ordop expr) {x = expr;}', 'if(x == e) {x = d;}'," + " 'if(e == x) {x = d;}' where 'x' is an lvalue expression with scalar type, 'expr', 'e', and 'd' are expressions with scalar type," + " and 'ordop' is one of '<' or '>'.">; +def note_omp_atomic_compare: Note< + "%select{expected compound statement|expected exactly one expression statement|expected assignment statement|expected conditional operator|expect result value to be at false expression|" + "expect binary operator in conditional expression|expect '<', '>' or '==' as order operator|expect comparison in a form of 'x == e', 'e == x', 'x ordop expr', or 'expr ordop x'|" + "expect lvalue for result value|expect scalar integer value}0">; def err_omp_atomic_several_clauses : Error< "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">; def err_omp_several_mem_order_clauses : Error< diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -10916,6 +10916,367 @@ } return ErrorFound != NoError; } + +/// Get the node id of the fixed point of an expression \a S. +llvm::FoldingSetNodeID getNodeId(ASTContext &Context, const Expr *S) { + llvm::FoldingSetNodeID Id; + S->IgnoreParenImpCasts()->Profile(Id, Context, true); + return Id; +} + +/// Check if two expressions are same. +bool checkIfTwoExprsAreSame(ASTContext &Context, const Expr *LHS, + const Expr *RHS) { + return getNodeId(Context, LHS) == getNodeId(Context, RHS); +} + +Expr *generateLHSCondExpr(Sema &SemaRef, BinaryOperator *Cond, Expr *X) { + if (!Cond) + return nullptr; + if (Cond->getOpcode() != BO_LT && Cond->getOpcode() != BO_GT) + return nullptr; + if (Cond->getOpcode() == BO_EQ) + return Cond; + if (checkIfTwoExprsAreSame(SemaRef.getASTContext(), X, Cond->getLHS())) + return Cond; + + BinaryOperatorKind Op = Cond->getOpcode() == BO_LT ? BO_GT : BO_LT; + ExprResult Result = + SemaRef.CreateBuiltinBinOp(Cond->getOperatorLoc(), Op, X, Cond->getLHS()); + assert(!Result.isInvalid()); + return Result.get(); +} + +class OpenMPAtomicCompareChecker { +public: + /// All kinds of errors that can occur in `atomic compare` + enum ErrorTy { + /// Empty compound statement. + NoStmt = 0, + /// More than one statement in a compound statement. + MoreThanOneStmt, + /// Not an assignment binary operator. + NotAnAssignment, + /// Not a conditional operator. + NotCondOp, + /// Wrong false expr. According to the spec, 'x' should be at the false + /// expression of a conditional expression. + WrongFalseExpr, + /// The condition of a conditional expression is not a binary operator. + NotABinaryOp, + /// Invalid binary operator (not <, >, or ==). + InvalidBinaryOp, + /// Invalid comparison (not x == e, e == x, x ordop expr, or expr ordop x). + InvalidComparison, + /// X is not a lvalue. + XNotLValue, + /// Not a scalar integer. + NotScalarInteger, + /// No error. + NoError, + }; + + struct ErrorInfoTy { + ErrorTy Error; + SourceLocation ErrorLoc; + SourceRange ErrorRange; + SourceLocation NoteLoc; + SourceRange NoteRange; + }; + + OpenMPAtomicCompareChecker(Sema &S) + : SemaRef(S), ContextRef(S.getASTContext()) {} + + /// Check if statement \a S is valid for atomic compare. + bool checkStmt(Stmt *S, ErrorInfoTy &ErrorInfo); + +private: + /// Reference to Sema. + Sema &SemaRef; + /// Reference to ASTContext + ASTContext &ContextRef; + /// 'x' lvalue part of the source atomic expression. + Expr *X = nullptr; + /// 'expr' or 'e' rvalue part of the source atomic expression. + Expr *E = nullptr; + /// 'd' rvalue part of the source atomic expression. + Expr *D = nullptr; + /// 'cond' part of the source atomic expression. It is in one of the following + /// forms: + /// expr ordop x + /// x ordop expr + /// x == e + /// e == x + /// If 'x' is on RHS, a corresponding LHS version will be generated for + /// convenience by calling function \p generateLHSCondExpr. + /// 'ordop' can only be '<' or '>'. + Expr *C = nullptr; + + /// Check if it is a valid conditional update statement (cond-update-stmt). + bool checkCondUpdateStmt(IfStmt *S, ErrorInfoTy &ErrorInfo); + + /// Check if it is a valid conditional expression statement (cond-expr-stmt). + bool checkCondExprStmt(Stmt *S, ErrorInfoTy &ErrorInfo); + + /// Check if all captured values have right type. + bool checkType(ErrorInfoTy &ErrorInfo) const; +}; + +bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S, + ErrorInfoTy &ErrorInfo) { + auto *Then = S->getThen(); + if (auto *CS = dyn_cast(Then)) { + if (CS->size() == 0) { + ErrorInfo.Error = ErrorTy::NoStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + if (CS->size() > 1) { + ErrorInfo.Error = ErrorTy::MoreThanOneStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange(); + return false; + } + Then = CS->body_front(); + } + + auto *BO = dyn_cast(Then); + if (!BO) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Then->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Then->getSourceRange(); + return false; + } + if (BO->getOpcode() != BO_Assign) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = BO->getExprLoc(); + ErrorInfo.NoteLoc = BO->getOperatorLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange(); + return false; + } + + X = BO->getLHS(); + + auto *Cond = dyn_cast(S->getCond()); + if (!Cond) { + ErrorInfo.Error = ErrorTy::NotABinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); + return false; + } + if (Cond->getOpcode() != BO_EQ && Cond->getOpcode() != BO_LT && + Cond->getOpcode() != BO_GT) { + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + + if (Cond->getOpcode() == BO_EQ) { + C = Cond; + D = BO->getRHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { + E = Cond->getRHS(); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + E = Cond->getLHS(); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + } else { + E = BO->getRHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { + C = Cond; + } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + // 'x' is on RHS. Create a LHS version. + C = generateLHSCondExpr(SemaRef, Cond, X); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + } + + return true; +} + +bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S, + ErrorInfoTy &ErrorInfo) { + auto *BO = dyn_cast(S); + if (!BO) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange(); + return false; + } + if (BO->getOpcode() != BO_Assign) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = BO->getExprLoc(); + ErrorInfo.NoteLoc = BO->getOperatorLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange(); + return false; + } + + X = BO->getLHS(); + + auto *CO = dyn_cast(BO->getRHS()->IgnoreParenImpCasts()); + if (!CO) { + ErrorInfo.Error = ErrorTy::NotCondOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = BO->getRHS()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getRHS()->getSourceRange(); + return false; + } + + if (!checkIfTwoExprsAreSame(ContextRef, X, CO->getFalseExpr())) { + ErrorInfo.Error = ErrorTy::WrongFalseExpr; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getFalseExpr()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = + CO->getFalseExpr()->getSourceRange(); + return false; + } + + auto *Cond = dyn_cast(CO->getCond()); + if (!Cond) { + ErrorInfo.Error = ErrorTy::NotABinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = + CO->getCond()->getSourceRange(); + return false; + } + + if (Cond->getOpcode() != BO_EQ && Cond->getOpcode() != BO_LT && + Cond->getOpcode() != BO_GT) { + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + + if (Cond->getOpcode() == BO_EQ) { + C = Cond; + D = CO->getTrueExpr(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { + E = Cond->getRHS(); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + E = Cond->getLHS(); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + } else { + E = CO->getTrueExpr(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { + C = Cond; + } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + // 'x' is on RHS. Create a LHS version. + C = generateLHSCondExpr(SemaRef, Cond, X); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + } + + return true; +} + +bool OpenMPAtomicCompareChecker::checkType(ErrorInfoTy &ErrorInfo) const { + // 'x' and 'e' cannot be nullptr + assert(X && E && "X and E cannot be nullptr"); + + { + if (!X->isLValue()) { + ErrorInfo.Error = ErrorTy::XNotLValue; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = X->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = X->getSourceRange(); + return false; + } + + auto Type = X->getType(); + if (!Type->isScalarType() || !Type->isIntegerType()) { + ErrorInfo.Error = ErrorTy::NotScalarInteger; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = X->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = X->getSourceRange(); + return false; + } + } + + { + auto Type = E->getType(); + if (!Type->isScalarType() || !Type->isIntegerType()) { + ErrorInfo.Error = ErrorTy::NotScalarInteger; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange(); + return false; + } + } + + if (D) { + auto Type = D->getType(); + if (!Type->isScalarType() || !Type->isIntegerType()) { + ErrorInfo.Error = ErrorTy::NotScalarInteger; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = D->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = D->getSourceRange(); + return false; + } + } + + return true; +} + +bool OpenMPAtomicCompareChecker::checkStmt( + Stmt *S, OpenMPAtomicCompareChecker::ErrorInfoTy &ErrorInfo) { + auto *CS = dyn_cast(S); + if (CS) { + if (CS->size() == 0) { + ErrorInfo.Error = ErrorTy::NoStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + + if (CS->size() != 1) { + ErrorInfo.Error = ErrorTy::MoreThanOneStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + S = CS->body_front(); + } + + auto Res = false; + + if (auto *IS = dyn_cast(S)) { + // Check if the statement is in one of the following forms + // (cond-update-stmt): + // if (expr ordop x) { x = expr; } + // if (x ordop expr) { x = expr; } + // if (x == e) { x = d; } + Res = checkCondUpdateStmt(IS, ErrorInfo); + } else { + // Check if the statement is in one of the following forms (cond-expr-stmt): + // x = expr ordop x ? expr : x; + // x = x ordop expr ? expr : x; + // x = x == e ? d : x; + Res = checkCondExprStmt(S, ErrorInfo); + } + + if (!Res) + return false; + + return checkType(ErrorInfo); +} } // namespace StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef Clauses, @@ -11396,6 +11757,15 @@ if (CurContext->isDependentContext()) UE = V = E = X = nullptr; } else if (AtomicKind == OMPC_compare) { + OpenMPAtomicCompareChecker::ErrorInfoTy ErrorInfo; + OpenMPAtomicCompareChecker Checker(*this); + if (!Checker.checkStmt(Body, ErrorInfo)) { + Diag(ErrorInfo.ErrorLoc, diag::err_omp_atomic_compare) + << ErrorInfo.ErrorRange; + Diag(ErrorInfo.NoteLoc, diag::note_omp_atomic_compare) + << ErrorInfo.Error << ErrorInfo.NoteRange; + return StmtError(); + } // TODO: For now we emit an error here and in emitOMPAtomicExpr we ignore // code gen. unsigned DiagID = Diags.getCustomDiagID(