diff --git a/flang/lib/Semantics/check-acc-structure.cpp b/flang/lib/Semantics/check-acc-structure.cpp --- a/flang/lib/Semantics/check-acc-structure.cpp +++ b/flang/lib/Semantics/check-acc-structure.cpp @@ -5,9 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// - #include "check-acc-structure.h" +#include "flang/Common/template.h" #include "flang/Parser/parse-tree.h" +#include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #define CHECK_SIMPLE_CLAUSE(X, Y) \ @@ -45,12 +46,36 @@ llvm::acc::Clause::ACCC_bind, llvm::acc::Clause::ACCC_gang, llvm::acc::Clause::ACCC_vector, llvm::acc::Clause::ACCC_worker}; +class ConstructSymbolSet { +public: + template bool Pre(const T &) { return true; } + template void Post(const T &) {} + bool Pre(const parser::DoConstruct &doConstruct) { + const auto &doStmt{ + std::get>(doConstruct.t)}; + if (const auto *currentLabel{MaybeGetStmtName(doStmt.statement)}) { + const Symbol &sym{*currentLabel->symbol}; + labels_.insert(sym); + } + return true; + } + SymbolSet labels() { return labels_; } + +private: + // Return the (possibly null) name of the statement + template + static const parser::Name *MaybeGetStmtName(const A &a) { + return common::GetPtrFromOptional(std::get<0>(a.t)); + } + SymbolSet labels_; +}; + class NoBranchingEnforce { public: - NoBranchingEnforce(SemanticsContext &context, + NoBranchingEnforce(SemanticsContext &context, SymbolSet &&labels, parser::CharBlock sourcePosition, llvm::acc::Directive directive) - : context_{context}, sourcePosition_{sourcePosition}, currentDirective_{ - directive} {} + : context_{context}, labels_{std::move(labels)}, + sourcePosition_{sourcePosition}, currentDirective_{directive} {} template bool Pre(const T &) { return true; } template void Post(const T &) {} @@ -60,9 +85,15 @@ } void Post(const parser::ReturnStmt &) { emitBranchOutError("RETURN"); } - void Post(const parser::ExitStmt &) { emitBranchOutError("EXIT"); } + void Post(const parser::ExitStmt &exitStmt) { + if (const auto *exitLabel{common::GetPtrFromOptional(exitStmt.v)}) { + const Symbol &sym{*exitLabel->symbol}; + if (labels_.find(sym) == labels_.end()) { + emitBranchOutError("EXIT"); + } + } + } void Post(const parser::StopStmt &) { emitBranchOutError("STOP"); } - private: parser::MessageFixedText GetEnclosingMsg() { return "Enclosing block construct"_en_US; @@ -78,6 +109,7 @@ } SemanticsContext &context_; + SymbolSet labels_; parser::CharBlock currentStatementSourcePosition_; parser::CharBlock sourcePosition_; llvm::acc::Directive currentDirective_; @@ -138,7 +170,12 @@ void AccStructureChecker::CheckNoBranching(const parser::Block &block, const llvm::acc::Directive directive, const parser::CharBlock &directiveSource) const { - NoBranchingEnforce noBranchingEnforce{context_, directiveSource, directive}; + // do-construct-name in a do construct is usually called a label. + // Works on struct Name and not the Label(R611). + ConstructSymbolSet doConstructLabelSet; + parser::Walk(block, doConstructLabelSet); + NoBranchingEnforce noBranchingEnforce{ + context_, doConstructLabelSet.labels(), directiveSource, directive}; parser::Walk(block, noBranchingEnforce); } diff --git a/flang/test/Semantics/acc-branch.f90 b/flang/test/Semantics/acc-branch.f90 --- a/flang/test/Semantics/acc-branch.f90 +++ b/flang/test/Semantics/acc-branch.f90 @@ -25,7 +25,6 @@ do i = 1, N a(i) = 3.14 if(i == N-1) THEN - !ERROR: EXIT statement is not allowed in a PARALLEL construct exit end if end do @@ -54,7 +53,6 @@ do i = 1, N a(i) = 3.14 if(i == N-1) THEN - !ERROR: EXIT statement is not allowed in a KERNELS construct exit end if end do @@ -82,7 +80,6 @@ do i = 1, N a(i) = 3.14 if(i == N-1) THEN - !ERROR: EXIT statement is not allowed in a SERIAL construct exit end if end do