diff --git a/flang/include/flang/Evaluate/characteristics.h b/flang/include/flang/Evaluate/characteristics.h --- a/flang/include/flang/Evaluate/characteristics.h +++ b/flang/include/flang/Evaluate/characteristics.h @@ -81,9 +81,9 @@ static std::optional Characterize( const semantics::ObjectEntityDetails &, FoldingContext &); static std::optional Characterize( - const semantics::ProcInterface &); + const semantics::ProcInterface &, FoldingContext &); static std::optional Characterize( - const semantics::DeclTypeSpec &); + const semantics::DeclTypeSpec &, FoldingContext &); static std::optional Characterize( const ActualArgument &, FoldingContext &); @@ -101,15 +101,16 @@ if (type->category() == TypeCategory::Character) { if (const auto *chExpr{UnwrapExpr>(x)}) { if (auto length{chExpr->LEN()}) { - result.set_LEN(Fold(context, std::move(*length))); + result.set_LEN(std::move(*length)); } } } - return result; + return std::move(result.Rewrite(context)); } } return std::nullopt; } + template static std::optional Characterize( const std::optional &x, FoldingContext &context) { @@ -121,9 +122,9 @@ } template static std::optional Characterize( - const A *x, FoldingContext &context) { - if (x) { - return Characterize(*x, context); + const A *p, FoldingContext &context) { + if (p) { + return Characterize(*p, context); } else { return std::nullopt; } @@ -151,14 +152,17 @@ std::optional> MeasureSizeInBytes( FoldingContext &) const; + // called by Fold() to rewrite in place + TypeAndShape &Rewrite(FoldingContext &); + llvm::raw_ostream &Dump(llvm::raw_ostream &) const; private: static std::optional Characterize( const semantics::AssocEntityDetails &, FoldingContext &); static std::optional Characterize( - const semantics::ProcEntityDetails &); - void AcquireShape(const semantics::ObjectEntityDetails &, FoldingContext &); + const semantics::ProcEntityDetails &, FoldingContext &); + void AcquireShape(const semantics::ObjectEntityDetails &); void AcquireLEN(); protected: @@ -325,6 +329,5 @@ private: Procedure() {} }; - } // namespace Fortran::evaluate::characteristics #endif // FORTRAN_EVALUATE_CHARACTERISTICS_H_ diff --git a/flang/include/flang/Evaluate/fold.h b/flang/include/flang/Evaluate/fold.h --- a/flang/include/flang/Evaluate/fold.h +++ b/flang/include/flang/Evaluate/fold.h @@ -19,6 +19,10 @@ #include "type.h" #include +namespace Fortran::evaluate::characteristics { +class TypeAndShape; +} + namespace Fortran::evaluate { using namespace Fortran::parser::literals; @@ -32,11 +36,13 @@ return Expr::Rewrite(context, std::move(expr)); } -template -std::optional> Fold( - FoldingContext &context, std::optional> &&expr) { - if (expr) { - return Fold(context, std::move(*expr)); +characteristics::TypeAndShape Fold( + FoldingContext &, characteristics::TypeAndShape &&); + +template +std::optional Fold(FoldingContext &context, std::optional &&x) { + if (x) { + return Fold(context, std::move(*x)); } else { return std::nullopt; } @@ -96,5 +102,13 @@ return std::nullopt; } } + +template std::optional ToInt64(const A *p) { + if (p) { + return ToInt64(*p); + } else { + return std::nullopt; + } +} } // namespace Fortran::evaluate #endif // FORTRAN_EVALUATE_FOLD_H_ diff --git a/flang/include/flang/Evaluate/shape.h b/flang/include/flang/Evaluate/shape.h --- a/flang/include/flang/Evaluate/shape.h +++ b/flang/include/flang/Evaluate/shape.h @@ -38,9 +38,6 @@ bool IsExplicitShape(const Symbol &); // Conversions between various representations of shapes. -Shape AsShape(const Constant &); -std::optional AsShape(FoldingContext &, ExtentExpr &&); - std::optional AsExtentArrayExpr(const Shape &); std::optional> AsConstantShape( @@ -53,29 +50,41 @@ inline int GetRank(const Shape &s) { return static_cast(s.size()); } +Shape Fold(FoldingContext &, Shape &&); +std::optional Fold(FoldingContext &, std::optional &&); + template std::optional GetShape(FoldingContext &, const A &); +template std::optional GetShape(const A &); // The dimension argument to these inquiries is zero-based, // unlike the DIM= arguments to many intrinsics. +ExtentExpr GetLowerBound(const NamedEntity &, int dimension); ExtentExpr GetLowerBound(FoldingContext &, const NamedEntity &, int dimension); +MaybeExtentExpr GetUpperBound(const NamedEntity &, int dimension); MaybeExtentExpr GetUpperBound( FoldingContext &, const NamedEntity &, int dimension); +MaybeExtentExpr ComputeUpperBound(ExtentExpr &&lower, MaybeExtentExpr &&extent); MaybeExtentExpr ComputeUpperBound( FoldingContext &, ExtentExpr &&lower, MaybeExtentExpr &&extent); +Shape GetLowerBounds(const NamedEntity &); Shape GetLowerBounds(FoldingContext &, const NamedEntity &); +Shape GetUpperBounds(const NamedEntity &); Shape GetUpperBounds(FoldingContext &, const NamedEntity &); +MaybeExtentExpr GetExtent(const NamedEntity &, int dimension); MaybeExtentExpr GetExtent(FoldingContext &, const NamedEntity &, int dimension); MaybeExtentExpr GetExtent( + const Subscript &, const NamedEntity &, int dimension); +MaybeExtentExpr GetExtent( FoldingContext &, const Subscript &, const NamedEntity &, int dimension); // Compute an element count for a triplet or trip count for a DO. -ExtentExpr CountTrips(FoldingContext &, ExtentExpr &&lower, ExtentExpr &&upper, - ExtentExpr &&stride); -ExtentExpr CountTrips(FoldingContext &, const ExtentExpr &lower, - const ExtentExpr &upper, const ExtentExpr &stride); -MaybeExtentExpr CountTrips(FoldingContext &, MaybeExtentExpr &&lower, - MaybeExtentExpr &&upper, MaybeExtentExpr &&stride); +ExtentExpr CountTrips( + ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride); +ExtentExpr CountTrips( + const ExtentExpr &lower, const ExtentExpr &upper, const ExtentExpr &stride); +MaybeExtentExpr CountTrips( + MaybeExtentExpr &&lower, MaybeExtentExpr &&upper, MaybeExtentExpr &&stride); // Computes SIZE() == PRODUCT(shape) MaybeExtentExpr GetSize(Shape &&); @@ -89,19 +98,22 @@ using Result = std::optional; using Base = AnyTraverse; using Base::operator(); - explicit GetShapeHelper(FoldingContext &c) : Base{*this}, context_{c} {} + GetShapeHelper() : Base{*this} {} + explicit GetShapeHelper(FoldingContext &c) : Base{*this}, context_{&c} {} - Result operator()(const ImpliedDoIndex &) const { return Scalar(); } - Result operator()(const DescriptorInquiry &) const { return Scalar(); } - Result operator()(const TypeParamInquiry &) const { return Scalar(); } - Result operator()(const BOZLiteralConstant &) const { return Scalar(); } + Result operator()(const ImpliedDoIndex &) const { return ScalarShape(); } + Result operator()(const DescriptorInquiry &) const { return ScalarShape(); } + Result operator()(const TypeParamInquiry &) const { return ScalarShape(); } + Result operator()(const BOZLiteralConstant &) const { return ScalarShape(); } Result operator()(const StaticDataObject::Pointer &) const { - return Scalar(); + return ScalarShape(); + } + Result operator()(const StructureConstructor &) const { + return ScalarShape(); } - Result operator()(const StructureConstructor &) const { return Scalar(); } template Result operator()(const Constant &c) const { - return AsShape(c.SHAPE()); + return ConstantShape(c.SHAPE()); } Result operator()(const Symbol &) const; @@ -125,21 +137,19 @@ } private: - static Result Scalar() { return Shape{}; } - Shape CreateShape(int rank, NamedEntity &base) const { - Shape shape; - for (int dimension{0}; dimension < rank; ++dimension) { - shape.emplace_back(GetExtent(context_, base, dimension)); - } - return shape; - } + static Result ScalarShape() { return Shape{}; } + static Shape ConstantShape(const Constant &); + Result AsShape(ExtentExpr &&) const; + static Shape CreateShape(int rank, NamedEntity &); + template MaybeExtentExpr GetArrayConstructorValueExtent( const ArrayConstructorValue &value) const { return std::visit( common::visitors{ [&](const Expr &x) -> MaybeExtentExpr { - if (std::optional xShape{GetShape(context_, x)}) { + if (auto xShape{ + context_ ? GetShape(*context_, x) : GetShape(x)}) { // Array values in array constructors get linearized. return GetSize(std::move(*xShape)); } else { @@ -154,8 +164,7 @@ !ContainsAnyImpliedDoIndex(ido.stride())) { if (auto nValues{GetArrayConstructorExtent(ido.values())}) { return std::move(*nValues) * - CountTrips( - context_, ido.lower(), ido.upper(), ido.stride()); + CountTrips(ido.lower(), ido.upper(), ido.stride()); } } return std::nullopt; @@ -178,12 +187,29 @@ return result; } - FoldingContext &context_; + FoldingContext *context_{nullptr}; }; template std::optional GetShape(FoldingContext &context, const A &x) { - return GetShapeHelper{context}(x); + if (auto shape{GetShapeHelper{context}(x)}) { + return Fold(context, std::move(shape)); + } else { + return std::nullopt; + } +} + +template std::optional GetShape(const A &x) { + return GetShapeHelper{}(x); +} + +template +std::optional GetShape(FoldingContext *context, const A &x) { + if (context) { + return GetShape(*context, x); + } else { + return GetShapeHelper{}(x); + } } template diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -223,6 +223,14 @@ return std::nullopt; } } +template +std::optional ExtractDataRef(const A *p, bool intoSubstring = false) { + if (p) { + return ExtractDataRef(*p, intoSubstring); + } else { + return std::nullopt; + } +} std::optional ExtractSubstringBase(const Substring &); // Predicate: is an expression is an array element reference? @@ -807,9 +815,6 @@ // when none is found. const Symbol *GetLastTarget(const SymbolVector &); -// Resolves any whole ASSOCIATE(B=>A) associations, then returns GetUltimate() -const Symbol &ResolveAssociations(const Symbol &); - // Collects all of the Symbols in an expression template semantics::SymbolSet CollectSymbols(const A &); extern template semantics::SymbolSet CollectSymbols(const Expr &); @@ -904,6 +909,7 @@ // These functions are used in Evaluate so they are defined here rather than in // Semantics to avoid a link-time dependency on Semantics. +// All of these apply GetUltimate() or ResolveAssociations() to their arguments. bool IsVariableName(const Symbol &); bool IsPureProcedure(const Symbol &); @@ -917,9 +923,18 @@ bool IsKindTypeParameter(const Symbol &); bool IsLenTypeParameter(const Symbol &); -// Follow use, host, and construct assocations to a variable, if any. -const Symbol *GetAssociationRoot(const Symbol &); -Symbol *GetAssociationRoot(Symbol &); +// ResolveAssociations() traverses use associations and host associations +// like GetUltimate(), but also resolves through whole variable associations +// with ASSOCIATE(x => y) and related constructs. GetAssociationRoot() +// applies ResolveAssociations() and then, in the case of resolution to +// a construct association with part of a variable that does not involve a +// vector subscript, returns the first symbol of that variable instead +// of the construct entity. +// (E.g., for ASSOCIATE(x => y%z), ResolveAssociations(x) returns x, +// while GetAssociationRoot(x) returns y.) +const Symbol &ResolveAssociations(const Symbol &); +const Symbol &GetAssociationRoot(const Symbol &); + const Symbol *FindCommonBlockContaining(const Symbol &); int CountLenParameters(const DerivedTypeSpec &); int CountNonConstantLenParameters(const DerivedTypeSpec &); diff --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h --- a/flang/include/flang/Evaluate/type.h +++ b/flang/include/flang/Evaluate/type.h @@ -340,6 +340,9 @@ template struct SomeKind { static constexpr TypeCategory category{CATEGORY}; constexpr bool operator==(const SomeKind &) const { return true; } + static std::string AsFortran() { + return "Some"s + common::EnumToString(category); + } }; using NumericCategoryTypes = std::tuple, @@ -350,7 +353,9 @@ // Represents a completely generic type (or, for Expr, a typeless // value like a BOZ literal or NULL() pointer). -struct SomeType {}; +struct SomeType { + static std::string AsFortran() { return "SomeType"s; } +}; class StructureConstructor; diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -102,6 +102,13 @@ bool IsEventTypeOrLockType(const DerivedTypeSpec *); bool IsOrContainsEventOrLockComponent(const Symbol &); bool CanBeTypeBoundProc(const Symbol *); +// Does a non-PARAMETER symbol have explicit initialization with =value or +// =>target in its declaration, or optionally in a DATA statement? (Being +// ALLOCATABLE or having a derived type with default component initialization +// doesn't count; it must be a variable initialization that implies the SAVE +// attribute, or a derived type component default value.) +bool IsStaticallyInitialized(const Symbol &, bool ignoreDATAstatements = false); +// Is the symbol explicitly or implicitly initialized in any way? bool IsInitialized(const Symbol &, bool ignoreDATAstatements = false, const Symbol *derivedType = nullptr); bool HasIntrinsicTypeName(const Symbol &); diff --git a/flang/lib/Evaluate/characteristics.cpp b/flang/lib/Evaluate/characteristics.cpp --- a/flang/lib/Evaluate/characteristics.cpp +++ b/flang/lib/Evaluate/characteristics.cpp @@ -60,6 +60,12 @@ attrs_ == that.attrs_ && corank_ == that.corank_; } +TypeAndShape &TypeAndShape::Rewrite(FoldingContext &context) { + LEN_ = Fold(context, std::move(LEN_)); + shape_ = Fold(context, std::move(shape_)); + return *this; +} + std::optional TypeAndShape::Characterize( const semantics::Symbol &symbol, FoldingContext &context) { return std::visit( @@ -77,7 +83,7 @@ [&](const semantics::ProcEntityDetails &proc) { const semantics::ProcInterface &interface{proc.interface()}; if (interface.type()) { - return Characterize(*interface.type()); + return Characterize(*interface.type(), context); } else if (interface.symbol()) { return Characterize(*interface.symbol(), context); } else { @@ -91,26 +97,23 @@ return std::optional{}; } }, - [&](const semantics::UseDetails &use) { - return Characterize(use.symbol(), context); - }, - [&](const semantics::HostAssocDetails &assoc) { - return Characterize(assoc.symbol(), context); - }, [&](const semantics::AssocEntityDetails &assoc) { return Characterize(assoc, context); }, [](const auto &) { return std::optional{}; }, }, - symbol.details()); + // GetUltimate() used here, not ResolveAssociations(), because + // we need the type/rank of an associate entity from TYPE IS, + // CLASS IS, or RANK statement. + symbol.GetUltimate().details()); } std::optional TypeAndShape::Characterize( const semantics::ObjectEntityDetails &object, FoldingContext &context) { if (auto type{DynamicType::From(object.type())}) { TypeAndShape result{std::move(*type)}; - result.AcquireShape(object, context); - return result; + result.AcquireShape(object); + return Fold(context, std::move(result)); } else { return std::nullopt; } @@ -118,26 +121,30 @@ std::optional TypeAndShape::Characterize( const semantics::AssocEntityDetails &assoc, FoldingContext &context) { + std::optional result; if (auto type{DynamicType::From(assoc.type())}) { - if (auto shape{GetShape(context, assoc.expr())}) { - TypeAndShape result{std::move(*type), std::move(*shape)}; - if (type->category() == TypeCategory::Character) { - if (const auto *chExpr{UnwrapExpr>(assoc.expr())}) { - if (auto len{chExpr->LEN()}) { - result.set_LEN(Fold(context, std::move(*len))); - } + if (auto rank{assoc.rank()}) { + if (*rank >= 0 && *rank <= common::maxRank) { + result = TypeAndShape{std::move(*type), Shape(*rank)}; + } + } else if (auto shape{GetShape(context, assoc.expr())}) { + result = TypeAndShape{std::move(*type), std::move(*shape)}; + } + if (result && type->category() == TypeCategory::Character) { + if (const auto *chExpr{UnwrapExpr>(assoc.expr())}) { + if (auto len{chExpr->LEN()}) { + result->set_LEN(std::move(*len)); } } - return std::move(result); } } - return std::nullopt; + return Fold(context, std::move(result)); } std::optional TypeAndShape::Characterize( - const semantics::DeclTypeSpec &spec) { + const semantics::DeclTypeSpec &spec, FoldingContext &context) { if (auto type{DynamicType::From(spec)}) { - return TypeAndShape{std::move(*type)}; + return Fold(context, TypeAndShape{std::move(*type)}); } else { return std::nullopt; } @@ -180,8 +187,7 @@ return std::nullopt; } -void TypeAndShape::AcquireShape( - const semantics::ObjectEntityDetails &object, FoldingContext &context) { +void TypeAndShape::AcquireShape(const semantics::ObjectEntityDetails &object) { CHECK(shape_.empty() && !attrs_.test(Attr::AssumedRank)); corank_ = object.coshape().Rank(); if (object.IsAssumedRank()) { @@ -207,7 +213,7 @@ extent = std::move(extent) + Expr{1} - std::move(*lbound); } - shape_.emplace_back(Fold(context, std::move(extent))); + shape_.emplace_back(std::move(extent)); } else { shape_.push_back(std::nullopt); } @@ -634,7 +640,7 @@ std::optional Procedure::Characterize( const semantics::Symbol &original, FoldingContext &context) { Procedure result; - const auto &symbol{ResolveAssociations(original)}; + const auto &symbol{original.GetUltimate()}; CopyAttrs(symbol, result, { {semantics::Attr::PURE, Procedure::Attr::Pure}, @@ -732,7 +738,7 @@ const ProcedureDesignator &proc, FoldingContext &context) { if (const auto *symbol{proc.GetSymbol()}) { if (auto result{characteristics::Procedure::Characterize( - ResolveAssociations(*symbol), context)}) { + symbol->GetUltimate(), context)}) { return result; } } else if (const auto *intrinsic{proc.GetSpecificIntrinsic()}) { diff --git a/flang/lib/Evaluate/check-expression.cpp b/flang/lib/Evaluate/check-expression.cpp --- a/flang/lib/Evaluate/check-expression.cpp +++ b/flang/lib/Evaluate/check-expression.cpp @@ -30,6 +30,11 @@ IsConstantExprHelper() : Base{*this} {} using Base::operator(); + // A missing expression is not considered to be constant. + template bool operator()(const std::optional &x) const { + return x && (*this)(*x); + } + bool operator()(const TypeParamInquiry &inq) const { return semantics::IsKindTypeParameter(inq.parameter()); } @@ -42,17 +47,7 @@ bool operator()(const semantics::ParamValue ¶m) const { return param.isExplicit() && (*this)(param.GetExplicit()); } - template bool operator()(const FunctionRef &call) const { - if (const auto *intrinsic{std::get_if(&call.proc().u)}) { - // kind is always a constant, and we avoid cascading errors by calling - // invalid calls to intrinsics constant - return intrinsic->name == "kind" || - intrinsic->name == IntrinsicProcTable::InvalidName; - // TODO: other inquiry intrinsics - } else { - return false; - } - } + bool operator()(const ProcedureRef &) const; bool operator()(const StructureConstructor &constructor) const { for (const auto &[symRef, expr] : constructor) { if (!IsConstantStructureConstructorComponent(*symRef, expr.value())) { @@ -77,20 +72,64 @@ } bool operator()(const Constant &) const { return true; } + bool operator()(const DescriptorInquiry &) const { return false; } private: bool IsConstantStructureConstructorComponent( - const Symbol &component, const Expr &expr) const { - if (IsAllocatable(component)) { - return IsNullPointer(expr); - } else if (IsPointer(component)) { - return IsNullPointer(expr) || IsInitialDataTarget(expr) || - IsInitialProcedureTarget(expr); - } else { - return (*this)(expr); + const Symbol &, const Expr &) const; + bool IsConstantExprShape(const Shape &) const; +}; + +bool IsConstantExprHelper::IsConstantStructureConstructorComponent( + const Symbol &component, const Expr &expr) const { + if (IsAllocatable(component)) { + return IsNullPointer(expr); + } else if (IsPointer(component)) { + return IsNullPointer(expr) || IsInitialDataTarget(expr) || + IsInitialProcedureTarget(expr); + } else { + return (*this)(expr); + } +} + +bool IsConstantExprHelper::operator()(const ProcedureRef &call) const { + // LBOUND, UBOUND, and SIZE with DIM= arguments will have been reritten + // into DescriptorInquiry operations. + if (const auto *intrinsic{std::get_if(&call.proc().u)}) { + if (intrinsic->name == "kind" || + intrinsic->name == IntrinsicProcTable::InvalidName) { + // kind is always a constant, and we avoid cascading errors by considering + // invalid calls to intrinsics to be constant + return true; + } else if (intrinsic->name == "lbound" && call.arguments().size() == 1) { + // LBOUND(x) without DIM= + auto base{ExtractNamedEntity(call.arguments()[0]->UnwrapExpr())}; + return base && IsConstantExprShape(GetLowerBounds(*base)); + } else if (intrinsic->name == "ubound" && call.arguments().size() == 1) { + // UBOUND(x) without DIM= + auto base{ExtractNamedEntity(call.arguments()[0]->UnwrapExpr())}; + return base && IsConstantExprShape(GetUpperBounds(*base)); + } else if (intrinsic->name == "shape") { + auto shape{GetShape(call.arguments()[0]->UnwrapExpr())}; + return shape && IsConstantExprShape(*shape); + } else if (intrinsic->name == "size" && call.arguments().size() == 1) { + // SIZE(x) without DIM + auto shape{GetShape(call.arguments()[0]->UnwrapExpr())}; + return shape && IsConstantExprShape(*shape); } + // TODO: STORAGE_SIZE } -}; + return false; +} + +bool IsConstantExprHelper::IsConstantExprShape(const Shape &shape) const { + for (const auto &extent : shape) { + if (!(*this)(extent)) { + return false; + } + } + return true; +} template bool IsConstantExpr(const A &x) { return IsConstantExprHelper{}(x); diff --git a/flang/lib/Evaluate/fold.cpp b/flang/lib/Evaluate/fold.cpp --- a/flang/lib/Evaluate/fold.cpp +++ b/flang/lib/Evaluate/fold.cpp @@ -8,9 +8,16 @@ #include "flang/Evaluate/fold.h" #include "fold-implementation.h" +#include "flang/Evaluate/characteristics.h" namespace Fortran::evaluate { +characteristics::TypeAndShape Fold( + FoldingContext &context, characteristics::TypeAndShape &&x) { + x.Rewrite(context); + return std::move(x); +} + std::optional> GetConstantSubscript( FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) { ss = FoldOperation(context, std::move(ss)); diff --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp --- a/flang/lib/Evaluate/intrinsics.cpp +++ b/flang/lib/Evaluate/intrinsics.cpp @@ -1559,7 +1559,12 @@ if (const Expr *expr{arg->UnwrapExpr()}) { auto dc{characteristics::DummyArgument::FromActual( std::string{d.keyword}, *expr, context)}; - CHECK(dc); + if (!dc) { + common::die("INTERNAL: could not characterize intrinsic function " + "actual argument '%s'", + expr->AsFortran().c_str()); + return std::nullopt; + } dummyArgs.emplace_back(std::move(*dc)); if (d.typePattern.kindCode == KindCode::same && !sameDummyArg) { sameDummyArg = j; diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp --- a/flang/lib/Evaluate/shape.cpp +++ b/flang/lib/Evaluate/shape.cpp @@ -22,39 +22,42 @@ namespace Fortran::evaluate { -bool IsImpliedShape(const Symbol &symbol0) { - const Symbol &symbol{ResolveAssociations(symbol0)}; +bool IsImpliedShape(const Symbol &original) { + const Symbol &symbol{ResolveAssociations(original)}; const auto *details{symbol.detailsIf()}; - return symbol.attrs().test(semantics::Attr::PARAMETER) && details && + return details && symbol.attrs().test(semantics::Attr::PARAMETER) && details->shape().IsImpliedShape(); } -bool IsExplicitShape(const Symbol &symbol0) { - const Symbol &symbol{ResolveAssociations(symbol0)}; +bool IsExplicitShape(const Symbol &original) { + const Symbol &symbol{ResolveAssociations(original)}; if (const auto *details{symbol.detailsIf()}) { const auto &shape{details->shape()}; - return shape.Rank() == 0 || shape.IsExplicitShape(); // even if scalar + return shape.Rank() == 0 || + shape.IsExplicitShape(); // true when scalar, too } else { - return false; + return symbol + .has(); // exprs have explicit shape } } -Shape AsShape(const Constant &arrayConstant) { +Shape GetShapeHelper::ConstantShape(const Constant &arrayConstant) { CHECK(arrayConstant.Rank() == 1); Shape result; std::size_t dimensions{arrayConstant.size()}; for (std::size_t j{0}; j < dimensions; ++j) { Scalar extent{arrayConstant.values().at(j)}; - result.emplace_back(MaybeExtentExpr{ExtentExpr{extent}}); + result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}}); } return result; } -std::optional AsShape(FoldingContext &context, ExtentExpr &&arrayExpr) { - // Flatten any array expression into an array constructor if possible. - arrayExpr = Fold(context, std::move(arrayExpr)); +auto GetShapeHelper::AsShape(ExtentExpr &&arrayExpr) const -> Result { + if (context_) { + arrayExpr = Fold(*context_, std::move(arrayExpr)); + } if (const auto *constArray{UnwrapConstantValue(arrayExpr)}) { - return AsShape(*constArray); + return ConstantShape(*constArray); } if (auto *constructor{UnwrapExpr>(arrayExpr)}) { Shape result; @@ -72,6 +75,14 @@ return std::nullopt; } +Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) { + Shape shape; + for (int dimension{0}; dimension < rank; ++dimension) { + shape.emplace_back(GetExtent(base, dimension)); + } + return shape; +} + std::optional AsExtentArrayExpr(const Shape &shape) { ArrayConstructorValues values; for (const auto &dim : shape) { @@ -121,33 +132,48 @@ } } -static ExtentExpr ComputeTripCount(FoldingContext &context, ExtentExpr &&lower, - ExtentExpr &&upper, ExtentExpr &&stride) { +Shape Fold(FoldingContext &context, Shape &&shape) { + for (auto &dim : shape) { + dim = Fold(context, std::move(dim)); + } + return std::move(shape); +} + +std::optional Fold( + FoldingContext &context, std::optional &&shape) { + if (shape) { + return Fold(context, std::move(*shape)); + } else { + return std::nullopt; + } +} + +static ExtentExpr ComputeTripCount( + ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) { ExtentExpr strideCopy{common::Clone(stride)}; ExtentExpr span{ (std::move(upper) - std::move(lower) + std::move(strideCopy)) / std::move(stride)}; - ExtentExpr extent{ + return ExtentExpr{ Extremum{Ordering::Greater, std::move(span), ExtentExpr{0}}}; - return Fold(context, std::move(extent)); } -ExtentExpr CountTrips(FoldingContext &context, ExtentExpr &&lower, - ExtentExpr &&upper, ExtentExpr &&stride) { +ExtentExpr CountTrips( + ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) { return ComputeTripCount( - context, std::move(lower), std::move(upper), std::move(stride)); + std::move(lower), std::move(upper), std::move(stride)); } -ExtentExpr CountTrips(FoldingContext &context, const ExtentExpr &lower, - const ExtentExpr &upper, const ExtentExpr &stride) { - return ComputeTripCount(context, common::Clone(lower), common::Clone(upper), - common::Clone(stride)); +ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper, + const ExtentExpr &stride) { + return ComputeTripCount( + common::Clone(lower), common::Clone(upper), common::Clone(stride)); } -MaybeExtentExpr CountTrips(FoldingContext &context, MaybeExtentExpr &&lower, - MaybeExtentExpr &&upper, MaybeExtentExpr &&stride) { +MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper, + MaybeExtentExpr &&stride) { std::function bound{ - std::bind(ComputeTripCount, context, _1, _2, _3)}; + std::bind(ComputeTripCount, _1, _2, _3)}; return common::MapOptional( std::move(bound), std::move(lower), std::move(upper), std::move(stride)); } @@ -182,15 +208,13 @@ using Result = ExtentExpr; using Base = Traverse; using Base::operator(); - GetLowerBoundHelper(FoldingContext &c, int d) - : Base{*this}, context_{c}, dimension_{d} {} + explicit GetLowerBoundHelper(int d) : Base{*this}, dimension_{d} {} static ExtentExpr Default() { return ExtentExpr{1}; } static ExtentExpr Combine(Result &&, Result &&) { return Default(); } ExtentExpr operator()(const Symbol &); ExtentExpr operator()(const Component &); private: - FoldingContext &context_; int dimension_; }; @@ -201,7 +225,7 @@ for (const auto &shapeSpec : details->shape()) { if (j++ == dimension_) { if (const auto &bound{shapeSpec.lbound().GetExplicit()}) { - return Fold(context_, common::Clone(*bound)); + return *bound; } else if (IsDescriptor(symbol)) { return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0}, DescriptorInquiry::Field::LowerBound, dimension_}}; @@ -226,7 +250,7 @@ for (const auto &shapeSpec : details->shape()) { if (j++ == dimension_) { if (const auto &bound{shapeSpec.lbound().GetExplicit()}) { - return Fold(context_, common::Clone(*bound)); + return *bound; } else if (IsDescriptor(symbol)) { return ExtentExpr{ DescriptorInquiry{NamedEntity{common::Clone(component)}, @@ -241,9 +265,22 @@ return Default(); } +ExtentExpr GetLowerBound(const NamedEntity &base, int dimension) { + return GetLowerBoundHelper{dimension}(base); +} + ExtentExpr GetLowerBound( FoldingContext &context, const NamedEntity &base, int dimension) { - return GetLowerBoundHelper{context, dimension}(base); + return Fold(context, GetLowerBound(base, dimension)); +} + +Shape GetLowerBounds(const NamedEntity &base) { + Shape result; + int rank{base.Rank()}; + for (int dim{0}; dim < rank; ++dim) { + result.emplace_back(GetLowerBound(base, dim)); + } + return result; } Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) { @@ -255,13 +292,12 @@ return result; } -MaybeExtentExpr GetExtent( - FoldingContext &context, const NamedEntity &base, int dimension) { +MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) { CHECK(dimension >= 0); const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())}; if (const auto *details{symbol.detailsIf()}) { if (IsImpliedShape(symbol)) { - Shape shape{GetShape(context, symbol).value()}; + Shape shape{GetShape(symbol).value()}; return std::move(shape.at(dimension)); } int j{0}; @@ -270,11 +306,10 @@ if (shapeSpec.ubound().isExplicit()) { if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) { if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) { - return Fold(context, - common::Clone(ubound.value()) - - common::Clone(lbound.value()) + ExtentExpr{1}); + return common::Clone(ubound.value()) - + common::Clone(lbound.value()) + ExtentExpr{1}; } else { - return Fold(context, common::Clone(ubound.value())); + return ubound.value(); } } } else if (details->IsAssumedSize() && j == symbol.Rank()) { @@ -287,7 +322,7 @@ } } else if (const auto *assoc{ symbol.detailsIf()}) { - if (auto shape{GetShape(context, assoc->expr())}) { + if (auto shape{GetShape(assoc->expr())}) { if (dimension < static_cast(shape->size())) { return std::move(shape->at(dimension)); } @@ -296,24 +331,29 @@ return std::nullopt; } -MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript, - const NamedEntity &base, int dimension) { +MaybeExtentExpr GetExtent( + FoldingContext &context, const NamedEntity &base, int dimension) { + return Fold(context, GetExtent(base, dimension)); +} + +MaybeExtentExpr GetExtent( + const Subscript &subscript, const NamedEntity &base, int dimension) { return std::visit( common::visitors{ [&](const Triplet &triplet) -> MaybeExtentExpr { MaybeExtentExpr upper{triplet.upper()}; if (!upper) { - upper = GetUpperBound(context, base, dimension); + upper = GetUpperBound(base, dimension); } MaybeExtentExpr lower{triplet.lower()}; if (!lower) { - lower = GetLowerBound(context, base, dimension); + lower = GetLowerBound(base, dimension); } - return CountTrips(context, std::move(lower), std::move(upper), + return CountTrips(std::move(lower), std::move(upper), MaybeExtentExpr{triplet.stride()}); }, [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr { - if (auto shape{GetShape(context, subs.value())}) { + if (auto shape{GetShape(subs.value())}) { if (GetRank(*shape) > 0) { CHECK(GetRank(*shape) == 1); // vector-valued subscript return std::move(shape->at(0)); @@ -325,70 +365,86 @@ subscript.u); } +MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript, + const NamedEntity &base, int dimension) { + return Fold(context, GetExtent(subscript, base, dimension)); +} + MaybeExtentExpr ComputeUpperBound( - FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) { + ExtentExpr &&lower, MaybeExtentExpr &&extent) { if (extent) { - return Fold(context, std::move(*extent) - std::move(lower) + ExtentExpr{1}); + return std::move(*extent) - std::move(lower) + ExtentExpr{1}; } else { return std::nullopt; } } -MaybeExtentExpr GetUpperBound( - FoldingContext &context, const NamedEntity &base, int dimension) { +MaybeExtentExpr ComputeUpperBound( + FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) { + return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent))); +} + +MaybeExtentExpr GetUpperBound(const NamedEntity &base, int dimension) { const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())}; if (const auto *details{symbol.detailsIf()}) { int j{0}; for (const auto &shapeSpec : details->shape()) { if (j++ == dimension) { if (const auto &bound{shapeSpec.ubound().GetExplicit()}) { - return Fold(context, common::Clone(*bound)); + return *bound; } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) { break; } else { - return ComputeUpperBound(context, - GetLowerBound(context, base, dimension), - GetExtent(context, base, dimension)); + return ComputeUpperBound( + GetLowerBound(base, dimension), GetExtent(base, dimension)); } } } } else if (const auto *assoc{ symbol.detailsIf()}) { - if (auto shape{GetShape(context, assoc->expr())}) { + if (auto shape{GetShape(assoc->expr())}) { if (dimension < static_cast(shape->size())) { - return ComputeUpperBound(context, - GetLowerBound(context, base, dimension), - std::move(shape->at(dimension))); + return ComputeUpperBound( + GetLowerBound(base, dimension), std::move(shape->at(dimension))); } } } return std::nullopt; } -Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) { +MaybeExtentExpr GetUpperBound( + FoldingContext &context, const NamedEntity &base, int dimension) { + return Fold(context, GetUpperBound(base, dimension)); +} + +Shape GetUpperBounds(const NamedEntity &base) { const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())}; if (const auto *details{symbol.detailsIf()}) { Shape result; int dim{0}; for (const auto &shapeSpec : details->shape()) { if (const auto &bound{shapeSpec.ubound().GetExplicit()}) { - result.emplace_back(Fold(context, common::Clone(*bound))); + result.push_back(*bound); } else if (details->IsAssumedSize()) { CHECK(dim + 1 == base.Rank()); result.emplace_back(std::nullopt); // UBOUND folding replaces with -1 } else { - result.emplace_back(ComputeUpperBound(context, - GetLowerBound(context, base, dim), GetExtent(context, base, dim))); + result.emplace_back( + ComputeUpperBound(GetLowerBound(base, dim), GetExtent(base, dim))); } ++dim; } CHECK(GetRank(result) == symbol.Rank()); return result; } else { - return std::move(GetShape(context, base).value()); + return std::move(GetShape(symbol).value()); } } +Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) { + return Fold(context, GetUpperBounds(base)); +} + auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result { return std::visit( common::visitors{ @@ -402,13 +458,13 @@ } }, [](const semantics::EntityDetails &) { - return Scalar(); // no dimensions seen + return ScalarShape(); // no dimensions seen }, [&](const semantics::ProcEntityDetails &proc) { if (const Symbol * interface{proc.interface().symbol()}) { return (*this)(*interface); } else { - return Scalar(); + return ScalarShape(); } }, [&](const semantics::AssocEntityDetails &assoc) { @@ -436,7 +492,7 @@ [&](const semantics::HostAssocDetails &assoc) { return (*this)(assoc.symbol()); }, - [](const semantics::TypeParamDetails &) { return Scalar(); }, + [](const semantics::TypeParamDetails &) { return ScalarShape(); }, [](const auto &) { return Result{}; }, }, symbol.details()); @@ -464,7 +520,7 @@ const NamedEntity &base{arrayRef.base()}; for (const Subscript &ss : arrayRef.subscript()) { if (ss.Rank() > 0) { - shape.emplace_back(GetExtent(context_, ss, base, dimension)); + shape.emplace_back(GetExtent(ss, base, dimension)); } ++dimension; } @@ -485,7 +541,7 @@ int dimension{0}; for (const Subscript &ss : coarrayRef.subscript()) { if (ss.Rank() > 0) { - shape.emplace_back(GetExtent(context_, ss, base, dimension)); + shape.emplace_back(GetExtent(ss, base, dimension)); } ++dimension; } @@ -499,14 +555,14 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result { if (call.Rank() == 0) { - return Scalar(); + return ScalarShape(); } else if (call.IsElemental()) { for (const auto &arg : call.arguments()) { if (arg && arg->Rank() > 0) { return (*this)(*arg); } } - return Scalar(); + return ScalarShape(); } else if (const Symbol * symbol{call.proc().GetSymbol()}) { return (*this)(*symbol); } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) { @@ -565,14 +621,14 @@ if (const auto *shapeExpr{ call.arguments().at(1).value().UnwrapExpr()}) { auto shape{std::get>(shapeExpr->u)}; - return AsShape(context_, ConvertToType(std::move(shape))); + return AsShape(ConvertToType(std::move(shape))); } } } else if (intrinsic->name == "pack") { if (call.arguments().size() >= 3 && call.arguments().at(2)) { // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v) return (*this)(call.arguments().at(2)); - } else if (call.arguments().size() >= 2) { + } else if (call.arguments().size() >= 2 && context_) { if (auto maskShape{(*this)(call.arguments().at(1))}) { if (maskShape->size() == 0) { // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)] @@ -583,8 +639,8 @@ ActualArgument{AsGenericExpr(std::move(*arraySize))}, ActualArgument{AsGenericExpr(ExtentExpr{0})}, common::Clone(call.arguments().at(1))}; - auto specific{context_.intrinsics().Probe( - CallCharacteristics{"merge"}, toMerge, context_)}; + auto specific{context_->intrinsics().Probe( + CallCharacteristics{"merge"}, toMerge, *context_)}; CHECK(specific); return Shape{ExtentExpr{FunctionRef{ ProcedureDesignator{std::move(specific->specificIntrinsic)}, @@ -594,8 +650,8 @@ // Non-scalar MASK= -> [COUNT(mask)] ActualArguments toCount{ActualArgument{common::Clone( DEREF(call.arguments().at(1).value().UnwrapExpr()))}}; - auto specific{context_.intrinsics().Probe( - CallCharacteristics{"count"}, toCount, context_)}; + auto specific{context_->intrinsics().Probe( + CallCharacteristics{"count"}, toCount, *context_)}; CHECK(specific); return Shape{ExtentExpr{FunctionRef{ ProcedureDesignator{std::move(specific->specificIntrinsic)}, @@ -631,27 +687,29 @@ return Shape{ MaybeExtentExpr{ConvertToType(common::Clone(*size))}}; } - } else if (auto moldTypeAndShape{ - characteristics::TypeAndShape::Characterize( - call.arguments().at(1), context_)}) { - if (GetRank(moldTypeAndShape->shape()) == 0) { - // SIZE= is absent and MOLD= is scalar: result is scalar - return Scalar(); - } else { - // SIZE= is absent and MOLD= is array: result is vector whose - // length is determined by sizes of types. See 16.9.193p4 case(ii). - if (auto sourceTypeAndShape{ - characteristics::TypeAndShape::Characterize( - call.arguments().at(0), context_)}) { - auto sourceBytes{sourceTypeAndShape->MeasureSizeInBytes(context_)}; - auto moldElementBytes{ - moldTypeAndShape->type().MeasureSizeInBytes(context_, true)}; - if (sourceBytes && moldElementBytes) { - ExtentExpr extent{Fold(context_, - (std::move(*sourceBytes) + common::Clone(*moldElementBytes) - - ExtentExpr{1}) / - common::Clone(*moldElementBytes))}; - return Shape{MaybeExtentExpr{std::move(extent)}}; + } else if (context_) { + if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize( + call.arguments().at(1), *context_)}) { + if (GetRank(moldTypeAndShape->shape()) == 0) { + // SIZE= is absent and MOLD= is scalar: result is scalar + return ScalarShape(); + } else { + // SIZE= is absent and MOLD= is array: result is vector whose + // length is determined by sizes of types. See 16.9.193p4 case(ii). + if (auto sourceTypeAndShape{ + characteristics::TypeAndShape::Characterize( + call.arguments().at(0), *context_)}) { + auto sourceBytes{ + sourceTypeAndShape->MeasureSizeInBytes(*context_)}; + auto moldElementBytes{ + moldTypeAndShape->type().MeasureSizeInBytes(*context_, true)}; + if (sourceBytes && moldElementBytes) { + ExtentExpr extent{Fold(*context_, + (std::move(*sourceBytes) + + common::Clone(*moldElementBytes) - ExtentExpr{1}) / + common::Clone(*moldElementBytes))}; + return Shape{MaybeExtentExpr{std::move(extent)}}; + } } } } diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -659,8 +659,8 @@ } } -bool IsAssumedRank(const Symbol &symbol0) { - const Symbol &symbol{ResolveAssociations(symbol0)}; +bool IsAssumedRank(const Symbol &original) { + const Symbol &symbol{semantics::ResolveAssociations(original)}; if (const auto *details{symbol.detailsIf()}) { return details->IsAssumedRank(); } else { @@ -743,15 +743,6 @@ return iter == end ? nullptr : &**iter; } -const Symbol &ResolveAssociations(const Symbol &symbol) { - if (const auto *details{symbol.detailsIf()}) { - if (const Symbol * nested{UnwrapWholeSymbolDataRef(details->expr())}) { - return ResolveAssociations(*nested); - } - } - return symbol.GetUltimate(); -} - struct CollectSymbolsHelper : public SetTraverse { using Base = SetTraverse; @@ -909,39 +900,55 @@ namespace Fortran::semantics { +const Symbol &ResolveAssociations(const Symbol &original) { + const Symbol &symbol{original.GetUltimate()}; + if (const auto *details{symbol.detailsIf()}) { + if (const Symbol * nested{UnwrapWholeSymbolDataRef(details->expr())}) { + return ResolveAssociations(*nested); + } + } + return symbol; +} + // When a construct association maps to a variable, and that variable // is not an array with a vector-valued subscript, return the base // Symbol of that variable, else nullptr. Descends into other construct // associations when one associations maps to another. -static const Symbol *GetAssociatedVariable( - const semantics::AssocEntityDetails &details) { +static const Symbol *GetAssociatedVariable(const AssocEntityDetails &details) { if (const auto &expr{details.expr()}) { if (IsVariable(*expr) && !HasVectorSubscript(*expr)) { if (const Symbol * varSymbol{GetFirstSymbol(*expr)}) { - return GetAssociationRoot(*varSymbol); + return &GetAssociationRoot(*varSymbol); } } } return nullptr; } -const Symbol *GetAssociationRoot(const Symbol &symbol) { - const Symbol &ultimate{symbol.GetUltimate()}; - const auto *details{ultimate.detailsIf()}; - return details ? GetAssociatedVariable(*details) : &ultimate; -} - -Symbol *GetAssociationRoot(Symbol &symbol) { - return const_cast( - GetAssociationRoot(const_cast(symbol))); +const Symbol &GetAssociationRoot(const Symbol &original) { + const Symbol &symbol{ResolveAssociations(original)}; + if (const auto *details{symbol.detailsIf()}) { + if (const Symbol * root{GetAssociatedVariable(*details)}) { + return *root; + } + } + return symbol; } -bool IsVariableName(const Symbol &symbol) { - const Symbol *root{GetAssociationRoot(symbol)}; - return root && root->has() && !IsNamedConstant(*root); +bool IsVariableName(const Symbol &original) { + const Symbol &symbol{ResolveAssociations(original)}; + if (symbol.has()) { + return !IsNamedConstant(symbol); + } else if (const auto *assoc{symbol.detailsIf()}) { + const auto &expr{assoc->expr()}; + return expr && IsVariable(*expr) && !HasVectorSubscript(*expr); + } else { + return false; + } } -bool IsPureProcedure(const Symbol &symbol) { +bool IsPureProcedure(const Symbol &original) { + const Symbol &symbol{original.GetUltimate()}; if (const auto *procDetails{symbol.detailsIf()}) { if (const Symbol * procInterface{procDetails->interface().symbol()}) { // procedure component with a pure interface @@ -960,8 +967,7 @@ if (IsFunction(*ref) && !IsPureProcedure(*ref)) { return false; } - const Symbol *root{GetAssociationRoot(*ref)}; - if (root && root->attrs().test(Attr::VOLATILE)) { + if (ref->GetUltimate().attrs().test(Attr::VOLATILE)) { return false; } } @@ -990,24 +996,21 @@ return ifc.type() || (ifc.symbol() && IsFunction(*ifc.symbol())); }, [](const ProcBindingDetails &x) { return IsFunction(x.symbol()); }, - [](const UseDetails &x) { return IsFunction(x.symbol()); }, [](const auto &) { return false; }, }, - symbol.details()); + symbol.GetUltimate().details()); } bool IsProcedure(const Symbol &symbol) { - return std::visit( - common::visitors{ - [](const SubprogramDetails &) { return true; }, - [](const SubprogramNameDetails &) { return true; }, - [](const ProcEntityDetails &) { return true; }, - [](const GenericDetails &) { return true; }, - [](const ProcBindingDetails &) { return true; }, - [](const UseDetails &x) { return IsProcedure(x.symbol()); }, - [](const auto &) { return false; }, - }, - symbol.details()); + return std::visit(common::visitors{ + [](const SubprogramDetails &) { return true; }, + [](const SubprogramNameDetails &) { return true; }, + [](const ProcEntityDetails &) { return true; }, + [](const GenericDetails &) { return true; }, + [](const ProcBindingDetails &) { return true; }, + [](const auto &) { return false; }, + }, + symbol.GetUltimate().details()); } const Symbol *FindCommonBlockContaining(const Symbol &object) { @@ -1015,39 +1018,39 @@ return details ? details->commonBlock() : nullptr; } -bool IsProcedurePointer(const Symbol &symbol) { +bool IsProcedurePointer(const Symbol &original) { + const Symbol &symbol{original.GetUltimate()}; return symbol.has() && IsPointer(symbol); } bool IsSaved(const Symbol &original) { - if (const Symbol * root{GetAssociationRoot(original)}) { - const Symbol &symbol{*root}; - const Scope *scope{&symbol.owner()}; - auto scopeKind{scope->kind()}; - if (scopeKind == Scope::Kind::Module) { - return true; // BLOCK DATA entities must all be in COMMON, handled below - } else if (symbol.attrs().test(Attr::SAVE)) { - return true; - } else if (scopeKind == Scope::Kind::DerivedType) { - return false; // this is a component - } else if (IsNamedConstant(symbol)) { - return false; - } else if (const auto *object{symbol.detailsIf()}; - object && object->init()) { - return true; - } else if (IsProcedurePointer(symbol) && - symbol.get().init()) { - return true; - } else if (const Symbol * block{FindCommonBlockContaining(symbol)}; - block && block->attrs().test(Attr::SAVE)) { - return true; - } else if (IsDummy(symbol) || IsFunctionResult(symbol)) { - return false; - } else if (scope->hasSAVE() ) { - return true; - } + const Symbol &symbol{GetAssociationRoot(original)}; + const Scope &scope{symbol.owner()}; + auto scopeKind{scope.kind()}; + if (symbol.has()) { + return false; // ASSOCIATE(non-variable) + } else if (scopeKind == Scope::Kind::Module) { + return true; // BLOCK DATA entities must all be in COMMON, handled below + } else if (symbol.attrs().test(Attr::SAVE)) { + return true; + } else if (scopeKind == Scope::Kind::DerivedType) { + return false; // this is a component + } else if (IsNamedConstant(symbol)) { + return false; + } else if (const auto *object{symbol.detailsIf()}; + object && object->init()) { + return true; + } else if (IsProcedurePointer(symbol) && + symbol.get().init()) { + return true; + } else if (const Symbol * block{FindCommonBlockContaining(symbol)}; + block && block->attrs().test(Attr::SAVE)) { + return true; + } else if (IsDummy(symbol) || IsFunctionResult(symbol)) { + return false; + } else { + return scope.hasSAVE(); } - return false; } bool IsDummy(const Symbol &symbol) { @@ -1055,12 +1058,12 @@ common::visitors{[](const EntityDetails &x) { return x.isDummy(); }, [](const ObjectEntityDetails &x) { return x.isDummy(); }, [](const ProcEntityDetails &x) { return x.isDummy(); }, - [](const HostAssocDetails &x) { return IsDummy(x.symbol()); }, [](const auto &) { return false; }}, - symbol.details()); + ResolveAssociations(symbol).details()); } -bool IsFunctionResult(const Symbol &symbol) { +bool IsFunctionResult(const Symbol &original) { + const Symbol &symbol{GetAssociationRoot(original)}; return (symbol.has() && symbol.get().isFuncResult()) || (symbol.has() && @@ -1068,12 +1071,12 @@ } bool IsKindTypeParameter(const Symbol &symbol) { - const auto *param{symbol.detailsIf()}; + const auto *param{symbol.GetUltimate().detailsIf()}; return param && param->attr() == common::TypeParamAttr::Kind; } bool IsLenTypeParameter(const Symbol &symbol) { - const auto *param{symbol.detailsIf()}; + const auto *param{symbol.GetUltimate().detailsIf()}; return param && param->attr() == common::TypeParamAttr::Len; } diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp --- a/flang/lib/Semantics/check-call.cpp +++ b/flang/lib/Semantics/check-call.cpp @@ -265,10 +265,10 @@ // Rank and shape checks const auto *actualLastSymbol{evaluate::GetLastSymbol(actual)}; if (actualLastSymbol) { - actualLastSymbol = GetAssociationRoot(*actualLastSymbol); + actualLastSymbol = &ResolveAssociations(*actualLastSymbol); } const ObjectEntityDetails *actualLastObject{actualLastSymbol - ? actualLastSymbol->GetUltimate().detailsIf() + ? actualLastSymbol->detailsIf() : nullptr}; int actualRank{evaluate::GetRank(actualType.shape())}; bool actualIsPointer{(actualLastSymbol && IsPointer(*actualLastSymbol)) || diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -480,7 +480,7 @@ } } } - if (IsInitialized(symbol, true /* ignore DATA inits */)) { // C808 + if (IsStaticallyInitialized(symbol, true /* ignore DATA inits */)) { // C808 CheckPointerInitialization(symbol); if (IsAutomatic(symbol)) { messages_.Say( diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp --- a/flang/lib/Semantics/check-do-forall.cpp +++ b/flang/lib/Semantics/check-do-forall.cpp @@ -115,10 +115,10 @@ // // Only to be called for symbols with ObjectEntityDetails - static bool HasImpureFinal(const Symbol &symbol) { - if (const Symbol * root{GetAssociationRoot(symbol)}) { - CHECK(root->has()); - if (const DeclTypeSpec * symType{root->GetType()}) { + static bool HasImpureFinal(const Symbol &original) { + const Symbol &symbol{ResolveAssociations(original)}; + if (symbol.has()) { + if (const DeclTypeSpec * symType{symbol.GetType()}) { if (const DerivedTypeSpec * derived{symType->AsDerived()}) { return semantics::HasImpureFinal(*derived); } @@ -142,22 +142,21 @@ // Is it possible that we will we deallocate a polymorphic entity or one // of its components? - static bool MightDeallocatePolymorphic(const Symbol &entity, + static bool MightDeallocatePolymorphic(const Symbol &original, const std::function &WillDeallocate) { - if (const Symbol * root{GetAssociationRoot(entity)}) { - // Check the entity itself, no coarray exception here - if (IsPolymorphicAllocatable(*root)) { - return true; - } - // Check the components - if (const auto *details{root->detailsIf()}) { - if (const DeclTypeSpec * entityType{details->type()}) { - if (const DerivedTypeSpec * derivedType{entityType->AsDerived()}) { - UltimateComponentIterator ultimates{*derivedType}; - for (const auto &ultimate : ultimates) { - if (WillDeallocatePolymorphic(ultimate, WillDeallocate)) { - return true; - } + const Symbol &symbol{ResolveAssociations(original)}; + // Check the entity itself, no coarray exception here + if (IsPolymorphicAllocatable(symbol)) { + return true; + } + // Check the components + if (const auto *details{symbol.detailsIf()}) { + if (const DeclTypeSpec * entityType{details->type()}) { + if (const DerivedTypeSpec * derivedType{entityType->AsDerived()}) { + UltimateComponentIterator ultimates{*derivedType}; + for (const auto &ultimate : ultimates) { + if (WillDeallocatePolymorphic(ultimate, WillDeallocate)) { + return true; } } } @@ -561,9 +560,7 @@ // symbols for (const parser::Name &name : names->v) { if (const Symbol * symbol{parentScope.FindSymbol(name.source)}) { - if (const Symbol * root{GetAssociationRoot(*symbol)}) { - symbols.insert(*root); - } + symbols.insert(ResolveAssociations(*symbol)); } } } @@ -575,9 +572,7 @@ SymbolSet result; if (const auto *expr{GetExpr(expression)}) { for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) { - if (const Symbol * root{GetAssociationRoot(symbol)}) { - result.insert(*root); - } + result.insert(ResolveAssociations(symbol)); } } return result; diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -4945,18 +4945,19 @@ // type came from explicit type-spec } else if (!prev) { ApplyImplicitRules(symbol); - } else if (const Symbol * prevRoot{GetAssociationRoot(*prev)}) { + } else { + const Symbol &prevRoot{ResolveAssociations(*prev)}; // prev could be host- use- or construct-associated with another symbol - if (!prevRoot->has() && - !prevRoot->has()) { + if (!prevRoot.has() && + !prevRoot.has()) { Say2(name, "Index name '%s' conflicts with existing identifier"_err_en_US, *prev, "Previous declaration of '%s'"_en_US); return; } else { - if (const auto *type{prevRoot->GetType()}) { + if (const auto *type{prevRoot.GetType()}) { symbol.SetType(*type); } - if (prevRoot->IsObjectArray()) { + if (prevRoot.IsObjectArray()) { SayWithDecl(name, *prev, "Index variable '%s' is not scalar"_err_en_US); return; } @@ -5047,7 +5048,7 @@ } // Sets InDataStmt flag on a variable (or misidentified function) in a DATA -// statement so that the predicate IsInitialized(base symbol) will be true +// statement so that the predicate IsStaticallyInitialized() will be true // during semantic analysis before the symbol's initializer is constructed. bool ConstructVisitor::Pre(const parser::DataIDoObject &x) { std::visit( @@ -5090,11 +5091,10 @@ if (auto *elem{parser::Unwrap(mutableData)}) { if (const auto *name{std::get_if(&elem->base.u)}) { if (const Symbol * symbol{FindSymbol(*name)}) { - if (const Symbol * ultimate{GetAssociationRoot(*symbol)}) { - if (ultimate->has()) { - mutableData.u = elem->ConvertToStructureConstructor( - DerivedTypeSpec{name->source, *ultimate}); - } + const Symbol &ultimate{symbol->GetUltimate()}; + if (ultimate.has()) { + mutableData.u = elem->ConvertToStructureConstructor( + DerivedTypeSpec{name->source, ultimate}); } } } diff --git a/flang/lib/Semantics/semantics.cpp b/flang/lib/Semantics/semantics.cpp --- a/flang/lib/Semantics/semantics.cpp +++ b/flang/lib/Semantics/semantics.cpp @@ -264,13 +264,12 @@ void SemanticsContext::CheckIndexVarRedefine(const parser::CharBlock &location, const Symbol &variable, parser::MessageFixedText &&message) { - if (const Symbol * root{GetAssociationRoot(variable)}) { - auto it{activeIndexVars_.find(*root)}; - if (it != activeIndexVars_.end()) { - std::string kind{EnumToString(it->second.kind)}; - Say(location, std::move(message), kind, root->name()) - .Attach(it->second.location, "Enclosing %s construct"_en_US, kind); - } + const Symbol &symbol{ResolveAssociations(variable)}; + auto it{activeIndexVars_.find(symbol)}; + if (it != activeIndexVars_.end()) { + std::string kind{EnumToString(it->second.kind)}; + Say(location, std::move(message), kind, symbol.name()) + .Attach(it->second.location, "Enclosing %s construct"_en_US, kind); } } @@ -302,19 +301,16 @@ const parser::Name &name, IndexVarKind kind) { CheckIndexVarRedefine(name); if (const Symbol * indexVar{name.symbol}) { - if (const Symbol * root{GetAssociationRoot(*indexVar)}) { - activeIndexVars_.emplace(*root, IndexVarInfo{name.source, kind}); - } + activeIndexVars_.emplace( + ResolveAssociations(*indexVar), IndexVarInfo{name.source, kind}); } } void SemanticsContext::DeactivateIndexVar(const parser::Name &name) { if (Symbol * indexVar{name.symbol}) { - if (const Symbol * root{GetAssociationRoot(*indexVar)}) { - auto it{activeIndexVars_.find(*root)}; - if (it != activeIndexVars_.end() && it->second.location == name.source) { - activeIndexVars_.erase(it); - } + auto it{activeIndexVars_.find(ResolveAssociations(*indexVar))}; + if (it != activeIndexVars_.end() && it->second.location == name.source) { + activeIndexVars_.erase(it); } } } diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -510,14 +510,13 @@ IsBuiltinDerivedType(derivedTypeSpec, "lock_type"); } -bool IsOrContainsEventOrLockComponent(const Symbol &symbol) { - if (const Symbol * root{GetAssociationRoot(symbol)}) { - if (const auto *details{root->detailsIf()}) { - if (const DeclTypeSpec * type{details->type()}) { - if (const DerivedTypeSpec * derived{type->AsDerived()}) { - return IsEventTypeOrLockType(derived) || - FindEventOrLockPotentialComponent(*derived); - } +bool IsOrContainsEventOrLockComponent(const Symbol &original) { + const Symbol &symbol{ResolveAssociations(original)}; + if (const auto *details{symbol.detailsIf()}) { + if (const DeclTypeSpec * type{details->type()}) { + if (const DerivedTypeSpec * derived{type->AsDerived()}) { + return IsEventTypeOrLockType(derived) || + FindEventOrLockPotentialComponent(*derived); } } } @@ -541,35 +540,39 @@ } } -bool IsInitialized(const Symbol &symbol, bool ignoreDATAstatements, - const Symbol *derivedTypeSymbol) { +bool IsStaticallyInitialized(const Symbol &symbol, bool ignoreDATAstatements) { if (!ignoreDATAstatements && symbol.test(Symbol::Flag::InDataStmt)) { return true; } else if (IsNamedConstant(symbol)) { return false; } else if (const auto *object{symbol.detailsIf()}) { - if (object->init()) { - return true; - } else if (object->isDummy() || IsFunctionResult(symbol)) { - return false; - } else if (IsAllocatable(symbol)) { - return true; - } else if (!IsPointer(symbol) && object->type()) { - if (const auto *derived{object->type()->AsDerived()}) { - if (&derived->typeSymbol() == derivedTypeSymbol) { - // error recovery: avoid infinite recursion on invalid - // recursive usage of a derived type - } else if (derived->HasDefaultInitialization()) { - return true; - } - } - } + return object->init().has_value(); } else if (const auto *proc{symbol.detailsIf()}) { return proc->init().has_value(); } return false; } +bool IsInitialized(const Symbol &symbol, bool ignoreDATAstatements, + const Symbol *derivedTypeSymbol) { + if (IsStaticallyInitialized(symbol, ignoreDATAstatements) || + IsAllocatable(symbol)) { + return true; + } else if (IsNamedConstant(symbol) || IsFunctionResult(symbol) || + IsPointer(symbol)) { + return false; + } else if (const auto *object{symbol.detailsIf()}) { + if (!object->isDummy() && object->type()) { + const auto *derived{object->type()->AsDerived()}; + // error recovery: avoid infinite recursion on invalid + // recursive usage of a derived type + return derived && &derived->typeSymbol() != derivedTypeSymbol && + derived->HasDefaultInitialization(); + } + } + return false; +} + bool HasIntrinsicTypeName(const Symbol &symbol) { std::string name{symbol.name().ToString()}; if (name == "doubleprecision") { @@ -730,12 +733,7 @@ const Symbol *IsExternalInPureContext( const Symbol &symbol, const Scope &scope) { if (const auto *pureProc{FindPureProcedureContaining(scope)}) { - if (const Symbol * root{GetAssociationRoot(symbol)}) { - if (const Symbol * - visible{FindExternallyVisibleObject(*root, *pureProc)}) { - return visible; - } - } + return FindExternallyVisibleObject(symbol.GetUltimate(), *pureProc); } return nullptr; } @@ -753,16 +751,15 @@ }); } -bool IsOrContainsPolymorphicComponent(const Symbol &symbol) { - if (const Symbol * root{GetAssociationRoot(symbol)}) { - if (const auto *details{root->detailsIf()}) { - if (const DeclTypeSpec * type{details->type()}) { - if (type->IsPolymorphic()) { - return true; - } - if (const DerivedTypeSpec * derived{type->AsDerived()}) { - return (bool)FindPolymorphicPotentialComponent(*derived); - } +bool IsOrContainsPolymorphicComponent(const Symbol &original) { + const Symbol &symbol{ResolveAssociations(original)}; + if (const auto *details{symbol.detailsIf()}) { + if (const DeclTypeSpec * type{details->type()}) { + if (type->IsPolymorphic()) { + return true; + } + if (const DerivedTypeSpec * derived{type->AsDerived()}) { + return (bool)FindPolymorphicPotentialComponent(*derived); } } } @@ -775,20 +772,20 @@ // C1101 and C1158 std::optional WhyNotModifiable( - const Symbol &symbol, const Scope &scope) { - const Symbol *root{GetAssociationRoot(symbol)}; - if (!root) { + const Symbol &original, const Scope &scope) { + const Symbol &symbol{GetAssociationRoot(original)}; + if (symbol.has()) { return "'%s' is construct associated with an expression"_en_US; - } else if (InProtectedContext(*root, scope)) { + } else if (InProtectedContext(symbol, scope)) { return "'%s' is protected in this scope"_en_US; - } else if (IsExternalInPureContext(*root, scope)) { + } else if (IsExternalInPureContext(symbol, scope)) { return "'%s' is externally visible and referenced in a pure" " procedure"_en_US; - } else if (IsOrContainsEventOrLockComponent(*root)) { + } else if (IsOrContainsEventOrLockComponent(symbol)) { return "'%s' is an entity with either an EVENT_TYPE or LOCK_TYPE"_en_US; - } else if (IsIntentIn(*root)) { + } else if (IsIntentIn(symbol)) { return "'%s' is an INTENT(IN) dummy argument"_en_US; - } else if (!IsVariableName(*root)) { + } else if (!IsVariableName(symbol)) { return "'%s' is not a variable"_en_US; } else { return std::nullopt; @@ -940,10 +937,8 @@ bool HasCoarray(const parser::Expr &expression) { if (const auto *expr{GetExpr(expression)}) { for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) { - if (const Symbol * root{GetAssociationRoot(symbol)}) { - if (IsCoarray(*root)) { - return true; - } + if (IsCoarray(GetAssociationRoot(symbol))) { + return true; } } } diff --git a/flang/test/Semantics/data04.f90 b/flang/test/Semantics/data04.f90 --- a/flang/test/Semantics/data04.f90 +++ b/flang/test/Semantics/data04.f90 @@ -62,7 +62,6 @@ end type type(large) largeNumber type(large), allocatable :: allocatableLarge - !ERROR: An automatic variable or component must not be initialized type(large) :: largeNumberArray(i) type(large) :: largeArray(5) character :: name(i) diff --git a/flang/test/Semantics/resolve44.f90 b/flang/test/Semantics/resolve44.f90 --- a/flang/test/Semantics/resolve44.f90 +++ b/flang/test/Semantics/resolve44.f90 @@ -20,12 +20,10 @@ integer, kind :: kind integer, len :: len !ERROR: Recursive use of the derived type requires POINTER or ALLOCATABLE - !ERROR: An automatic variable or component must not be initialized type(recursive2(kind,len)) :: bad1 type(recursive2(kind,len)), pointer :: ok1 type(recursive2(kind,len)), allocatable :: ok2 !ERROR: Recursive use of the derived type requires POINTER or ALLOCATABLE - !ERROR: An automatic variable or component must not be initialized !ERROR: CLASS entity 'bad2' must be a dummy argument or have ALLOCATABLE or POINTER attribute class(recursive2(kind,len)) :: bad2 class(recursive2(kind,len)), pointer :: ok3 diff --git a/flang/test/Semantics/shape.f90 b/flang/test/Semantics/shape.f90 --- a/flang/test/Semantics/shape.f90 +++ b/flang/test/Semantics/shape.f90 @@ -6,18 +6,26 @@ integer :: arrayDummy(:) integer, allocatable :: arrayDeferred(:) integer :: arrayLocal(2) = [88, 99] + !ERROR: Dimension 1 of left operand has extent 1, but right operand has extent 0 + !ERROR: Dimension 1 of left operand has extent 1, but right operand has extent 0 if (all(shape(arrayDummy)==shape(8))) then print *, "hello" end if + !ERROR: Dimension 1 of left operand has extent 0, but right operand has extent 1 + !ERROR: Dimension 1 of left operand has extent 0, but right operand has extent 1 if (all(shape(27)==shape(arrayDummy))) then print *, "hello" end if if (all(64==shape(arrayDummy))) then print *, "hello" end if + !ERROR: Dimension 1 of left operand has extent 1, but right operand has extent 0 + !ERROR: Dimension 1 of left operand has extent 1, but right operand has extent 0 if (all(shape(arrayDeferred)==shape(8))) then print *, "hello" end if + !ERROR: Dimension 1 of left operand has extent 0, but right operand has extent 1 + !ERROR: Dimension 1 of left operand has extent 0, but right operand has extent 1 if (all(shape(27)==shape(arrayDeferred))) then print *, "hello" end if