diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -30,9 +30,14 @@ class Scope; class Symbol; +// Note: Here ProgramUnit includes internal subprograms while TopLevelUnit +// does not. "program-unit" in the Fortran standard matches TopLevelUnit. +const Scope &GetTopLevelUnitContaining(const Scope &); +const Scope &GetTopLevelUnitContaining(const Symbol &); +const Scope &GetProgramUnitContaining(const Scope &); +const Scope &GetProgramUnitContaining(const Symbol &); + const Scope *FindModuleContaining(const Scope &); -const Scope *FindProgramUnitContaining(const Scope &); -const Scope *FindProgramUnitContaining(const Symbol &); const Scope *FindPureProcedureContaining(const Scope &); const Scope *FindPureProcedureContaining(const Symbol &); const Symbol *FindPointerComponent(const Scope &); diff --git a/flang/lib/Semantics/check-return.cpp b/flang/lib/Semantics/check-return.cpp --- a/flang/lib/Semantics/check-return.cpp +++ b/flang/lib/Semantics/check-return.cpp @@ -16,11 +16,10 @@ namespace Fortran::semantics { static const Scope *FindContainingSubprogram(const Scope &start) { - const Scope *scope{FindProgramUnitContaining(start)}; - return scope && - (scope->kind() == Scope::Kind::MainProgram || - scope->kind() == Scope::Kind::Subprogram) - ? scope + const Scope &scope{GetProgramUnitContaining(start)}; + return scope.kind() == Scope::Kind::MainProgram || + scope.kind() == Scope::Kind::Subprogram + ? &scope : nullptr; } diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -5602,11 +5602,11 @@ } bool DeclarationVisitor::IsUplevelReference(const Symbol &symbol) { - const Scope *symbolUnit{FindProgramUnitContaining(symbol)}; - if (symbolUnit == FindProgramUnitContaining(currScope())) { + const Scope &symbolUnit{GetProgramUnitContaining(symbol)}; + if (symbolUnit == GetProgramUnitContaining(currScope())) { return false; } else { - Scope::Kind kind{DEREF(symbolUnit).kind()}; + Scope::Kind kind{symbolUnit.kind()}; return kind == Scope::Kind::Subprogram || kind == Scope::Kind::MainProgram; } } diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -37,13 +37,24 @@ } } +const Scope &GetTopLevelUnitContaining(const Scope &start) { + CHECK(!start.IsGlobal()); + return DEREF(FindScopeContaining( + start, [](const Scope &scope) { return scope.parent().IsGlobal(); })); +} + +const Scope &GetTopLevelUnitContaining(const Symbol &symbol) { + return GetTopLevelUnitContaining(symbol.owner()); +} + const Scope *FindModuleContaining(const Scope &start) { return FindScopeContaining( start, [](const Scope &scope) { return scope.IsModule(); }); } -const Scope *FindProgramUnitContaining(const Scope &start) { - return FindScopeContaining(start, [](const Scope &scope) { +const Scope &GetProgramUnitContaining(const Scope &start) { + CHECK(!start.IsGlobal()); + return DEREF(FindScopeContaining(start, [](const Scope &scope) { switch (scope.kind()) { case Scope::Kind::Module: case Scope::Kind::MainProgram: @@ -53,23 +64,19 @@ default: return false; } - }); + })); } -const Scope *FindProgramUnitContaining(const Symbol &symbol) { - return FindProgramUnitContaining(symbol.owner()); +const Scope &GetProgramUnitContaining(const Symbol &symbol) { + return GetProgramUnitContaining(symbol.owner()); } const Scope *FindPureProcedureContaining(const Scope &start) { // N.B. We only need to examine the innermost containing program unit // because an internal subprogram of a pure subprogram must also // be pure (C1592). - if (const Scope * scope{FindProgramUnitContaining(start)}) { - if (IsPureProcedure(*scope)) { - return scope; - } - } - return nullptr; + const Scope &scope{GetProgramUnitContaining(start)}; + return IsPureProcedure(scope) ? &scope : nullptr; } Tristate IsDefinedAssignment( @@ -176,9 +183,9 @@ } bool IsUseAssociated(const Symbol &symbol, const Scope &scope) { - const Scope *owner{FindProgramUnitContaining(symbol.GetUltimate().owner())}; - return owner && owner->kind() == Scope::Kind::Module && - owner != FindProgramUnitContaining(scope); + const Scope &owner{GetProgramUnitContaining(symbol.GetUltimate().owner())}; + return owner.kind() == Scope::Kind::Module && + owner != GetProgramUnitContaining(scope); } bool DoesScopeContain( @@ -203,10 +210,9 @@ } bool IsHostAssociated(const Symbol &symbol, const Scope &scope) { - const Scope *subprogram{FindProgramUnitContaining(scope)}; - return subprogram && - DoesScopeContain( - FindProgramUnitContaining(FollowHostAssoc(symbol)), *subprogram); + const Scope &subprogram{GetProgramUnitContaining(scope)}; + return DoesScopeContain( + &GetProgramUnitContaining(FollowHostAssoc(symbol)), subprogram); } bool IsInStmtFunction(const Symbol &symbol) {