diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -10974,6 +10974,16 @@
NotScalar,
/// Not an integer.
NotInteger,
+ /// Not an equality operator.
+ NotEQ,
+ /// Invalid assignment (not v == x).
+ InvalidAssignment,
+ /// Not if statement
+ NotIfStmt,
+ /// More than two statements in a compund statement.
+ MoreThanTwoStmts,
+ /// Not a compound statement.
+ NotCompoundStmt,
/// No error.
NoError,
};
@@ -10997,7 +11007,7 @@
Expr *getCond() const { return C; }
bool isXBinopExpr() const { return IsXBinopExpr; }
-private:
+protected:
/// Reference to ASTContext
ASTContext &ContextRef;
/// 'x' lvalue part of the source atomic expression.
@@ -11024,6 +11034,35 @@
/// Check if all captured values have right type.
bool checkType(ErrorInfoTy &ErrorInfo) const;
+
+ static bool CheckValue(const Expr *E, ErrorInfoTy &ErrorInfo,
+ bool ShouldBeLValue) {
+ if (ShouldBeLValue && !E->isLValue()) {
+ ErrorInfo.Error = ErrorTy::XNotLValue;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
+ return false;
+ }
+
+ if (!E->isInstantiationDependent()) {
+ QualType QTy = E->getType();
+ if (!QTy->isScalarType()) {
+ ErrorInfo.Error = ErrorTy::NotScalar;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
+ return false;
+ }
+
+ if (!QTy->isIntegerType()) {
+ ErrorInfo.Error = ErrorTy::NotInteger;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
+ return false;
+ }
+ }
+
+ return true;
+ }
};
bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
@@ -11206,41 +11245,13 @@
// 'x' and 'e' cannot be nullptr
assert(X && E && "X and E cannot be nullptr");
- auto CheckValue = [&ErrorInfo](const Expr *E, bool ShouldBeLValue) {
- if (ShouldBeLValue && !E->isLValue()) {
- ErrorInfo.Error = ErrorTy::XNotLValue;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
- return false;
- }
-
- if (!E->isInstantiationDependent()) {
- QualType QTy = E->getType();
- if (!QTy->isScalarType()) {
- ErrorInfo.Error = ErrorTy::NotScalar;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
- return false;
- }
-
- if (!QTy->isIntegerType()) {
- ErrorInfo.Error = ErrorTy::NotInteger;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
- return false;
- }
- }
-
- return true;
- };
-
- if (!CheckValue(X, true))
+ if (!CheckValue(X, ErrorInfo, true))
return false;
- if (!CheckValue(E, false))
+ if (!CheckValue(E, ErrorInfo, false))
return false;
- if (D && !CheckValue(D, false))
+ if (D && !CheckValue(D, ErrorInfo, false))
return false;
return true;
@@ -11288,6 +11299,262 @@
return checkType(ErrorInfo);
}
+
+class OpenMPAtomicCompareCaptureChecker final
+ : public OpenMPAtomicCompareChecker {
+public:
+ OpenMPAtomicCompareCaptureChecker(Sema &S) : OpenMPAtomicCompareChecker(S) {}
+
+ Expr *getV() const { return V; }
+ Expr *getR() const { return R; }
+ bool isFailOnly() const { return IsFailOnly; }
+
+ /// Check if statement \a S is valid for atomic compare.
+ bool checkStmt(Stmt *S, ErrorInfoTy &ErrorInfo);
+
+private:
+ bool checkType(ErrorInfoTy &ErrorInfo);
+
+ /// Check if it is valid 'if(x == e) { x = d; } else { v = x; }' (form 3)
+ bool checkForm3(IfStmt *IS, ErrorInfoTy &ErrorInfo);
+
+ /// 'v' lvalue part of the source atomic expression.
+ Expr *V = nullptr;
+ /// 'r' lvalue part of the source atomic expression.
+ Expr *R = nullptr;
+ ///
+ bool IsFailOnly = false;
+};
+
+bool OpenMPAtomicCompareCaptureChecker::checkType(ErrorInfoTy &ErrorInfo) {
+ if (!OpenMPAtomicCompareChecker::checkType(ErrorInfo))
+ return false;
+
+ assert(V && "V cannot be nullptr");
+
+ if (!CheckValue(V, ErrorInfo, true))
+ return false;
+
+ if (R && !CheckValue(R, ErrorInfo, true))
+ return false;
+
+ return true;
+}
+
+bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
+ ErrorInfoTy &ErrorInfo) {
+ IsFailOnly = true;
+
+ auto *Then = S->getThen();
+ if (auto *CS = dyn_cast(Then)) {
+ if (CS->body_empty()) {
+ ErrorInfo.Error = ErrorTy::NoStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+ return false;
+ }
+ if (CS->size() > 1) {
+ ErrorInfo.Error = ErrorTy::MoreThanOneStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange();
+ return false;
+ }
+ Then = CS->body_front();
+ }
+
+ auto *BO = dyn_cast(Then);
+ if (!BO) {
+ ErrorInfo.Error = ErrorTy::NotAnAssignment;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Then->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Then->getSourceRange();
+ return false;
+ }
+ if (BO->getOpcode() != BO_Assign) {
+ ErrorInfo.Error = ErrorTy::NotAnAssignment;
+ ErrorInfo.ErrorLoc = BO->getExprLoc();
+ ErrorInfo.NoteLoc = BO->getOperatorLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange();
+ return false;
+ }
+
+ X = BO->getLHS();
+ D = BO->getRHS();
+
+ auto *Cond = dyn_cast(S->getCond());
+ if (!Cond) {
+ ErrorInfo.Error = ErrorTy::NotABinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+
+ if (Cond->getOpcode() != BO_EQ) {
+ ErrorInfo.Error = ErrorTy::NotEQ;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+ E = Cond->getRHS();
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ E = Cond->getLHS();
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+
+ auto *Else = S->getElse();
+ if (auto *CS = dyn_cast(Else)) {
+ if (CS->body_empty()) {
+ ErrorInfo.Error = ErrorTy::NoStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+ return false;
+ }
+ if (CS->size() > 1) {
+ ErrorInfo.Error = ErrorTy::MoreThanOneStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange();
+ return false;
+ }
+ Else = CS->body_front();
+ }
+
+ auto *ElseBO = dyn_cast(Else);
+ if (!ElseBO) {
+ ErrorInfo.Error = ErrorTy::NotAnAssignment;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Else->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Else->getSourceRange();
+ return false;
+ }
+ if (ElseBO->getOpcode() != BO_Assign) {
+ ErrorInfo.Error = ErrorTy::NotAnAssignment;
+ ErrorInfo.ErrorLoc = ElseBO->getExprLoc();
+ ErrorInfo.NoteLoc = ElseBO->getOperatorLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = ElseBO->getSourceRange();
+ return false;
+ }
+
+ if (!checkIfTwoExprsAreSame(ContextRef, X, ElseBO->getRHS())) {
+ ErrorInfo.Error = ErrorTy::InvalidAssignment;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = ElseBO->getRHS()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+ ElseBO->getRHS()->getSourceRange();
+ return false;
+ }
+
+ V = ElseBO->getLHS();
+
+ return checkType(ErrorInfo);
+}
+
+bool OpenMPAtomicCompareCaptureChecker::checkStmt(Stmt *S,
+ ErrorInfoTy &ErrorInfo) {
+ // if(x == e) { x = d; } else { v = x; }
+ if (auto *IS = dyn_cast(S))
+ return checkForm3(IS, ErrorInfo);
+
+ if (auto *CS = dyn_cast(S)) {
+ if (CS->body_empty()) {
+ ErrorInfo.Error = ErrorTy::NoStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+ return false;
+ }
+
+ // { if(x == e) { x = d; } else { v = x; } }
+ if (CS->size() == 1) {
+ auto *IS = dyn_cast(CS->body_front());
+ if (!IS) {
+ ErrorInfo.Error = ErrorTy::NotIfStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+ return false;
+ }
+
+ return checkForm3(IS, ErrorInfo);
+ } else if (CS->size() == 2) {
+ auto *S1 = CS->body_front();
+ auto *S2 = CS->body_back();
+
+ Stmt *UpdateStmt = nullptr;
+ Stmt *CondUpdateStmt = nullptr;
+
+ if (auto *BO = dyn_cast(S1)) {
+ // { v = x; cond-update-stmt }
+ UpdateStmt = S1;
+ CondUpdateStmt = S2;
+ } else {
+ // { cond-update-stmt v = x; }
+ UpdateStmt = S2;
+ CondUpdateStmt = S1;
+ }
+
+ auto CheckUpdateStmt = [this, &ErrorInfo](Stmt* US) {
+ auto *BO = dyn_cast(US);
+ if (!BO) {
+ ErrorInfo.Error = ErrorTy::NotAnAssignment;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = US->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = US->getSourceRange();
+ return false;
+ }
+
+ if (BO->getOpcode() != BO_Assign) {
+ ErrorInfo.Error = ErrorTy::NotAnAssignment;
+ ErrorInfo.ErrorLoc = BO->getExprLoc();
+ ErrorInfo.NoteLoc = BO->getOperatorLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange();
+ return false;
+ }
+
+ if (this->X != BO->getRHS()) {
+ ErrorInfo.Error = ErrorTy::InvalidAssignment;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = BO->getRHS()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+ BO->getRHS()->getSourceRange();
+ return false;
+ }
+
+ this->V = BO->getLHS();
+
+ return true;
+ };
+
+ auto CheckCondUpdateStmt = [this, &ErrorInfo](Stmt *CUS) {
+ auto *IS = dyn_cast(CUS);
+ if (!IS) {
+ ErrorInfo.Error = ErrorTy::NotIfStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CUS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CUS->getSourceRange();
+ return false;
+ }
+
+ if (!checkCondUpdateStmt(IS, ErrorInfo))
+ return false;
+ };
+
+ if (!CheckUpdateStmt(UpdateStmt))
+ return false;
+ if (!CheckCondUpdateStmt(CondUpdateStmt))
+ return false;
+ } else {
+ ErrorInfo.Error = ErrorTy::MoreThanTwoStmts;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+ return false;
+ }
+
+ return checkType(ErrorInfo);
+ }
+
+ ErrorInfo.Error = ErrorTy::NotCompoundStmt;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getBeginLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange();
+ return false;
+}
} // namespace
StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef Clauses,
@@ -11785,6 +12052,15 @@
UE = V = E = X = nullptr;
} else if (AtomicKind == OMPC_compare) {
if (IsCompareCapture) {
+ OpenMPAtomicCompareCaptureChecker::ErrorInfoTy ErrorInfo;
+ OpenMPAtomicCompareCaptureChecker Checker(*this);
+ if (!Checker.checkStmt(Body, ErrorInfo)) {
+ Diag(ErrorInfo.ErrorLoc, diag::err_omp_atomic_compare)
+ << ErrorInfo.ErrorRange;
+ Diag(ErrorInfo.NoteLoc, diag::note_omp_atomic_compare)
+ << ErrorInfo.Error << ErrorInfo.NoteRange;
+ return StmtError();
+ }
// TODO: We don't set X, D, E, etc. here because in code gen we will emit
// error directly.
} else {