diff --git a/flang/include/flang/Evaluate/characteristics.h b/flang/include/flang/Evaluate/characteristics.h --- a/flang/include/flang/Evaluate/characteristics.h +++ b/flang/include/flang/Evaluate/characteristics.h @@ -16,6 +16,7 @@ #include "common.h" #include "expression.h" #include "shape.h" +#include "tools.h" #include "type.h" #include "flang/Common/Fortran-features.h" #include "flang/Common/Fortran.h" diff --git a/flang/include/flang/Evaluate/shape.h b/flang/include/flang/Evaluate/shape.h --- a/flang/include/flang/Evaluate/shape.h +++ b/flang/include/flang/Evaluate/shape.h @@ -13,11 +13,9 @@ #define FORTRAN_EVALUATE_SHAPE_H_ #include "expression.h" -#include "fold.h" #include "traverse.h" #include "variable.h" #include "flang/Common/indirection.h" -#include "flang/Evaluate/tools.h" #include "flang/Evaluate/type.h" #include #include @@ -201,12 +199,7 @@ ExtentExpr result{0}; for (const auto &value : values) { if (MaybeExtentExpr n{GetArrayConstructorValueExtent(value)}) { - result = std::move(result) + std::move(*n); - if (context_) { - // Fold during expression creation to avoid creating an expression so - // large we can't evalute it without overflowing the stack. - result = Fold(*context_, std::move(result)); - } + AccumulateExtent(result, std::move(*n)); } else { return std::nullopt; } @@ -214,6 +207,9 @@ return result; } + // Add an extent to another, with folding + void AccumulateExtent(ExtentExpr &, ExtentExpr &&) const; + FoldingContext *context_{nullptr}; bool useResultSymbolShape_{true}; }; diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -15,6 +15,8 @@ #include "flang/Common/unwrap.h" #include "flang/Evaluate/constant.h" #include "flang/Evaluate/expression.h" +#include "flang/Evaluate/shape.h" +#include "flang/Evaluate/type.h" #include "flang/Parser/message.h" #include "flang/Semantics/attr.h" #include "flang/Semantics/symbol.h" @@ -1026,8 +1028,14 @@ }; template -bool IsExpandableScalar(const Expr &expr, bool admitPureCall = false) { - return !UnexpandabilityFindingVisitor{admitPureCall}(expr); +bool IsExpandableScalar(const Expr &expr, FoldingContext &context, + const Shape &shape, bool admitPureCall = false) { + if (UnexpandabilityFindingVisitor{admitPureCall}(expr)) { + auto extents{AsConstantExtents(context, shape)}; + return extents && GetSize(*extents) == 1; + } else { + return true; + } } // Common handling for procedure pointer compatibility of left- and right-hand diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -1539,17 +1539,19 @@ std::move(resultLength), std::move(*left), std::move(*right)); } } - } else if (IsExpandableScalar(rightExpr)) { + } else if (IsExpandableScalar(rightExpr, context, *leftShape)) { return MapOperation(context, std::move(f), *leftShape, std::move(resultLength), std::move(*left), rightExpr); } } } - } else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) { - if (std::optional shape{GetShape(context, rightExpr)}) { - if (auto right{AsFlatArrayConstructor(rightExpr)}) { - return MapOperation(context, std::move(f), *shape, - std::move(resultLength), leftExpr, std::move(*right)); + } else if (rightExpr.Rank() > 0) { + if (std::optional rightShape{GetShape(context, rightExpr)}) { + if (IsExpandableScalar(leftExpr, context, *rightShape)) { + if (auto right{AsFlatArrayConstructor(rightExpr)}) { + return MapOperation(context, std::move(f), *rightShape, + std::move(resultLength), leftExpr, std::move(*right)); + } } } } diff --git a/flang/lib/Evaluate/shape.cpp b/flang/lib/Evaluate/shape.cpp --- a/flang/lib/Evaluate/shape.cpp +++ b/flang/lib/Evaluate/shape.cpp @@ -1021,6 +1021,16 @@ return Shape(static_cast(call.Rank()), MaybeExtentExpr{}); } +void GetShapeHelper::AccumulateExtent( + ExtentExpr &result, ExtentExpr &&n) const { + result = std::move(result) + std::move(n); + if (context_) { + // Fold during expression creation to avoid creating an expression so + // large we can't evalute it without overflowing the stack. + result = Fold(*context_, std::move(result)); + } +} + // Check conformance of the passed shapes. std::optional CheckConformance(parser::ContextualMessages &messages, const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags, diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -1836,7 +1836,8 @@ "component", "value")}; if (checked && *checked && GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0 && - !IsExpandableScalar(*converted, true /*admit PURE call*/)) { + !IsExpandableScalar(*converted, GetFoldingContext(), + *componentShape, true /*admit PURE call*/)) { AttachDeclaration( Say(expr.source, "Scalar value cannot be expanded to shape of array component '%s'"_err_en_US,