diff --git a/flang/include/flang/Evaluate/check-expression.h b/flang/include/flang/Evaluate/check-expression.h --- a/flang/include/flang/Evaluate/check-expression.h +++ b/flang/include/flang/Evaluate/check-expression.h @@ -108,5 +108,8 @@ template bool IsErrorExpr(const A &); extern template bool IsErrorExpr(const Expr &); +std::optional CheckStatementFunction( + const Symbol &, const Expr &, FoldingContext &); + } // namespace Fortran::evaluate #endif diff --git a/flang/include/flang/Semantics/expression.h b/flang/include/flang/Semantics/expression.h --- a/flang/include/flang/Semantics/expression.h +++ b/flang/include/flang/Semantics/expression.h @@ -237,6 +237,7 @@ MaybeExpr Analyze(const parser::StructureConstructor &); MaybeExpr Analyze(const parser::InitialDataTarget &); MaybeExpr Analyze(const parser::NullInit &); + MaybeExpr Analyze(const parser::StmtFunctionStmt &); void Analyze(const parser::CallStmt &); const Assignment *Analyze(const parser::AssignmentStmt &); @@ -385,6 +386,7 @@ bool useSavedTypedExprs_{true}; bool inWhereBody_{false}; bool inDataStmtConstant_{false}; + bool inStmtFunctionDefinition_{false}; friend class ArgumentAnalyzer; }; diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -523,9 +523,9 @@ class Symbol { public: ENUM_CLASS(Flag, - Function, // symbol is a function + Function, // symbol is a function or statement function Subroutine, // symbol is a subroutine - StmtFunction, // symbol is a statement function (Function is set too) + StmtFunction, // symbol is a statement function or result Implicit, // symbol is implicitly typed ImplicitOrError, // symbol must be implicitly typed or it's an error ModFile, // symbol came from .mod file diff --git a/flang/lib/Evaluate/check-expression.cpp b/flang/lib/Evaluate/check-expression.cpp --- a/flang/lib/Evaluate/check-expression.cpp +++ b/flang/lib/Evaluate/check-expression.cpp @@ -871,4 +871,83 @@ template bool IsErrorExpr(const Expr &); +// C1577 +// TODO: Also check C1579 & C1582 here +class StmtFunctionChecker + : public AnyTraverse> { +public: + using Result = std::optional; + using Base = AnyTraverse; + StmtFunctionChecker(const Symbol &sf, FoldingContext &context) + : Base{*this}, sf_{sf}, context_{context} {} + using Base::operator(); + + template Result operator()(const ArrayConstructor &) const { + return parser::Message{sf_.name(), + "Statement function '%s' should not contain an array constructor"_port_en_US, + sf_.name()}; + } + Result operator()(const StructureConstructor &) const { + return parser::Message{sf_.name(), + "Statement function '%s' should not contain a structure constructor"_port_en_US, + sf_.name()}; + } + Result operator()(const TypeParamInquiry &) const { + return parser::Message{sf_.name(), + "Statement function '%s' should not contain a type parameter inquiry"_port_en_US, + sf_.name()}; + } + Result operator()(const ProcedureDesignator &proc) const { + if (const Symbol * symbol{proc.GetSymbol()}) { + const Symbol &ultimate{symbol->GetUltimate()}; + if (const auto *subp{ + ultimate.detailsIf()}) { + if (subp->stmtFunction() && &ultimate.owner() == &sf_.owner()) { + if (ultimate.name().begin() > sf_.name().begin()) { + return parser::Message{sf_.name(), + "Statement function '%s' may not reference another statement function '%s' that is defined later"_err_en_US, + sf_.name(), ultimate.name()}; + } + } + } + if (auto chars{ + characteristics::Procedure::Characterize(proc, context_)}) { + if (!chars->CanBeCalledViaImplicitInterface()) { + return parser::Message(sf_.name(), + "Statement function '%s' should not reference function '%s' that requires an explicit interface"_port_en_US, + sf_.name(), symbol->name()); + } + } + } + if (proc.Rank() > 0) { + return parser::Message(sf_.name(), + "Statement function '%s' should not reference a function that returns an array"_port_en_US, + sf_.name()); + } + return std::nullopt; + } + Result operator()(const ActualArgument &arg) const { + if (const auto *expr{arg.UnwrapExpr()}) { + if (auto result{(*this)(*expr)}) { + return result; + } + if (expr->Rank() > 0 && !UnwrapWholeSymbolOrComponentDataRef(*expr)) { + return parser::Message(sf_.name(), + "Statement function '%s' should not pass an array argument that is not a whole array"_port_en_US, + sf_.name()); + } + } + return std::nullopt; + } + +private: + const Symbol &sf_; + FoldingContext &context_; +}; + +std::optional CheckStatementFunction( + const Symbol &sf, const Expr &expr, FoldingContext &context) { + return StmtFunctionChecker{sf, context}(expr); +} + } // namespace Fortran::evaluate diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -1272,6 +1272,9 @@ // reference an IMPURE procedure or a VOLATILE variable if (const auto &expr{symbol.get().stmtFunction()}) { for (const SymbolRef &ref : evaluate::CollectSymbols(*expr)) { + if (&*ref == &symbol) { + return false; // error recovery, recursion is caught elsewhere + } if (IsFunction(*ref) && !IsPureProcedure(*ref)) { return false; } diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -965,6 +965,12 @@ } } } + if (const MaybeExpr & stmtFunction{details.stmtFunction()}) { + if (auto msg{evaluate::CheckStatementFunction( + symbol, *stmtFunction, context_.foldingContext())}) { + SayWithDeclaration(symbol, std::move(*msg)); + } + } if (IsElementalProcedure(symbol)) { // See comment on the similar check in CheckProcEntity() if (details.isDummy()) { diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -852,6 +852,12 @@ return std::nullopt; } +MaybeExpr ExpressionAnalyzer::Analyze( + const parser::StmtFunctionStmt &stmtFunc) { + inStmtFunctionDefinition_ = true; + return Analyze(std::get>(stmtFunc.t)); +} + MaybeExpr ExpressionAnalyzer::Analyze(const parser::InitialDataTarget &x) { return Analyze(x.value()); } @@ -2187,6 +2193,9 @@ context_.SetError(symbol); return false; } + } else if (inStmtFunctionDefinition_) { + semantics::ResolveSpecificationParts(context_, symbol); + CHECK(symbol.has()); } else { // 10.1.11 para 4 Say("The internal function '%s' may not be referenced in a specification expression"_err_en_US, symbol.name()); @@ -3076,7 +3085,9 @@ if (const Symbol *function{ semantics::IsFunctionResultWithSameNameAsFunction(*name->symbol)}) { auto &msg{context.Say(funcRef.v.source, - "Recursive call to '%s' requires a distinct RESULT in its declaration"_err_en_US, + function->flags().test(Symbol::Flag::StmtFunction) + ? "Recursive call to statement function '%s' is not allowed"_err_en_US + : "Recursive call to '%s' requires a distinct RESULT in its declaration"_err_en_US, name->source)}; AttachDeclaration(&msg, *function); name->symbol = const_cast(function); diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -3307,7 +3307,8 @@ // Look up name: provides return type or tells us if it's an array if (auto *symbol{FindSymbol(name)}) { auto *details{symbol->detailsIf()}; - if (!details) { + if (!details || symbol->has() || + symbol->has()) { badStmtFuncFound_ = true; return false; } @@ -3317,7 +3318,7 @@ } if (badStmtFuncFound_) { Say(name, "'%s' has not been declared as an array"_err_en_US); - return true; + return false; } auto &symbol{PushSubprogramScope(name, Symbol::Flag::Function)}; symbol.set(Symbol::Flag::StmtFunction); @@ -3342,10 +3343,9 @@ } resultDetails.set_funcResult(true); Symbol &result{MakeSymbol(name, std::move(resultDetails))}; + result.flags().set(Symbol::Flag::StmtFunction); ApplyImplicitRules(result); details.set_result(result); - const auto &parsedExpr{std::get>(x.t)}; - Walk(parsedExpr); // The analysis of the expression that constitutes the body of the // statement function is deferred to FinishSpecificationPart() so that // all declarations and implicit typing are complete. @@ -7414,28 +7414,31 @@ // Analyze the bodies of statement functions now that the symbols in this // specification part have been fully declared and implicitly typed. +// (Statement function references are not allowed in specification +// expressions, so it's safe to defer processing their definitions.) void ResolveNamesVisitor::AnalyzeStmtFunctionStmt( const parser::StmtFunctionStmt &stmtFunc) { Symbol *symbol{std::get(stmtFunc.t).symbol}; - if (!symbol || !symbol->has()) { - return; - } - auto &details{symbol->get()}; - auto expr{AnalyzeExpr( - context(), std::get>(stmtFunc.t))}; - if (!expr) { - context().SetError(*symbol); + auto *details{symbol ? symbol->detailsIf() : nullptr}; + if (!details || !symbol->scope()) { return; } - if (auto type{evaluate::DynamicType::From(*symbol)}) { - auto converted{ConvertToType(*type, std::move(*expr))}; - if (!converted) { - context().SetError(*symbol); - return; + // Resolve the symbols on the RHS of the statement function. + PushScope(*symbol->scope()); + const auto &parsedExpr{std::get>(stmtFunc.t)}; + Walk(parsedExpr); + PopScope(); + if (auto expr{AnalyzeExpr(context(), stmtFunc)}) { + if (auto type{evaluate::DynamicType::From(*symbol)}) { + if (auto converted{ConvertToType(*type, std::move(*expr))}) { + details->set_stmtFunction(std::move(*converted)); + } + } else { + details->set_stmtFunction(std::move(*expr)); } - details.set_stmtFunction(std::move(*converted)); - } else { - details.set_stmtFunction(std::move(*expr)); + } + if (!details->stmtFunction()) { + context().SetError(*symbol); } } @@ -7825,6 +7828,7 @@ resolver_.CheckBindings(tbps); } } + bool Pre(const parser::StmtFunctionStmt &stmtFunc) { return false; } private: void Init(const parser::Name &name, @@ -7849,7 +7853,7 @@ } SetScope(*node.scope()); // The initializers of pointers, the default initializers of pointer - // components, and non-deferred type-bound procedure bindings have not + // components, non-deferred type-bound procedure bindings have not // yet been traversed. // We do that now, when any (formerly) forward references that appear // in those initializers will resolve to the right symbols without diff --git a/flang/test/Semantics/stmt-func01.f90 b/flang/test/Semantics/stmt-func01.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/stmt-func01.f90 @@ -0,0 +1,44 @@ +! RUN: %python %S/test_errors.py %s %flang_fc1 +! C1577 +program main + type t1(k,l) + integer, kind :: k = kind(1) + integer, len :: l = 666 + integer(k) n + end type t1 + interface + pure integer function ifunc() + end function + end interface + type(t1(k=4,l=ifunc())) x1 + !PORTABILITY: Statement function 'sf1' should not contain an array constructor + sf1(n) = sum([(j,j=1,n)]) + type(t1) sf2 + !PORTABILITY: Statement function 'sf2' should not contain a structure constructor + sf2(n) = t1(n) + !PORTABILITY: Statement function 'sf3' should not contain a type parameter inquiry + sf3(n) = x1%l + !ERROR: Recursive call to statement function 'sf4' is not allowed + sf4(n) = sf4(n) + !ERROR: Statement function 'sf5' may not reference another statement function 'sf6' that is defined later + sf5(n) = sf6(n) + real sf7 + !ERROR: Statement function 'sf6' may not reference another statement function 'sf7' that is defined later + sf6(n) = sf7(n) + !PORTABILITY: Statement function 'sf7' should not reference function 'explicit' that requires an explicit interface + sf7(n) = explicit(n) + real :: a(3) = [1., 2., 3.] + !PORTABILITY: Statement function 'sf8' should not pass an array argument that is not a whole array + sf8(n) = sum(a(1:2)) + sf8a(n) = sum(a) ! ok + contains + real function explicit(x,y) + integer, intent(in) :: x + integer, intent(in), optional :: y + explicit = x + end function + pure function arr() + real :: arr(2) + arr = [1., 2.] + end function +end