diff --git a/flang/include/flang/Evaluate/constant.h b/flang/include/flang/Evaluate/constant.h --- a/flang/include/flang/Evaluate/constant.h +++ b/flang/include/flang/Evaluate/constant.h @@ -195,8 +195,11 @@ }; class StructureConstructor; -using StructureConstructorValues = - std::map>>; +struct ComponentCompare { + bool operator()(SymbolRef x, SymbolRef y) const; +}; +using StructureConstructorValues = std::map>, ComponentCompare>; template <> class Constant diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -839,10 +839,12 @@ const Symbol *GetLastTarget(const SymbolVector &); // Collects all of the Symbols in an expression -template semantics::SymbolSet CollectSymbols(const A &); -extern template semantics::SymbolSet CollectSymbols(const Expr &); -extern template semantics::SymbolSet CollectSymbols(const Expr &); -extern template semantics::SymbolSet CollectSymbols( +template semantics::UnorderedSymbolSet CollectSymbols(const A &); +extern template semantics::UnorderedSymbolSet CollectSymbols( + const Expr &); +extern template semantics::UnorderedSymbolSet CollectSymbols( + const Expr &); +extern template semantics::UnorderedSymbolSet CollectSymbols( const Expr &); // Predicate: does a variable contain a vector-valued subscript (not a triplet)? diff --git a/flang/include/flang/Semantics/semantics.h b/flang/include/flang/Semantics/semantics.h --- a/flang/include/flang/Semantics/semantics.h +++ b/flang/include/flang/Semantics/semantics.h @@ -198,8 +198,9 @@ parser::CharBlock location; IndexVarKind kind; }; - std::map activeIndexVars_; - SymbolSet errorSymbols_; + std::map + activeIndexVars_; + UnorderedSymbolSet errorSymbols_; std::set tempNames_; }; diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -596,13 +596,6 @@ bool operator==(const Symbol &that) const { return this == &that; } bool operator!=(const Symbol &that) const { return !(*this == that); } - // Symbol comparison is based on the order of cooked source - // stream creation and, when both are from the same cooked source, - // their positions in that cooked source stream. - // (This function is implemented in Evaluate/tools.cpp to - // satisfy complicated shared library interdependency.) - bool operator<(const Symbol &) const; - int Rank() const { return std::visit( common::visitors{ @@ -767,13 +760,40 @@ details_); } -inline bool operator<(SymbolRef x, SymbolRef y) { - return *x < *y; // name source position ordering -} -inline bool operator<(MutableSymbolRef x, MutableSymbolRef y) { - return *x < *y; // name source position ordering +// Sets and maps keyed by Symbols + +struct SymbolAddressCompare { + bool operator()(const SymbolRef &x, const SymbolRef &y) const { + return &*x < &*y; + } + bool operator()(const MutableSymbolRef &x, const MutableSymbolRef &y) const { + return &*x < &*y; + } +}; + +// Symbol comparison is based on the order of cooked source +// stream creation and, when both are from the same cooked source, +// their positions in that cooked source stream. +// Don't use this comparator or OrderedSymbolSet to hold +// Symbols that might be subject to ReplaceName(). +struct SymbolSourcePositionCompare { + // These functions are implemented in Evaluate/tools.cpp to + // satisfy complicated shared library interdependency. + bool operator()(const SymbolRef &, const SymbolRef &) const; + bool operator()(const MutableSymbolRef &, const MutableSymbolRef &) const; +}; + +using UnorderedSymbolSet = std::set; +using OrderedSymbolSet = std::set; + +template +OrderedSymbolSet OrderBySourcePosition(const A &container) { + OrderedSymbolSet result; + for (SymbolRef x : container) { + result.emplace(x); + } + return result; } -using SymbolSet = std::set; } // namespace Fortran::semantics diff --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp --- a/flang/lib/Evaluate/characteristics.cpp +++ b/flang/lib/Evaluate/characteristics.cpp @@ -343,30 +343,29 @@ procedure.value() == that.procedure.value(); } -static std::string GetSeenProcs(const semantics::SymbolSet &seenProcs) { +static std::string GetSeenProcs( + const semantics::UnorderedSymbolSet &seenProcs) { // Sort the symbols so that they appear in the same order on all platforms - std::vector sorter{seenProcs.begin(), seenProcs.end()}; - std::sort(sorter.begin(), sorter.end()); - + auto ordered{semantics::OrderBySourcePosition(seenProcs)}; std::string result; llvm::interleave( - sorter, + ordered, [&](const SymbolRef p) { result += '\'' + p->name().ToString() + '\''; }, [&]() { result += ", "; }); return result; } -// These functions with arguments of type SymbolSet are used with mutually -// recursive calls when characterizing a Procedure, a DummyArgument, or a -// DummyProcedure to detect circularly defined procedures as required by +// These functions with arguments of type UnorderedSymbolSet are used with +// mutually recursive calls when characterizing a Procedure, a DummyArgument, +// or a DummyProcedure to detect circularly defined procedures as required by // 15.4.3.6, paragraph 2. static std::optional CharacterizeDummyArgument( const semantics::Symbol &symbol, FoldingContext &context, - semantics::SymbolSet &seenProcs); + semantics::UnorderedSymbolSet &seenProcs); static std::optional CharacterizeProcedure( const semantics::Symbol &original, FoldingContext &context, - semantics::SymbolSet &seenProcs) { + semantics::UnorderedSymbolSet &seenProcs) { Procedure result; const auto &symbol{original.GetUltimate()}; if (seenProcs.find(symbol) != seenProcs.end()) { @@ -475,7 +474,7 @@ static std::optional CharacterizeDummyProcedure( const semantics::Symbol &symbol, FoldingContext &context, - semantics::SymbolSet &seenProcs) { + semantics::UnorderedSymbolSet &seenProcs) { if (auto procedure{CharacterizeProcedure(symbol, context, seenProcs)}) { // Dummy procedures may not be elemental. Elemental dummy procedure // interfaces are errors when the interface is not intrinsic, and that @@ -516,7 +515,7 @@ static std::optional CharacterizeDummyArgument( const semantics::Symbol &symbol, FoldingContext &context, - semantics::SymbolSet &seenProcs) { + semantics::UnorderedSymbolSet &seenProcs) { auto name{symbol.name().ToString()}; if (symbol.has()) { if (auto obj{DummyDataObject::Characterize(symbol, context)}) { @@ -779,7 +778,7 @@ std::optional Procedure::Characterize( const semantics::Symbol &original, FoldingContext &context) { - semantics::SymbolSet seenProcs; + semantics::UnorderedSymbolSet seenProcs; return CharacterizeProcedure(original, context, seenProcs); } diff --git a/flang/lib/Evaluate/constant.cpp b/flang/lib/Evaluate/constant.cpp --- a/flang/lib/Evaluate/constant.cpp +++ b/flang/lib/Evaluate/constant.cpp @@ -315,5 +315,9 @@ return Base::CopyFrom(source, count, resultSubscripts, dimOrder); } +bool ComponentCompare::operator()(SymbolRef x, SymbolRef y) const { + return semantics::SymbolSourcePositionCompare{}(x, y); +} + INSTANTIATE_CONSTANT_TEMPLATES } // namespace Fortran::evaluate diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -782,20 +782,22 @@ } struct CollectSymbolsHelper - : public SetTraverse { - using Base = SetTraverse; + : public SetTraverse { + using Base = SetTraverse; CollectSymbolsHelper() : Base{*this} {} using Base::operator(); - semantics::SymbolSet operator()(const Symbol &symbol) const { + semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const { return {symbol}; } }; -template semantics::SymbolSet CollectSymbols(const A &x) { +template semantics::UnorderedSymbolSet CollectSymbols(const A &x) { return CollectSymbolsHelper{}(x); } -template semantics::SymbolSet CollectSymbols(const Expr &); -template semantics::SymbolSet CollectSymbols(const Expr &); -template semantics::SymbolSet CollectSymbols(const Expr &); +template semantics::UnorderedSymbolSet CollectSymbols(const Expr &); +template semantics::UnorderedSymbolSet CollectSymbols( + const Expr &); +template semantics::UnorderedSymbolSet CollectSymbols( + const Expr &); // HasVectorSubscript() struct HasVectorSubscriptHelper : public AnyTraverse { @@ -1177,7 +1179,7 @@ } static const Symbol *FindFunctionResult( - const Symbol &original, SymbolSet &seen) { + const Symbol &original, UnorderedSymbolSet &seen) { const Symbol &root{GetAssociationRoot(original)}; ; if (!seen.insert(root).second) { @@ -1199,7 +1201,7 @@ } const Symbol *FindFunctionResult(const Symbol &symbol) { - SymbolSet seen; + UnorderedSymbolSet seen; return FindFunctionResult(symbol, seen); } @@ -1207,8 +1209,15 @@ // them; they cannot be defined in symbol.h due to the dependence // on Scope. -bool Symbol::operator<(const Symbol &that) const { - return GetSemanticsContext().allCookedSources().Precedes(name_, that.name_); +bool SymbolSourcePositionCompare::operator()( + const SymbolRef &x, const SymbolRef &y) const { + return x->GetSemanticsContext().allCookedSources().Precedes( + x->name(), y->name()); +} +bool SymbolSourcePositionCompare::operator()( + const MutableSymbolRef &x, const MutableSymbolRef &y) const { + return x->GetSemanticsContext().allCookedSources().Precedes( + x->name(), y->name()); } SemanticsContext &Symbol::GetSemanticsContext() const { diff --git a/flang/lib/Parser/provenance.cpp b/flang/lib/Parser/provenance.cpp --- a/flang/lib/Parser/provenance.cpp +++ b/flang/lib/Parser/provenance.cpp @@ -602,16 +602,15 @@ } bool AllCookedSources::Precedes(CharBlock x, CharBlock y) const { - const CookedSource *ySource{Find(y)}; if (const CookedSource * xSource{Find(x)}) { - if (ySource) { - int xNum{xSource->number()}; - int yNum{ySource->number()}; - return xNum < yNum || (xNum == yNum && x.begin() < y.begin()); + if (xSource->AsCharBlock().Contains(y)) { + return x.begin() < y.begin(); + } else if (const CookedSource * ySource{Find(y)}) { + return xSource->number() < ySource->number(); } else { return true; // by fiat, all cooked source < anything outside } - } else if (ySource) { + } else if (Find(y)) { return false; } else { // Both names are compiler-created (SaveTempName). diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -110,7 +110,8 @@ // that has a symbol. const Symbol *innermostSymbol_{nullptr}; // Cache of calls to Procedure::Characterize(Symbol) - std::map> characterizeCache_; + std::map, SymbolAddressCompare> + characterizeCache_; }; class DistinguishabilityHelper { diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp --- a/flang/lib/Semantics/check-do-forall.cpp +++ b/flang/lib/Semantics/check-do-forall.cpp @@ -548,9 +548,9 @@ // the names up in the scope that encloses the DO construct to avoid getting // the local versions of them. Then follow the host-, use-, and // construct-associations to get the root symbols - SymbolSet GatherLocals( + UnorderedSymbolSet GatherLocals( const std::list &localitySpecs) const { - SymbolSet symbols; + UnorderedSymbolSet symbols; const Scope &parentScope{ context_.FindScope(currentStatementSourcePosition_).parent()}; // Loop through the LocalitySpec::Local locality-specs @@ -568,8 +568,9 @@ return symbols; } - static SymbolSet GatherSymbolsFromExpression(const parser::Expr &expression) { - SymbolSet result; + static UnorderedSymbolSet GatherSymbolsFromExpression( + const parser::Expr &expression) { + UnorderedSymbolSet result; if (const auto *expr{GetExpr(expression)}) { for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) { result.insert(ResolveAssociations(symbol)); @@ -580,8 +581,9 @@ // C1121 - procedures in mask must be pure void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const { - SymbolSet references{GatherSymbolsFromExpression(mask.thing.thing.value())}; - for (const Symbol &ref : references) { + UnorderedSymbolSet references{ + GatherSymbolsFromExpression(mask.thing.thing.value())}; + for (const Symbol &ref : OrderBySourcePosition(references)) { if (IsProcedure(ref) && !IsPureProcedure(ref)) { context_.SayWithDecl(ref, parser::Unwrap(mask)->source, "%s mask expression may not reference impure procedure '%s'"_err_en_US, @@ -591,10 +593,10 @@ } } - void CheckNoCollisions(const SymbolSet &refs, const SymbolSet &uses, - parser::MessageFixedText &&errorMessage, + void CheckNoCollisions(const UnorderedSymbolSet &refs, + const UnorderedSymbolSet &uses, parser::MessageFixedText &&errorMessage, const parser::CharBlock &refPosition) const { - for (const Symbol &ref : refs) { + for (const Symbol &ref : OrderBySourcePosition(refs)) { if (uses.find(ref) != uses.end()) { context_.SayWithDecl(ref, refPosition, std::move(errorMessage), LoopKindName(), ref.name()); @@ -603,8 +605,8 @@ } } - void HasNoReferences( - const SymbolSet &indexNames, const parser::ScalarIntExpr &expr) const { + void HasNoReferences(const UnorderedSymbolSet &indexNames, + const parser::ScalarIntExpr &expr) const { CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), indexNames, "%s limit expression may not reference index variable '%s'"_err_en_US, @@ -612,8 +614,8 @@ } // C1129, names in local locality-specs can't be in mask expressions - void CheckMaskDoesNotReferenceLocal( - const parser::ScalarLogicalExpr &mask, const SymbolSet &localVars) const { + void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask, + const UnorderedSymbolSet &localVars) const { CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()), localVars, "%s mask expression references variable '%s'" @@ -623,8 +625,8 @@ // C1129, names in local locality-specs can't be in limit or step // expressions - void CheckExprDoesNotReferenceLocal( - const parser::ScalarIntExpr &expr, const SymbolSet &localVars) const { + void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr, + const UnorderedSymbolSet &localVars) const { CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), localVars, "%s expression references variable '%s'" @@ -663,7 +665,7 @@ CheckMaskIsPure(*mask); } auto &controls{std::get>(header.t)}; - SymbolSet indexNames; + UnorderedSymbolSet indexNames; for (const parser::ConcurrentControl &control : controls) { const auto &indexName{std::get(control.t)}; if (indexName.symbol) { @@ -697,7 +699,7 @@ const auto &localitySpecs{ std::get>(concurrent.t)}; if (!localitySpecs.empty()) { - const SymbolSet &localVars{GatherLocals(localitySpecs)}; + const UnorderedSymbolSet &localVars{GatherLocals(localitySpecs)}; for (const auto &c : GetControls(control)) { CheckExprDoesNotReferenceLocal(std::get<1>(c.t), localVars); CheckExprDoesNotReferenceLocal(std::get<2>(c.t), localVars); @@ -733,7 +735,7 @@ void CheckForallIndexesUsed(const evaluate::Assignment &assignment) { SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)}; if (!indexVars.empty()) { - SymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)}; + UnorderedSymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)}; std::visit( common::visitors{ [&](const evaluate::Assignment::BoundsSpec &spec) { diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -630,7 +630,7 @@ } } // A list-item cannot appear in more than one aligned clause - semantics::SymbolSet alignedVars; + semantics::UnorderedSymbolSet alignedVars; auto clauseAll = FindClauses(llvm::omp::Clause::OMPC_aligned); for (auto itr = clauseAll.first; itr != clauseAll.second; ++itr) { const auto &alignedClause{ diff --git a/flang/lib/Semantics/compute-offsets.cpp b/flang/lib/Semantics/compute-offsets.cpp --- a/flang/lib/Semantics/compute-offsets.cpp +++ b/flang/lib/Semantics/compute-offsets.cpp @@ -58,9 +58,10 @@ std::size_t offset_{0}; std::size_t alignment_{1}; // symbol -> symbol+offset that determines its location, from EQUIVALENCE - std::map dependents_; + std::map dependents_; // base symbol -> SizeAndAlignment for each distinct EQUIVALENCE block - std::map equivalenceBlock_; + std::map + equivalenceBlock_; }; void ComputeOffsetsHelper::Compute(Scope &scope) { diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -81,8 +81,8 @@ const Scope &scope_; bool isInterface_{false}; SymbolVector need_; // symbols that are needed - SymbolSet needSet_; // symbols already in need_ - SymbolSet useSet_; // use-associations that might be needed + UnorderedSymbolSet needSet_; // symbols already in need_ + UnorderedSymbolSet useSet_; // use-associations that might be needed std::set imports_; // imports from host that are needed void DoSymbol(const Symbol &); @@ -498,7 +498,8 @@ for (const auto &pair : scope.commonBlocks()) { sorted.push_back(*pair.second); } - std::sort(sorted.end() - commonSize, sorted.end()); + std::sort( + sorted.end() - commonSize, sorted.end(), SymbolSourcePositionCompare{}); } void PutEntity(llvm::raw_ostream &os, const Symbol &symbol) { diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -105,7 +105,7 @@ Symbol *DeclarePrivateAccessEntity(Symbol &, Symbol::Flag, Scope &); Symbol *DeclareOrMarkOtherAccessEntity(const parser::Name &, Symbol::Flag); - SymbolSet dataSharingAttributeObjects_; // on one directive + UnorderedSymbolSet dataSharingAttributeObjects_; // on one directive SemanticsContext &context_; std::vector dirContext_; // used as a stack }; @@ -452,8 +452,8 @@ Symbol::Flag::OmpCopyIn, Symbol::Flag::OmpCopyPrivate}; std::vector allocateNames_; // on one directive - SymbolSet privateDataSharingAttributeObjects_; // on one directive - SymbolSet stmtFunctionExprSymbols_; + UnorderedSymbolSet privateDataSharingAttributeObjects_; // on one directive + UnorderedSymbolSet stmtFunctionExprSymbols_; std::multimap>> sourceLabels_; diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -2690,7 +2690,7 @@ // this generic interface. Resolve those names to symbols. void InterfaceVisitor::ResolveSpecificsInGeneric(Symbol &generic) { auto &details{generic.get()}; - SymbolSet symbolsSeen; + UnorderedSymbolSet symbolsSeen; for (const Symbol &symbol : details.specificProcs()) { symbolsSeen.insert(symbol); } @@ -3651,7 +3651,7 @@ bool DeclarationVisitor::HasCycle( const Symbol &procSymbol, const ProcInterface &interface) { - SymbolSet procsInCycle; + OrderedSymbolSet procsInCycle; procsInCycle.insert(procSymbol); const ProcInterface *thisInterface{&interface}; bool haveInterface{true}; diff --git a/flang/lib/Semantics/scope.cpp b/flang/lib/Semantics/scope.cpp --- a/flang/lib/Semantics/scope.cpp +++ b/flang/lib/Semantics/scope.cpp @@ -61,7 +61,7 @@ for (auto &pair : symbols) { result.push_back(*pair.second); } - std::sort(result.begin(), result.end()); + std::sort(result.begin(), result.end(), SymbolSourcePositionCompare{}); return result; } diff --git a/flang/test/Semantics/resolve102.f90 b/flang/test/Semantics/resolve102.f90 --- a/flang/test/Semantics/resolve102.f90 +++ b/flang/test/Semantics/resolve102.f90 @@ -68,7 +68,6 @@ !ERROR: The interface for procedure 'p1' is recursively defined !ERROR: The interface for procedure 'p2' is recursively defined procedure(p1) p2 - !ERROR: 'p2' must be an abstract interface or a procedure with an explicit interface procedure(p2) p1 call p1 call p2 @@ -76,10 +75,8 @@ program threeCycle !ERROR: The interface for procedure 'p1' is recursively defined - !ERROR: 'p1' must be an abstract interface or a procedure with an explicit interface !ERROR: The interface for procedure 'p2' is recursively defined procedure(p1) p2 - !ERROR: 'p2' must be an abstract interface or a procedure with an explicit interface !ERROR: The interface for procedure 'p3' is recursively defined procedure(p2) p3 procedure(p3) p1