diff --git a/flang/lib/Semantics/resolve-names-utils.h b/flang/lib/Semantics/resolve-names-utils.h --- a/flang/lib/Semantics/resolve-names-utils.h +++ b/flang/lib/Semantics/resolve-names-utils.h @@ -19,6 +19,7 @@ #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/type.h" +#include "llvm/Support/raw_ostream.h" #include namespace Fortran::parser { @@ -50,6 +51,11 @@ bool IsIntrinsicOperator(const SemanticsContext &, const SourceName &); bool IsLogicalConstant(const SemanticsContext &, const SourceName &); +// Some intrinsic operators have more than one name (e.g. `operator(.eq.)` and +// `operator(==)`). GetAllNames() returns them all, including symbolName. +std::forward_list GetAllNames( + const SemanticsContext &, const SourceName &); + template MaybeIntExpr EvaluateIntExpr(SemanticsContext &context, const T &expr) { if (MaybeExpr maybeExpr{ @@ -75,13 +81,11 @@ GenericKind kind() const { return kind_; } const SourceName &symbolName() const { return symbolName_.value(); } - // Some intrinsic operators have more than one name (e.g. `operator(.eq.)` and - // `operator(==)`). GetAllNames() returns them all, including symbolName. - std::forward_list GetAllNames(SemanticsContext &) const; // Set the GenericKind in this symbol and resolve the corresponding // name if there is one void Resolve(Symbol *) const; - Symbol *FindInScope(SemanticsContext &, const Scope &) const; + friend llvm::raw_ostream &operator<<( + llvm::raw_ostream &, const GenericSpecInfo &); private: GenericKind kind_; diff --git a/flang/lib/Semantics/resolve-names-utils.cpp b/flang/lib/Semantics/resolve-names-utils.cpp --- a/flang/lib/Semantics/resolve-names-utils.cpp +++ b/flang/lib/Semantics/resolve-names-utils.cpp @@ -29,6 +29,8 @@ using common::RelationalOperator; using IntrinsicOperator = parser::DefinedOperator::IntrinsicOperator; +static constexpr const char *operatorPrefix{"operator("}; + static GenericKind MapIntrinsicOperator(IntrinsicOperator); Symbol *Resolve(const parser::Name &name, Symbol *symbol) { @@ -65,6 +67,37 @@ return false; } +template +std::forward_list GetOperatorNames( + const SemanticsContext &context, E opr) { + std::forward_list result; + for (const char *name : context.languageFeatures().GetNames(opr)) { + result.emplace_front(std::string{operatorPrefix} + name + ')'); + } + return result; +} + +std::forward_list GetAllNames( + const SemanticsContext &context, const SourceName &name) { + std::string str{name.ToString()}; + if (!name.empty() && name.end()[-1] == ')' && + name.ToString().rfind(std::string{operatorPrefix}, 0) == 0) { + for (int i{0}; i != common::LogicalOperator_enumSize; ++i) { + auto names{GetOperatorNames(context, LogicalOperator{i})}; + if (std::find(names.begin(), names.end(), str) != names.end()) { + return names; + } + } + for (int i{0}; i != common::RelationalOperator_enumSize; ++i) { + auto names{GetOperatorNames(context, RelationalOperator{i})}; + if (std::find(names.begin(), names.end(), str) != names.end()) { + return names; + } + } + } + return {str}; +} + bool IsLogicalConstant( const SemanticsContext &context, const SourceName &name) { std::string str{name.ToString()}; @@ -73,37 +106,6 @@ (str == ".t" || str == ".f.")); } -// The operators <, <=, >, >=, ==, and /= always have the same interpretations -// as the operators .LT., .LE., .GT., .GE., .EQ., and .NE., respectively. -std::forward_list GenericSpecInfo::GetAllNames( - SemanticsContext &context) const { - auto getNames{[&](auto opr) { - std::forward_list result; - for (const char *name : context.languageFeatures().GetNames(opr)) { - result.emplace_front("operator("s + name + ')'); - } - return result; - }}; - return std::visit( - common::visitors{[&](const LogicalOperator &x) { return getNames(x); }, - [&](const RelationalOperator &x) { return getNames(x); }, - [&](const auto &) -> std::forward_list { - return {symbolName_.value().ToString()}; - }}, - kind_.u); -} - -Symbol *GenericSpecInfo::FindInScope( - SemanticsContext &context, const Scope &scope) const { - for (const auto &name : GetAllNames(context)) { - auto iter{scope.find(SourceName{name})}; - if (iter != scope.end()) { - return &*iter->second; - } - } - return nullptr; -} - void GenericSpecInfo::Resolve(Symbol *symbol) const { if (symbol) { if (auto *details{symbol->detailsIf()}) { @@ -162,6 +164,16 @@ x.u); } +llvm::raw_ostream &operator<<( + llvm::raw_ostream &os, const GenericSpecInfo &info) { + os << "GenericSpecInfo: kind=" << info.kind_.ToString(); + os << " parseName=" + << (info.parseName_ ? info.parseName_->ToString() : "null"); + os << " symbolName=" + << (info.symbolName_ ? info.symbolName_->ToString() : "null"); + return os; +} + // parser::DefinedOperator::IntrinsicOperator -> GenericKind static GenericKind MapIntrinsicOperator(IntrinsicOperator op) { switch (op) { diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -502,6 +502,9 @@ // Search for name only in scope, not in enclosing scopes. Symbol *FindInScope(const Scope &, const parser::Name &); Symbol *FindInScope(const Scope &, const SourceName &); + template Symbol *FindInScope(const T &name) { + return FindInScope(currScope(), name); + } // Search for name in a derived type scope and its parents. Symbol *FindInTypeOrParents(const Scope &, const parser::Name &); Symbol *FindInTypeOrParents(const parser::Name &); @@ -533,7 +536,7 @@ const SourceName &name, const Attrs &attrs, D &&details) { // Note: don't use FindSymbol here. If this is a derived type scope, // we want to detect whether the name is already declared as a component. - auto *symbol{FindInScope(currScope(), name)}; + auto *symbol{FindInScope(name)}; if (!symbol) { symbol = &MakeSymbol(name, attrs); symbol->set_details(std::move(details)); @@ -2048,7 +2051,7 @@ return symbol; } Symbol &ScopeHandler::CopySymbol(const SourceName &name, const Symbol &symbol) { - CHECK(!FindInScope(currScope(), name)); + CHECK(!FindInScope(name)); return MakeSymbol(currScope(), name, symbol.attrs()); } @@ -2058,11 +2061,14 @@ return Resolve(name, FindInScope(scope, name.source)); } Symbol *ScopeHandler::FindInScope(const Scope &scope, const SourceName &name) { - if (auto it{scope.find(name)}; it != scope.end()) { - return &*it->second; - } else { - return nullptr; + // all variants of names, e.g. "operator(.ne.)" for "operator(/=)" + for (const std::string &n : GetAllNames(context(), name)) { + auto it{scope.find(SourceName{n})}; + if (it != scope.end()) { + return &*it->second; + } } + return nullptr; } // Find a component or type parameter by name in a derived type or its parents. @@ -2318,7 +2324,7 @@ !symbol->attrs().test(Attr::INTRINSIC) && !symbol->has() && useNames.count(name) == 0) { SourceName location{x.moduleName.source}; - if (auto *localSymbol{FindInScope(currScope(), name)}) { + if (auto *localSymbol{FindInScope(name)}) { DoAddUse(location, localSymbol->name(), *localSymbol, *symbol); } else { DoAddUse(location, location, CopySymbol(name, *symbol), *symbol); @@ -2397,8 +2403,7 @@ generic1.CopyFrom(generic2); } EraseSymbol(localSymbol); - MakeSymbol( - localSymbol.name(), localUltimate.attrs(), std::move(generic1)); + MakeSymbol(localSymbol.name(), localSymbol.attrs(), std::move(generic1)); } else { ConvertToUseError(localSymbol, location, *useModuleScope_); } @@ -2435,8 +2440,7 @@ void ModuleVisitor::AddUse(const GenericSpecInfo &info) { if (useModuleScope_) { const auto &name{info.symbolName()}; - auto rename{ - AddUse(name, name, info.FindInScope(context(), *useModuleScope_))}; + auto rename{AddUse(name, name, FindInScope(*useModuleScope_, name))}; info.Resolve(rename.use); } } @@ -2523,7 +2527,7 @@ // Create a symbol in genericSymbol_ for this GenericSpec. bool InterfaceVisitor::Pre(const parser::GenericSpec &x) { - if (auto *symbol{GenericSpecInfo{x}.FindInScope(context(), currScope())}) { + if (auto *symbol{FindInScope(GenericSpecInfo{x}.symbolName())}) { SetGenericSymbol(*symbol); } return false; @@ -3402,7 +3406,7 @@ if (attr == Attr::INTRINSIC && !IsIntrinsic(name.source, std::nullopt)) { Say(name.source, "'%s' is not a known intrinsic procedure"_err_en_US); } - auto *symbol{FindInScope(currScope(), name)}; + auto *symbol{FindInScope(name)}; if (attr == Attr::ASYNCHRONOUS || attr == Attr::VOLATILE) { // these can be set on a symbol that is host-assoc or use-assoc if (!symbol && @@ -4065,7 +4069,7 @@ CHECK(currScope().IsDerivedType()); for (auto &declaration : tbps.declarations) { auto &bindingName{std::get(declaration.t)}; - if (Symbol * binding{FindInScope(currScope(), bindingName)}) { + if (Symbol * binding{FindInScope(bindingName)}) { if (auto *details{binding->detailsIf()}) { const Symbol *procedure{FindSubprogram(details->symbol())}; if (!CanBeTypeBoundProc(procedure)) { @@ -4134,7 +4138,7 @@ SourceName symbolName{info.symbolName()}; bool isPrivate{accessSpec ? accessSpec->v == parser::AccessSpec::Kind::Private : derivedTypeInfo_.privateBindings}; - auto *genericSymbol{info.FindInScope(context(), currScope())}; + auto *genericSymbol{FindInScope(symbolName)}; if (genericSymbol) { if (!genericSymbol->has()) { genericSymbol = nullptr; // MakeTypeSymbol will report the error below @@ -4142,7 +4146,7 @@ } else { // look in parent types: Symbol *inheritedSymbol{nullptr}; - for (const auto &name : info.GetAllNames(context())) { + for (const auto &name : GetAllNames(context(), symbolName)) { inheritedSymbol = currScope().FindComponent(SourceName{name}); if (inheritedSymbol) { break; @@ -4298,7 +4302,7 @@ } const auto &groupName{std::get(x.t)}; - auto *groupSymbol{FindInScope(currScope(), groupName)}; + auto *groupSymbol{FindInScope(groupName)}; if (!groupSymbol || !groupSymbol->has()) { groupSymbol = &MakeSymbol(groupName, std::move(details)); groupSymbol->ReplaceName(groupName.source); @@ -4397,7 +4401,7 @@ void DeclarationVisitor::CheckSaveStmts() { for (const SourceName &name : saveInfo_.entities) { - auto *symbol{FindInScope(currScope(), name)}; + auto *symbol{FindInScope(name)}; if (!symbol) { // error was reported } else if (saveInfo_.saveAll) { @@ -5159,7 +5163,7 @@ void ConstructVisitor::Post(const parser::CoarrayAssociation &x) { const auto &decl{std::get(x.t)}; const auto &name{std::get(decl.t)}; - if (auto *symbol{FindInScope(currScope(), name)}) { + if (auto *symbol{FindInScope(name)}) { const auto &selector{std::get(x.t)}; if (auto sel{ResolveSelector(selector)}) { const Symbol *whole{UnwrapWholeSymbolDataRef(sel.expr)}; @@ -5962,7 +5966,7 @@ [=](const Indirection &y) { auto info{GenericSpecInfo{y.value()}}; const auto &symbolName{info.symbolName()}; - if (auto *symbol{info.FindInScope(context(), currScope())}) { + if (auto *symbol{FindInScope(symbolName)}) { info.Resolve(&SetAccess(symbolName, accessAttr, symbol)); } else if (info.kind().IsName()) { info.Resolve(&SetAccess(symbolName, accessAttr)); @@ -6084,7 +6088,7 @@ return; } GenericDetails genericDetails; - if (Symbol * existing{info.FindInScope(context(), currScope())}) { + if (Symbol * existing{FindInScope(symbolName)}) { if (existing->has()) { info.Resolve(existing); return; // already have generic, add to it @@ -6204,7 +6208,7 @@ void ResolveNamesVisitor::CheckImport( const SourceName &location, const SourceName &name) { - if (auto *symbol{FindInScope(currScope(), name)}) { + if (auto *symbol{FindInScope(name)}) { Say(location, "'%s' from host is not accessible"_err_en_US, name) .Attach(symbol->name(), "'%s' is hidden by this entity"_en_US, symbol->name()); diff --git a/flang/test/Semantics/modfile07.f90 b/flang/test/Semantics/modfile07.f90 --- a/flang/test/Semantics/modfile07.f90 +++ b/flang/test/Semantics/modfile07.f90 @@ -549,3 +549,52 @@ ! end !end +! Verify that equivalent names are used when generic operators are merged + +module m10a + interface operator(.ne.) + end interface +end +!Expect: m10a.mod +!module m10a +! interface operator(.ne.) +! end interface +!end + +module m10b + interface operator(<>) + end interface +end +!Expect: m10b.mod +!module m10b +! interface operator(<>) +! end interface +!end + +module m10c + use m10a + use m10b + interface operator(/=) + end interface +end +!Expect: m10c.mod +!module m10c +! use m10b,only:operator(.ne.) +! use m10a,only:operator(.ne.) +! interface operator(.ne.) +! end interface +!end + +module m10d + use m10a + use m10c + private :: operator(<>) +end +!Expect: m10d.mod +!module m10d +! use m10c,only:operator(.ne.) +! use m10a,only:operator(.ne.) +! interface operator(.ne.) +! end interface +! private::operator(.ne.) +!end