diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -3712,6 +3712,12 @@ std::tuple t; + enum class StmtType { + CaptureStmt, + WriteStmt, + UpdateStmt + } Stmt1Type, + Stmt2Type; }; // ATOMIC diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -238,6 +238,7 @@ void CheckAtomicUpdateStmt(const parser::AssignmentStmt &); void CheckAtomicCaptureStmt(const parser::AssignmentStmt &); void CheckAtomicWriteStmt(const parser::AssignmentStmt &); + void CheckAtomicCaptureConstruct(const parser::OmpAtomicCapture &); void CheckAtomicConstructStructure(const parser::OpenMPAtomicConstruct &); void CheckDistLinear(const parser::OpenMPLoopConstruct &x); void CheckSIMDNest(const parser::OpenMPConstruct &x); @@ -283,6 +284,10 @@ LastType }; int directiveNest_[LastType + 1] = {0}; + + // Flags used exclusively for atomic capture semantic checks + bool shouldSuppressOutput = false; // check semantics without reporting output + bool isValidConstruct; // if the construct is valid or not }; } // namespace Fortran::semantics #endif // FORTRAN_SEMANTICS_CHECK_OMP_STRUCTURE_H_ diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -1508,11 +1508,14 @@ std::get_if(&designator.value().u); const Fortran::parser::Name *name = dataRef ? std::get_if(&dataRef->u) : nullptr; - if (name) - context_.Say(name->source, - "%s must not have ALLOCATABLE " - "attribute"_err_en_US, - name->ToString()); + if (name) { + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(name->source, + "%s must not have ALLOCATABLE " + "attribute"_err_en_US, + name->ToString()); + } } } @@ -1526,11 +1529,13 @@ const Symbol &varSymbol = evaluate::GetSymbolVector(*v).front(); for (const Symbol &symbol : evaluate::GetSymbolVector(*e)) { if (varSymbol == symbol) { - context_.Say(expr.source, - "RHS expression " - "on atomic assignment statement" - " cannot access '%s'"_err_en_US, - var.GetSource().ToString()); + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(expr.source, + "RHS expression " + "on atomic assignment statement" + " cannot access '%s'"_err_en_US, + var.GetSource().ToString()); } } } @@ -1543,16 +1548,22 @@ const auto *e{GetExpr(context_, expr)}; const auto *v{GetExpr(context_, var)}; if (e && v) { - if (e->Rank() != 0) - context_.Say(expr.source, - "Expected scalar expression " - "on the RHS of atomic assignment " - "statement"_err_en_US); - if (v->Rank() != 0) - context_.Say(var.GetSource(), - "Expected scalar variable " - "on the LHS of atomic assignment " - "statement"_err_en_US); + if (e->Rank() != 0) { + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(expr.source, + "Expected scalar expression " + "on the RHS of atomic assignment " + "statement"_err_en_US); + } + if (v->Rank() != 0) { + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(var.GetSource(), + "Expected scalar variable " + "on the LHS of atomic assignment " + "statement"_err_en_US); + } } } @@ -1575,10 +1586,12 @@ const auto &exprRight{std::get<1>(node.t)}; if ((exprLeft.value().source.ToString() != variableName) && (exprRight.value().source.ToString() != variableName)) { - context_.Say(variable.GetSource(), - "Atomic update statement should be of form " - "`%s = %s operator expr` OR `%s = expr operator %s`"_err_en_US, - variableName, variableName, variableName, variableName); + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(variable.GetSource(), + "Atomic update statement should be of form " + "`%s = %s operator expr` OR `%s = expr operator %s`"_err_en_US, + variableName, variableName, variableName, variableName); } return common::HasMember; } @@ -1589,6 +1602,7 @@ const parser::AssignmentStmt &assignmentStmt) { const auto &var{std::get(assignmentStmt.t)}; const auto &expr{std::get(assignmentStmt.t)}; + isValidConstruct = true; common::visit( common::visitors{ [&](const common::Indirection &designator) { @@ -1597,18 +1611,23 @@ const auto *name = dataRef ? std::get_if(&dataRef->u) : nullptr; - if (name && IsAllocatable(*name->symbol)) - context_.Say(name->source, - "%s must not have ALLOCATABLE " - "attribute"_err_en_US, - name->ToString()); + if (name && IsAllocatable(*name->symbol)) { + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(name->source, + "%s must not have ALLOCATABLE " + "attribute"_err_en_US, + name->ToString()); + } }, [&](const auto &) { // Anything other than a `parser::Designator` is not allowed - context_.Say(expr.source, - "Expected scalar variable " - "of intrinsic type on RHS of atomic " - "assignment statement"_err_en_US); + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(expr.source, + "Expected scalar variable " + "of intrinsic type on RHS of atomic " + "assignment statement"_err_en_US); }}, expr.u); ErrIfLHSAndRHSSymbolsMatch(var, expr); @@ -1619,6 +1638,7 @@ const parser::AssignmentStmt &assignmentStmt) { const auto &var{std::get(assignmentStmt.t)}; const auto &expr{std::get(assignmentStmt.t)}; + isValidConstruct = true; ErrIfAllocatableVariable(var); ErrIfLHSAndRHSSymbolsMatch(var, expr); ErrIfNonScalarAssignmentStmt(var, expr); @@ -1628,6 +1648,7 @@ const parser::AssignmentStmt &assignment) { const auto &expr{std::get(assignment.t)}; const auto &var{std::get(assignment.t)}; + isValidConstruct = true; common::visit( common::visitors{ [&](const common::Indirection &x) { @@ -1639,9 +1660,11 @@ !(name->source == "max" || name->source == "min" || name->source == "iand" || name->source == "ior" || name->source == "ieor")) { - context_.Say(expr.source, - "Invalid intrinsic procedure name in " - "OpenMP ATOMIC (UPDATE) statement"_err_en_US); + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(expr.source, + "Invalid intrinsic procedure name in " + "OpenMP ATOMIC (UPDATE) statement"_err_en_US); } else if (name) { bool foundMatch{false}; if (auto varDesignatorIndirection = @@ -1666,17 +1689,21 @@ } } if (!foundMatch) { - context_.Say(expr.source, - "Atomic update variable '%s' not found in the " - "argument list of intrinsic procedure"_err_en_US, - var.GetSource().ToString()); + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(expr.source, + "Atomic update variable '%s' not found in the " + "argument list of intrinsic procedure"_err_en_US, + var.GetSource().ToString()); } } }, [&](const auto &x) { if (!IsOperatorValid(x, var)) { - context_.Say(expr.source, - "Invalid operator in OpenMP ATOMIC (UPDATE) statement"_err_en_US); + isValidConstruct = false; + if (!shouldSuppressOutput) + context_.Say(expr.source, + "Invalid operator in OpenMP ATOMIC (UPDATE) statement"_err_en_US); } }, }, @@ -1684,6 +1711,91 @@ ErrIfAllocatableVariable(var); } +void OmpStructureChecker::CheckAtomicCaptureConstruct( + const parser::OmpAtomicCapture &atomicCaptureConstruct) { + const parser::AssignmentStmt &stmt1 = + std::get<3>(atomicCaptureConstruct.t).v.statement; + const parser::AssignmentStmt &stmt2 = + std::get<4>(atomicCaptureConstruct.t).v.statement; + const auto &stmt1Expr{std::get(stmt1.t)}; + const auto &stmt1Var{std::get(stmt1.t)}; + const auto &stmt2Expr{std::get(stmt2.t)}; + const auto &stmt2Var{std::get(stmt2.t)}; + + enum class StmtType { + Undefined, + CaptureStmt, + WriteStmt, + UpdateStmt + } Stmt1Type, + Stmt2Type; + + shouldSuppressOutput = true; + + // Statement 1 checks + CheckAtomicCaptureStmt(stmt1); + Stmt1Type = isValidConstruct ? StmtType::CaptureStmt : StmtType::Undefined; + CheckAtomicUpdateStmt(stmt1); + Stmt1Type = (isValidConstruct && Stmt1Type == StmtType::Undefined) + ? StmtType::UpdateStmt + : Stmt1Type; + // Statement 2 checks if statement 1 is capture-stmt + if (Stmt1Type == StmtType::CaptureStmt) { + // Check for [capture-stmt, update-stmt] + CheckAtomicUpdateStmt(stmt2); + Stmt2Type = isValidConstruct ? StmtType::UpdateStmt : StmtType::Undefined; + + // Check for [capture-stmt, write-stmt] + CheckAtomicWriteStmt(stmt2); + Stmt2Type = (isValidConstruct && Stmt2Type == StmtType::Undefined) + ? StmtType::WriteStmt + : Stmt2Type; + } + // Statement 2 checks if statement 1 is update-stmt + else if (Stmt1Type == StmtType::UpdateStmt) { + // Check for [update-stmt, capture-stmt] + CheckAtomicCaptureStmt(stmt2); + Stmt2Type = isValidConstruct ? StmtType::CaptureStmt : StmtType::Undefined; + } + shouldSuppressOutput = false; + + if (Stmt1Type == StmtType::Undefined || Stmt2Type == StmtType::Undefined) + context_.Say(stmt1Expr.source, + "Invalid atomic capture construct statements. " + "Expected one of [update-stmt, capture-stmt], " + "[capture-stmt, update-stmt], or [capture-stmt, write-stmt]"_err_en_US); + + else if (Stmt1Type == StmtType::CaptureStmt) { + // Variable captured in stmt1 should be assigned in stmt2 + const auto *e{GetExpr(context_, stmt1Expr)}; + const auto *v{GetExpr(context_, stmt2Var)}; + if (e && v) { + const Symbol &stmt2VarSymbol = evaluate::GetSymbolVector(*v).front(); + const Symbol &stmt1ExprSymbol = evaluate::GetSymbolVector(*e).front(); + if (stmt2VarSymbol != stmt1ExprSymbol) + context_.Say(stmt1Expr.source, + "Captured variable %s " + "expected to be assigned in statement 2 of " + "atomic capture construct"_err_en_US, + stmt1ExprSymbol.name().ToString()); + } + } else if (Stmt1Type == StmtType::UpdateStmt) { + // Variable updated in stmt1 should be captured in stmt2 + const auto *e{GetExpr(context_, stmt2Expr)}; + const auto *v{GetExpr(context_, stmt1Var)}; + if (e && v) { + const Symbol &stmt1VarSymbol = evaluate::GetSymbolVector(*v).front(); + const Symbol &stmt2ExprSymbol = evaluate::GetSymbolVector(*e).front(); + if (stmt1VarSymbol != stmt2ExprSymbol) + context_.Say(stmt1Var.GetSource(), + "Updated variable %s " + "expected to be captured in statement 2 of " + "atomic capture construct"_err_en_US, + stmt1Var.GetSource().ToString()); + } + } +} + void OmpStructureChecker::CheckAtomicMemoryOrderClause( const parser::OmpAtomicClauseList *leftHandClauseList, const parser::OmpAtomicClauseList *rightHandClauseList) { @@ -1767,15 +1879,15 @@ atomicWrite.t) .statement); }, - [&](const auto &atomicConstruct) { - const auto &dir{std::get(atomicConstruct.t)}; + [&](const parser::OmpAtomicCapture &atomicCapture) { + const auto &dir{std::get(atomicCapture.t)}; PushContextAndClauseSets( dir.source, llvm::omp::Directive::OMPD_atomic); - CheckAtomicMemoryOrderClause(&std::get<0>(atomicConstruct.t), - &std::get<2>(atomicConstruct.t)); + CheckAtomicMemoryOrderClause( + &std::get<0>(atomicCapture.t), &std::get<2>(atomicCapture.t)); CheckHintClause( - &std::get<0>(atomicConstruct.t), - &std::get<2>(atomicConstruct.t)); + &std::get<0>(atomicCapture.t), &std::get<2>(atomicCapture.t)); + CheckAtomicCaptureConstruct(atomicCapture); }, }, x.u); diff --git a/flang/test/Semantics/OpenMP/omp-atomic-assignment-stmt.f90 b/flang/test/Semantics/OpenMP/omp-atomic-assignment-stmt.f90 --- a/flang/test/Semantics/OpenMP/omp-atomic-assignment-stmt.f90 +++ b/flang/test/Semantics/OpenMP/omp-atomic-assignment-stmt.f90 @@ -3,7 +3,7 @@ program sample use omp_lib - integer :: x, v + integer :: x, v, b integer :: y(10) integer, allocatable :: k integer a(10) @@ -83,4 +83,44 @@ !$omp atomic write !ERROR: Expected scalar variable on the LHS of atomic assignment statement a = x + + !$omp atomic capture + v = x + x = x + 1 + !$omp end atomic + + !$omp atomic release capture + !ERROR: Invalid atomic capture construct statements. Expected one of [update-stmt, capture-stmt], [capture-stmt, update-stmt], or [capture-stmt, write-stmt] + v = x + x = b + (x * 1) + !$omp end atomic + + !$omp atomic capture hint(0) + v = x + x = 1 + !$omp end atomic + + !$omp atomic capture + !ERROR: Invalid atomic capture construct statements. Expected one of [update-stmt, capture-stmt], [capture-stmt, update-stmt], or [capture-stmt, write-stmt] + v = 1 + x = 4 + !$omp end atomic + + !$omp atomic capture + !ERROR: Captured variable x expected to be assigned in statement 2 of atomic capture construct + v = x + b = b + 1 + !$omp end atomic + + !$omp atomic capture + !ERROR: Captured variable x expected to be assigned in statement 2 of atomic capture construct + v = x + b = 10 + !$omp end atomic + + !$omp atomic capture + !ERROR: Updated variable x expected to be captured in statement 2 of atomic capture construct + x = x + 10 + v = b + !$omp end atomic end program