diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -255,22 +255,25 @@ bool ExprTypeKindIsDefault( const SomeExpr &expr, const SemanticsContext &context); -struct GetExprHelper { +class GetExprHelper { +public: + explicit GetExprHelper(SemanticsContext *context) : context_{context} {} + GetExprHelper() : crashIfNoExpr_{true} {} + // Specializations for parse tree nodes that have a typedExpr member. - static const SomeExpr *Get(const parser::Expr &); - static const SomeExpr *Get(const parser::Variable &); - static const SomeExpr *Get(const parser::DataStmtConstant &); - static const SomeExpr *Get(const parser::AllocateObject &); - static const SomeExpr *Get(const parser::PointerObject &); - - template - static const SomeExpr *Get(const common::Indirection &x) { + const SomeExpr *Get(const parser::Expr &); + const SomeExpr *Get(const parser::Variable &); + const SomeExpr *Get(const parser::DataStmtConstant &); + const SomeExpr *Get(const parser::AllocateObject &); + const SomeExpr *Get(const parser::PointerObject &); + + template const SomeExpr *Get(const common::Indirection &x) { return Get(x.value()); } - template static const SomeExpr *Get(const std::optional &x) { + template const SomeExpr *Get(const std::optional &x) { return x ? Get(*x) : nullptr; } - template static const SomeExpr *Get(const T &x) { + template const SomeExpr *Get(const T &x) { static_assert( !parser::HasTypedExpr::value, "explicit Get overload must be added"); if constexpr (ConstraintTrait) { @@ -281,8 +284,25 @@ return nullptr; } } + +private: + SemanticsContext *context_{nullptr}; + const bool crashIfNoExpr_{false}; }; +// If a SemanticsContext is passed, even if null, it is possible for a null +// pointer to be returned in the event of an expression that had fatal errors. +// Use these first two forms in semantics checks for best error recovery. +// If a SemanticsContext is not passed, a missing expression will +// cause a crash. +template +const SomeExpr *GetExpr(SemanticsContext *context, const T &x) { + return GetExprHelper{context}.Get(x); +} +template +const SomeExpr *GetExpr(SemanticsContext &context, const T &x) { + return GetExprHelper{&context}.Get(x); +} template const SomeExpr *GetExpr(const T &x) { return GetExprHelper{}.Get(x); } @@ -292,7 +312,7 @@ const parser::PointerAssignmentStmt &); template std::optional GetIntValue(const T &x) { - if (const auto *expr{GetExpr(x)}) { + if (const auto *expr{GetExpr(nullptr, x)}) { return evaluate::ToInt64(*expr); } else { return std::nullopt; diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp --- a/flang/lib/Semantics/assignment.cpp +++ b/flang/lib/Semantics/assignment.cpp @@ -246,7 +246,7 @@ template void AssignmentContext::PushWhereContext(const A &x) { const auto &expr{std::get(x.t)}; - CheckShape(expr.thing.value().source, GetExpr(expr)); + CheckShape(expr.thing.value().source, GetExpr(context_, expr)); ++whereDepth_; } diff --git a/flang/lib/Semantics/check-allocate.cpp b/flang/lib/Semantics/check-allocate.cpp --- a/flang/lib/Semantics/check-allocate.cpp +++ b/flang/lib/Semantics/check-allocate.cpp @@ -187,7 +187,7 @@ } if (info.gotSource || info.gotMold) { - if (const auto *expr{GetExpr(DEREF(parserSourceExpr))}) { + if (const auto *expr{GetExpr(context, DEREF(parserSourceExpr))}) { parser::CharBlock at{parserSourceExpr->source}; info.sourceExprType = expr->GetType(); if (!info.sourceExprType) { diff --git a/flang/lib/Semantics/check-arithmeticif.cpp b/flang/lib/Semantics/check-arithmeticif.cpp --- a/flang/lib/Semantics/check-arithmeticif.cpp +++ b/flang/lib/Semantics/check-arithmeticif.cpp @@ -25,7 +25,7 @@ // R853 Check for a scalar-numeric-expr // C849 that shall not be of type complex. auto &parsedExpr{std::get(arithmeticIfStmt.t)}; - if (const auto *expr{GetExpr(parsedExpr)}) { + if (const auto *expr{GetExpr(context_, parsedExpr)}) { if (expr->Rank() > 0) { context_.Say(parsedExpr.source, "ARITHMETIC IF expression must be a scalar expression"_err_en_US); diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp --- a/flang/lib/Semantics/check-case.cpp +++ b/flang/lib/Semantics/check-case.cpp @@ -240,7 +240,7 @@ const auto &selectCase{selectCaseStmt.statement}; const auto &selectExpr{ std::get>(selectCase.t).thing}; - const auto *x{GetExpr(selectExpr)}; + const auto *x{GetExpr(context_, selectExpr)}; if (!x) { return; // expression semantics failed } diff --git a/flang/lib/Semantics/check-coarray.cpp b/flang/lib/Semantics/check-coarray.cpp --- a/flang/lib/Semantics/check-coarray.cpp +++ b/flang/lib/Semantics/check-coarray.cpp @@ -64,7 +64,7 @@ template static void CheckTeamType(SemanticsContext &context, const T &x) { - if (const auto *expr{GetExpr(x)}) { + if (const auto *expr{GetExpr(context, x)}) { if (!IsTeamType(evaluate::GetDerivedTypeSpec(expr->GetType()))) { context.Say(parser::FindSourceLocation(x), // C1114 "Team value must be of type TEAM_TYPE from module ISO_FORTRAN_ENV"_err_en_US); diff --git a/flang/lib/Semantics/check-deallocate.cpp b/flang/lib/Semantics/check-deallocate.cpp --- a/flang/lib/Semantics/check-deallocate.cpp +++ b/flang/lib/Semantics/check-deallocate.cpp @@ -36,7 +36,7 @@ [&](const parser::StructureComponent &structureComponent) { // Only perform structureComponent checks it was successfully // analyzed in expression analysis. - if (GetExpr(allocateObject)) { + if (GetExpr(context_, allocateObject)) { if (!IsAllocatableOrPointer( *structureComponent.component.symbol)) { // C932 context_.Say(structureComponent.component.source, diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp --- a/flang/lib/Semantics/check-do-forall.cpp +++ b/flang/lib/Semantics/check-do-forall.cpp @@ -501,7 +501,7 @@ // Semantic checks for the limit and step expressions void CheckDoExpression(const parser::ScalarExpr &scalarExpression) { - if (const SomeExpr * expr{GetExpr(scalarExpression)}) { + if (const SomeExpr * expr{GetExpr(context_, scalarExpression)}) { if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) { // No warnings or errors for type INTEGER const parser::CharBlock &loc{scalarExpression.thing.value().source}; @@ -569,10 +569,10 @@ return symbols; } - static UnorderedSymbolSet GatherSymbolsFromExpression( - const parser::Expr &expression) { + UnorderedSymbolSet GatherSymbolsFromExpression( + const parser::Expr &expression) const { UnorderedSymbolSet result; - if (const auto *expr{GetExpr(expression)}) { + if (const auto *expr{GetExpr(context_, expression)}) { for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) { result.insert(ResolveAssociations(symbol)); } @@ -1022,7 +1022,7 @@ void DoForallChecker::Leave(const parser::Expr &parsedExpr) { CHECK(exprDepth_ > 0); if (--exprDepth_ == 0) { // Only check top level expressions - if (const SomeExpr * expr{GetExpr(parsedExpr)}) { + if (const SomeExpr * expr{GetExpr(context_, parsedExpr)}) { ActualArgumentSet argSet{CollectActualArguments(*expr)}; for (const evaluate::ActualArgumentRef &argRef : argSet) { CheckIfArgIsDoVar(*argRef, parsedExpr.source, context_); diff --git a/flang/lib/Semantics/check-io.h b/flang/lib/Semantics/check-io.h --- a/flang/lib/Semantics/check-io.h +++ b/flang/lib/Semantics/check-io.h @@ -87,7 +87,7 @@ template std::optional GetConstExpr(const T &x) { using DefaultCharConstantType = evaluate::Ascii; - if (const SomeExpr * expr{GetExpr(x)}) { + if (const SomeExpr * expr{GetExpr(context_, x)}) { const auto foldExpr{ evaluate::Fold(context_.foldingContext(), common::Clone(*expr))}; if constexpr (std::is_same_v) { diff --git a/flang/lib/Semantics/check-io.cpp b/flang/lib/Semantics/check-io.cpp --- a/flang/lib/Semantics/check-io.cpp +++ b/flang/lib/Semantics/check-io.cpp @@ -209,7 +209,7 @@ [&](const parser::Label &) { flags_.set(Flag::LabelFmt); }, [&](const parser::Star &) { flags_.set(Flag::StarFmt); }, [&](const parser::Expr &format) { - const SomeExpr *expr{GetExpr(format)}; + const SomeExpr *expr{GetExpr(context_, format)}; if (!expr) { return; } @@ -299,7 +299,7 @@ void IoChecker::Enter(const parser::IdVariable &spec) { SetSpecifier(IoSpecKind::Id); - const auto *expr{GetExpr(spec)}; + const auto *expr{GetExpr(context_, spec)}; if (!expr || !expr->GetType()) { return; } @@ -546,7 +546,7 @@ if (stmt_ == IoStmtKind::Write) { CheckForDefinableVariable(*var, "Internal file"); } - if (const auto *expr{GetExpr(*var)}) { + if (const auto *expr{GetExpr(context_, *var)}) { if (HasVectorSubscript(*expr)) { context_.Say(parser::FindSourceLocation(*var), // C1201 "Internal file must not have a vector subscript"_err_en_US); @@ -577,7 +577,7 @@ void IoChecker::Enter(const parser::OutputItem &item) { flags_.set(Flag::DataList); if (const auto *x{std::get_if(&item.u)}) { - if (const auto *expr{GetExpr(*x)}) { + if (const auto *expr{GetExpr(context_, *x)}) { if (evaluate::IsBOZLiteral(*expr)) { context_.Say(parser::FindSourceLocation(*x), // C7109 "Output item must not be a BOZ literal constant"_err_en_US); diff --git a/flang/lib/Semantics/check-nullify.cpp b/flang/lib/Semantics/check-nullify.cpp --- a/flang/lib/Semantics/check-nullify.cpp +++ b/flang/lib/Semantics/check-nullify.cpp @@ -40,7 +40,7 @@ } }, [&](const parser::StructureComponent &structureComponent) { - if (const auto *checkedExpr{GetExpr(pointerObject)}) { + if (const auto *checkedExpr{GetExpr(context_, pointerObject)}) { if (!IsPointer(*structureComponent.component.symbol)) { // C951 messages.Say(structureComponent.component.source, "component in NULLIFY statement must have the POINTER attribute"_err_en_US); diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -50,8 +50,8 @@ bool Pre(const parser::AssignmentStmt &assignment) { const auto &var{std::get(assignment.t)}; const auto &expr{std::get(assignment.t)}; - const auto *lhs{GetExpr(var)}; - const auto *rhs{GetExpr(expr)}; + const auto *lhs{GetExpr(context_, var)}; + const auto *rhs{GetExpr(context_, expr)}; if (lhs && rhs) { Tristate isDefined{semantics::IsDefinedAssignment( lhs->GetType(), lhs->Rank(), rhs->GetType(), rhs->Rank())}; @@ -65,7 +65,7 @@ } bool Pre(const parser::Expr &expr) { - if (const auto *e{GetExpr(expr)}) { + if (const auto *e{GetExpr(context_, expr)}) { for (const Symbol &symbol : evaluate::CollectSymbols(*e)) { const Symbol &root{GetAssociationRoot(symbol)}; if (IsFunction(root) && @@ -1467,7 +1467,7 @@ if (const auto *name = std::get_if(&dataRef->u)) { const auto &varSymbol = *name->symbol; - if (const auto *e{GetExpr(expr)}) { + if (const auto *e{GetExpr(context_, expr)}) { for (const Symbol &symbol : evaluate::CollectSymbols(*e)) { if (symbol == varSymbol) { diff --git a/flang/lib/Semantics/check-stop.cpp b/flang/lib/Semantics/check-stop.cpp --- a/flang/lib/Semantics/check-stop.cpp +++ b/flang/lib/Semantics/check-stop.cpp @@ -18,7 +18,7 @@ void StopChecker::Enter(const parser::StopStmt &stmt) { const auto &stopCode{std::get>(stmt.t)}; - if (const auto *expr{GetExpr(stopCode)}) { + if (const auto *expr{GetExpr(context_, stopCode)}) { const parser::CharBlock &source{parser::FindSourceLocation(stopCode)}; if (ExprHasTypeCategory(*expr, common::TypeCategory::Integer)) { // C1171 default kind diff --git a/flang/lib/Semantics/data-to-inits.cpp b/flang/lib/Semantics/data-to-inits.cpp --- a/flang/lib/Semantics/data-to-inits.cpp +++ b/flang/lib/Semantics/data-to-inits.cpp @@ -36,13 +36,13 @@ // repetition. template class ValueListIterator { public: - explicit ValueListIterator(const std::list &list) - : end_{list.end()}, at_{list.begin()} { + ValueListIterator(SemanticsContext &context, const std::list &list) + : context_{context}, end_{list.end()}, at_{list.begin()} { SetRepetitionCount(); } bool hasFatalError() const { return hasFatalError_; } bool IsAtEnd() const { return at_ == end_; } - const SomeExpr *operator*() const { return GetExpr(GetConstant()); } + const SomeExpr *operator*() const { return GetExpr(context_, GetConstant()); } parser::CharBlock LocateSource() const { return GetConstant().source; } ValueListIterator &operator++() { if (repetitionsRemaining_ > 0) { @@ -64,6 +64,7 @@ return std::get(GetValue().t); } + SemanticsContext &context_; listIterator end_, at_; ConstantSubscript repetitionsRemaining_{0}; bool hasFatalError_{false}; @@ -93,7 +94,7 @@ public: DataInitializationCompiler(DataInitializations &inits, evaluate::ExpressionAnalyzer &a, const std::list &list) - : inits_{inits}, exprAnalyzer_{a}, values_{list} {} + : inits_{inits}, exprAnalyzer_{a}, values_{a.context(), list} {} const DataInitializations &inits() const { return inits_; } bool HasSurplusValues() const { return !values_.IsAtEnd(); } bool Scan(const parser::DataStmtObject &); @@ -134,7 +135,7 @@ template bool DataInitializationCompiler::Scan(const parser::Variable &var) { - if (const auto *expr{GetExpr(var)}) { + if (const auto *expr{GetExpr(exprAnalyzer_.context(), var)}) { exprAnalyzer_.GetFoldingContext().messages().SetLocation(var.GetSource()); if (InitDesignator(*expr)) { return true; @@ -160,10 +161,13 @@ bool DataInitializationCompiler::Scan(const parser::DataImpliedDo &ido) { const auto &bounds{std::get(ido.t)}; auto name{bounds.name.thing.thing}; - const auto *lowerExpr{GetExpr(bounds.lower.thing.thing)}; - const auto *upperExpr{GetExpr(bounds.upper.thing.thing)}; - const auto *stepExpr{ - bounds.step ? GetExpr(bounds.step->thing.thing) : nullptr}; + const auto *lowerExpr{ + GetExpr(exprAnalyzer_.context(), bounds.lower.thing.thing)}; + const auto *upperExpr{ + GetExpr(exprAnalyzer_.context(), bounds.upper.thing.thing)}; + const auto *stepExpr{bounds.step + ? GetExpr(exprAnalyzer_.context(), bounds.step->thing.thing) + : nullptr}; if (lowerExpr && upperExpr) { // Fold the bounds expressions (again) in case any of them depend // on outer implied DO loops. diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -279,6 +279,18 @@ return true; } + bool Pre(const parser::StmtFunctionStmt &x) { + const auto &parsedExpr{std::get>(x.t)}; + if (const auto *expr{GetExpr(context_, parsedExpr)}) { + for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) { + if (!IsStmtFunctionDummy(symbol)) { + stmtFunctionExprSymbols_.insert(symbol.GetUltimate()); + } + } + } + return true; + } + bool Pre(const parser::OpenMPBlockConstruct &); void Post(const parser::OpenMPBlockConstruct &); diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -385,8 +385,9 @@ // If an analyzed expr or assignment is missing, dump the node and die. template -static void CheckMissingAnalysis(bool absent, const T &x) { - if (absent) { +static void CheckMissingAnalysis( + bool crash, SemanticsContext *context, const T &x) { + if (crash && !(context && context->AnyFatalError())) { std::string buf; llvm::raw_string_ostream ss{buf}; ss << "node has not been analyzed:\n"; @@ -395,34 +396,35 @@ } } -template static const SomeExpr *GetTypedExpr(const T &x) { - CheckMissingAnalysis(!x.typedExpr, x); - return common::GetPtrFromOptional(x.typedExpr->v); -} const SomeExpr *GetExprHelper::Get(const parser::Expr &x) { - return GetTypedExpr(x); + CheckMissingAnalysis(crashIfNoExpr_ && !x.typedExpr, context_, x); + return x.typedExpr ? common::GetPtrFromOptional(x.typedExpr->v) : nullptr; } const SomeExpr *GetExprHelper::Get(const parser::Variable &x) { - return GetTypedExpr(x); + CheckMissingAnalysis(crashIfNoExpr_ && !x.typedExpr, context_, x); + return x.typedExpr ? common::GetPtrFromOptional(x.typedExpr->v) : nullptr; } const SomeExpr *GetExprHelper::Get(const parser::DataStmtConstant &x) { - return GetTypedExpr(x); + CheckMissingAnalysis(crashIfNoExpr_ && !x.typedExpr, context_, x); + return x.typedExpr ? common::GetPtrFromOptional(x.typedExpr->v) : nullptr; } const SomeExpr *GetExprHelper::Get(const parser::AllocateObject &x) { - return GetTypedExpr(x); + CheckMissingAnalysis(crashIfNoExpr_ && !x.typedExpr, context_, x); + return x.typedExpr ? common::GetPtrFromOptional(x.typedExpr->v) : nullptr; } const SomeExpr *GetExprHelper::Get(const parser::PointerObject &x) { - return GetTypedExpr(x); + CheckMissingAnalysis(crashIfNoExpr_ && !x.typedExpr, context_, x); + return x.typedExpr ? common::GetPtrFromOptional(x.typedExpr->v) : nullptr; } const evaluate::Assignment *GetAssignment(const parser::AssignmentStmt &x) { - CheckMissingAnalysis(!x.typedAssignment, x); - return common::GetPtrFromOptional(x.typedAssignment->v); + return x.typedAssignment ? common::GetPtrFromOptional(x.typedAssignment->v) + : nullptr; } const evaluate::Assignment *GetAssignment( const parser::PointerAssignmentStmt &x) { - CheckMissingAnalysis(!x.typedAssignment, x); - return common::GetPtrFromOptional(x.typedAssignment->v); + return x.typedAssignment ? common::GetPtrFromOptional(x.typedAssignment->v) + : nullptr; } const Symbol *FindInterface(const Symbol &symbol) { @@ -998,7 +1000,7 @@ } bool HasCoarray(const parser::Expr &expression) { - if (const auto *expr{GetExpr(expression)}) { + if (const auto *expr{GetExpr(nullptr, expression)}) { for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) { if (evaluate::IsCoarray(symbol)) { return true;