diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -10974,6 +10974,16 @@ NotScalar, /// Not an integer. NotInteger, + /// Not an equality operator. + NotEQ, + /// Invalid assignment (not v == x). + InvalidAssignment, + /// Not if statement + NotIfStmt, + /// More than two statements in a compund statement. + MoreThanTwoStmts, + /// Not a compound statement. + NotCompoundStmt, /// No error. NoError, }; @@ -10997,7 +11007,7 @@ Expr *getCond() const { return C; } bool isXBinopExpr() const { return IsXBinopExpr; } -private: +protected: /// Reference to ASTContext ASTContext &ContextRef; /// 'x' lvalue part of the source atomic expression. @@ -11024,6 +11034,35 @@ /// Check if all captured values have right type. bool checkType(ErrorInfoTy &ErrorInfo) const; + + static bool CheckValue(const Expr *E, ErrorInfoTy &ErrorInfo, + bool ShouldBeLValue) { + if (ShouldBeLValue && !E->isLValue()) { + ErrorInfo.Error = ErrorTy::XNotLValue; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange(); + return false; + } + + if (!E->isInstantiationDependent()) { + QualType QTy = E->getType(); + if (!QTy->isScalarType()) { + ErrorInfo.Error = ErrorTy::NotScalar; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange(); + return false; + } + + if (!QTy->isIntegerType()) { + ErrorInfo.Error = ErrorTy::NotInteger; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange(); + return false; + } + } + + return true; + } }; bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S, @@ -11206,41 +11245,13 @@ // 'x' and 'e' cannot be nullptr assert(X && E && "X and E cannot be nullptr"); - auto CheckValue = [&ErrorInfo](const Expr *E, bool ShouldBeLValue) { - if (ShouldBeLValue && !E->isLValue()) { - ErrorInfo.Error = ErrorTy::XNotLValue; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange(); - return false; - } - - if (!E->isInstantiationDependent()) { - QualType QTy = E->getType(); - if (!QTy->isScalarType()) { - ErrorInfo.Error = ErrorTy::NotScalar; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange(); - return false; - } - - if (!QTy->isIntegerType()) { - ErrorInfo.Error = ErrorTy::NotInteger; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange(); - return false; - } - } - - return true; - }; - - if (!CheckValue(X, true)) + if (!CheckValue(X, ErrorInfo, true)) return false; - if (!CheckValue(E, false)) + if (!CheckValue(E, ErrorInfo, false)) return false; - if (D && !CheckValue(D, false)) + if (D && !CheckValue(D, ErrorInfo, false)) return false; return true; @@ -11288,6 +11299,262 @@ return checkType(ErrorInfo); } + +class OpenMPAtomicCompareCaptureChecker final + : public OpenMPAtomicCompareChecker { +public: + OpenMPAtomicCompareCaptureChecker(Sema &S) : OpenMPAtomicCompareChecker(S) {} + + Expr *getV() const { return V; } + Expr *getR() const { return R; } + bool isFailOnly() const { return IsFailOnly; } + + /// Check if statement \a S is valid for atomic compare. + bool checkStmt(Stmt *S, ErrorInfoTy &ErrorInfo); + +private: + bool checkType(ErrorInfoTy &ErrorInfo); + + /// Check if it is valid 'if(x == e) { x = d; } else { v = x; }' (form 3) + bool checkForm3(IfStmt *IS, ErrorInfoTy &ErrorInfo); + + /// 'v' lvalue part of the source atomic expression. + Expr *V = nullptr; + /// 'r' lvalue part of the source atomic expression. + Expr *R = nullptr; + /// + bool IsFailOnly = false; +}; + +bool OpenMPAtomicCompareCaptureChecker::checkType(ErrorInfoTy &ErrorInfo) { + if (!OpenMPAtomicCompareChecker::checkType(ErrorInfo)) + return false; + + assert(V && "V cannot be nullptr"); + + if (!CheckValue(V, ErrorInfo, true)) + return false; + + if (R && !CheckValue(R, ErrorInfo, true)) + return false; + + return true; +} + +bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S, + ErrorInfoTy &ErrorInfo) { + IsFailOnly = true; + + auto *Then = S->getThen(); + if (auto *CS = dyn_cast(Then)) { + if (CS->body_empty()) { + ErrorInfo.Error = ErrorTy::NoStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + if (CS->size() > 1) { + ErrorInfo.Error = ErrorTy::MoreThanOneStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange(); + return false; + } + Then = CS->body_front(); + } + + auto *BO = dyn_cast(Then); + if (!BO) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Then->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Then->getSourceRange(); + return false; + } + if (BO->getOpcode() != BO_Assign) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = BO->getExprLoc(); + ErrorInfo.NoteLoc = BO->getOperatorLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange(); + return false; + } + + X = BO->getLHS(); + D = BO->getRHS(); + + auto *Cond = dyn_cast(S->getCond()); + if (!Cond) { + ErrorInfo.Error = ErrorTy::NotABinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + + if (Cond->getOpcode() != BO_EQ) { + ErrorInfo.Error = ErrorTy::NotEQ; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { + E = Cond->getRHS(); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + E = Cond->getLHS(); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + + auto *Else = S->getElse(); + if (auto *CS = dyn_cast(Else)) { + if (CS->body_empty()) { + ErrorInfo.Error = ErrorTy::NoStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + if (CS->size() > 1) { + ErrorInfo.Error = ErrorTy::MoreThanOneStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange(); + return false; + } + Else = CS->body_front(); + } + + auto *ElseBO = dyn_cast(Else); + if (!ElseBO) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Else->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Else->getSourceRange(); + return false; + } + if (ElseBO->getOpcode() != BO_Assign) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = ElseBO->getExprLoc(); + ErrorInfo.NoteLoc = ElseBO->getOperatorLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = ElseBO->getSourceRange(); + return false; + } + + if (!checkIfTwoExprsAreSame(ContextRef, X, ElseBO->getRHS())) { + ErrorInfo.Error = ErrorTy::InvalidAssignment; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = ElseBO->getRHS()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = + ElseBO->getRHS()->getSourceRange(); + return false; + } + + V = ElseBO->getLHS(); + + return checkType(ErrorInfo); +} + +bool OpenMPAtomicCompareCaptureChecker::checkStmt(Stmt *S, + ErrorInfoTy &ErrorInfo) { + // if(x == e) { x = d; } else { v = x; } + if (auto *IS = dyn_cast(S)) + return checkForm3(IS, ErrorInfo); + + if (auto *CS = dyn_cast(S)) { + if (CS->body_empty()) { + ErrorInfo.Error = ErrorTy::NoStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + + // { if(x == e) { x = d; } else { v = x; } } + if (CS->size() == 1) { + auto *IS = dyn_cast(CS->body_front()); + if (!IS) { + ErrorInfo.Error = ErrorTy::NotIfStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + + return checkForm3(IS, ErrorInfo); + } else if (CS->size() == 2) { + auto *S1 = CS->body_front(); + auto *S2 = CS->body_back(); + + Stmt *UpdateStmt = nullptr; + Stmt *CondUpdateStmt = nullptr; + + if (auto *BO = dyn_cast(S1)) { + // { v = x; cond-update-stmt } + UpdateStmt = S1; + CondUpdateStmt = S2; + } else { + // { cond-update-stmt v = x; } + UpdateStmt = S2; + CondUpdateStmt = S1; + } + + auto CheckUpdateStmt = [this, &ErrorInfo](Stmt* US) { + auto *BO = dyn_cast(US); + if (!BO) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = US->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = US->getSourceRange(); + return false; + } + + if (BO->getOpcode() != BO_Assign) { + ErrorInfo.Error = ErrorTy::NotAnAssignment; + ErrorInfo.ErrorLoc = BO->getExprLoc(); + ErrorInfo.NoteLoc = BO->getOperatorLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange(); + return false; + } + + if (this->X != BO->getRHS()) { + ErrorInfo.Error = ErrorTy::InvalidAssignment; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = BO->getRHS()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = + BO->getRHS()->getSourceRange(); + return false; + } + + this->V = BO->getLHS(); + + return true; + }; + + auto CheckCondUpdateStmt = [this, &ErrorInfo](Stmt *CUS) { + auto *IS = dyn_cast(CUS); + if (!IS) { + ErrorInfo.Error = ErrorTy::NotIfStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CUS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CUS->getSourceRange(); + return false; + } + + if (!checkCondUpdateStmt(IS, ErrorInfo)) + return false; + }; + + if (!CheckUpdateStmt(UpdateStmt)) + return false; + if (!CheckCondUpdateStmt(CondUpdateStmt)) + return false; + } else { + ErrorInfo.Error = ErrorTy::MoreThanTwoStmts; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange(); + return false; + } + + return checkType(ErrorInfo); + } + + ErrorInfo.Error = ErrorTy::NotCompoundStmt; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getBeginLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange(); + return false; +} } // namespace StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef Clauses, @@ -11785,6 +12052,15 @@ UE = V = E = X = nullptr; } else if (AtomicKind == OMPC_compare) { if (IsCompareCapture) { + OpenMPAtomicCompareCaptureChecker::ErrorInfoTy ErrorInfo; + OpenMPAtomicCompareCaptureChecker Checker(*this); + if (!Checker.checkStmt(Body, ErrorInfo)) { + Diag(ErrorInfo.ErrorLoc, diag::err_omp_atomic_compare) + << ErrorInfo.ErrorRange; + Diag(ErrorInfo.NoteLoc, diag::note_omp_atomic_compare) + << ErrorInfo.Error << ErrorInfo.NoteRange; + return StmtError(); + } // TODO: We don't set X, D, E, etc. here because in code gen we will emit // error directly. } else {