diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -15,6 +15,8 @@ #include "flang/Common/reference.h" #include "flang/Common/visit.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/SmallSet.h" + #include #include #include @@ -638,36 +640,8 @@ bool operator!=(const Symbol &that) const { return !(*this == that); } int Rank() const { - return common::visit( - common::visitors{ - [](const SubprogramDetails &sd) { - return sd.isFunction() ? sd.result().Rank() : 0; - }, - [](const GenericDetails &) { - return 0; /*TODO*/ - }, - [](const ProcBindingDetails &x) { return x.symbol().Rank(); }, - [](const UseDetails &x) { return x.symbol().Rank(); }, - [](const HostAssocDetails &x) { return x.symbol().Rank(); }, - [](const ObjectEntityDetails &oed) { return oed.shape().Rank(); }, - [](const ProcEntityDetails &ped) { - const Symbol *iface{ped.interface().symbol()}; - return iface ? iface->Rank() : 0; - }, - [](const AssocEntityDetails &aed) { - if (const auto &expr{aed.expr()}) { - if (auto assocRank{aed.rank()}) { - return *assocRank; - } else { - return expr->Rank(); - } - } else { - return 0; - } - }, - [](const auto &) { return 0; }, - }, - details_); + VisitedSymbols visited; + return RankImpl(visited); } int Corank() const { @@ -718,6 +692,49 @@ friend llvm::raw_ostream &DumpForUnparse( llvm::raw_ostream &, const Symbol &, bool); + using VisitedSymbols = llvm::SmallSet; + + inline const DeclTypeSpec *GetTypeImpl(VisitedSymbols &visited) const; + inline int RankImpl(VisitedSymbols &visited) const { + if (visited.contains(this)) { + return 0; + } + visited.insert(this); + return common::visit( + common::visitors{ + [&](const SubprogramDetails &sd) { + return sd.isFunction() ? sd.result().RankImpl(visited) : 0; + }, + [](const GenericDetails &) { + return 0; /*TODO*/ + }, + [&](const ProcBindingDetails &x) { + return x.symbol().RankImpl(visited); + }, + [&](const UseDetails &x) { return x.symbol().RankImpl(visited); }, + [&](const HostAssocDetails &x) { + return x.symbol().RankImpl(visited); + }, + [](const ObjectEntityDetails &oed) { return oed.shape().Rank(); }, + [&](const ProcEntityDetails &ped) { + const Symbol *iface{ped.interface().symbol()}; + return iface ? iface->RankImpl(visited) : 0; + }, + [](const AssocEntityDetails &aed) { + if (const auto &expr{aed.expr()}) { + if (auto assocRank{aed.rank()}) { + return *assocRank; + } else { + return expr->Rank(); + } + } else { + return 0; + } + }, + [](const auto &) { return 0; }, + }, + details_); + } template friend class Symbols; template friend class std::array; }; @@ -786,28 +803,42 @@ return const_cast( const_cast(this)->GetType()); } -inline const DeclTypeSpec *Symbol::GetType() const { + +inline const DeclTypeSpec *Symbol::GetTypeImpl(VisitedSymbols &visited) const { + if (visited.contains(this)) { + return nullptr; + } + visited.insert(this); return common::visit( common::visitors{ [](const EntityDetails &x) { return x.type(); }, [](const ObjectEntityDetails &x) { return x.type(); }, [](const AssocEntityDetails &x) { return x.type(); }, - [](const SubprogramDetails &x) { - return x.isFunction() ? x.result().GetType() : nullptr; + [&](const SubprogramDetails &x) { + return x.isFunction() ? x.result().GetTypeImpl(visited) : nullptr; }, - [](const ProcEntityDetails &x) { + [&](const ProcEntityDetails &x) { const Symbol *symbol{x.interface().symbol()}; - return symbol ? symbol->GetType() : x.interface().type(); + return symbol ? symbol->GetTypeImpl(visited) : x.interface().type(); + }, + [&](const ProcBindingDetails &x) { + return x.symbol().GetTypeImpl(visited); }, - [](const ProcBindingDetails &x) { return x.symbol().GetType(); }, [](const TypeParamDetails &x) { return x.type(); }, - [](const UseDetails &x) { return x.symbol().GetType(); }, - [](const HostAssocDetails &x) { return x.symbol().GetType(); }, + [&](const UseDetails &x) { return x.symbol().GetTypeImpl(visited); }, + [&](const HostAssocDetails &x) { + return x.symbol().GetTypeImpl(visited); + }, [](const auto &) -> const DeclTypeSpec * { return nullptr; }, }, details_); } +inline const DeclTypeSpec *Symbol::GetType() const { + VisitedSymbols visited; + return GetTypeImpl(visited); +} + // Sets and maps keyed by Symbols struct SymbolAddressCompare { diff --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp --- a/flang/lib/Evaluate/characteristics.cpp +++ b/flang/lib/Evaluate/characteristics.cpp @@ -72,7 +72,7 @@ return common::visit( common::visitors{ [&](const semantics::ProcEntityDetails &proc) { - const semantics::ProcInterface &interface{proc.interface()}; + const semantics::ProcInterface &interface { proc.interface() }; if (interface.type()) { return Characterize(*interface.type(), context); } else if (interface.symbol()) { @@ -367,6 +367,9 @@ static std::optional CharacterizeDummyArgument( const semantics::Symbol &symbol, FoldingContext &context, semantics::UnorderedSymbolSet seenProcs); +static std::optional CharacterizeFunctionResult( + const semantics::Symbol &symbol, FoldingContext &context, + semantics::UnorderedSymbolSet seenProcs); static std::optional CharacterizeProcedure( const semantics::Symbol &original, FoldingContext &context, @@ -397,8 +400,8 @@ [&](const semantics::SubprogramDetails &subp) -> std::optional { if (subp.isFunction()) { - if (auto fr{ - FunctionResult::Characterize(subp.result(), context)}) { + if (auto fr{CharacterizeFunctionResult( + subp.result(), context, seenProcs)}) { result.functionResult = std::move(fr); } else { return std::nullopt; @@ -438,7 +441,7 @@ } return intrinsic; } - const semantics::ProcInterface &interface{proc.interface()}; + const semantics::ProcInterface &interface { proc.interface() }; if (const semantics::Symbol * interfaceSymbol{interface.symbol()}) { return CharacterizeProcedure( *interfaceSymbol, context, seenProcs); @@ -699,8 +702,9 @@ return attrs == that.attrs && u == that.u; } -std::optional FunctionResult::Characterize( - const Symbol &symbol, FoldingContext &context) { +static std::optional CharacterizeFunctionResult( + const semantics::Symbol &symbol, FoldingContext &context, + semantics::UnorderedSymbolSet seenProcs) { if (symbol.has()) { if (auto type{TypeAndShape::Characterize(symbol, context)}) { FunctionResult result{std::move(*type)}; @@ -712,7 +716,8 @@ }); return result; } - } else if (auto maybeProc{Procedure::Characterize(symbol, context)}) { + } else if (auto maybeProc{ + CharacterizeProcedure(symbol, context, seenProcs)}) { FunctionResult result{std::move(*maybeProc)}; result.attrs.set(FunctionResult::Attr::Pointer); return result; @@ -720,6 +725,12 @@ return std::nullopt; } +std::optional FunctionResult::Characterize( + const Symbol &symbol, FoldingContext &context) { + semantics::UnorderedSymbolSet seenProcs; + return CharacterizeFunctionResult(symbol, context, seenProcs); +} + bool FunctionResult::IsAssumedLengthCharacter() const { if (const auto *ts{std::get_if(&u)}) { return ts->type().IsAssumedLengthCharacter(); diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -1786,6 +1786,9 @@ auto addSpecifics{[&](const Symbol &generic) { const auto *details{generic.GetUltimate().detailsIf()}; if (!details) { + if (generic.test(Symbol::Flag::Function)) { + Characterize(generic); + } return; } GenericKind kind{details->kind()}; diff --git a/flang/test/Semantics/resolve102.f90 b/flang/test/Semantics/resolve102.f90 --- a/flang/test/Semantics/resolve102.f90 +++ b/flang/test/Semantics/resolve102.f90 @@ -20,6 +20,12 @@ end subroutine end subroutine circular +!ERROR: Procedure 'foo' is recursively defined. Procedures in the cycle: 'foo', 'r' +function foo() result(r) + !ERROR: Procedure 'r' is recursively defined. Procedures in the cycle: 'foo', 'r' + procedure(foo), pointer :: r +end function foo + program iface !ERROR: Procedure 'p' is recursively defined. Procedures in the cycle: 'p', 'sub', 'p2' procedure(sub) :: p