diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -1310,15 +1310,28 @@ // into an Expr, folds it, and returns the resulting wrapped // array constructor or constant array value. template -Expr FromArrayConstructor(FoldingContext &context, - ArrayConstructor &&values, std::optional &&shape) { - Expr result{Fold(context, Expr{std::move(values)})}; - if (shape) { +std::optional> FromArrayConstructor( + FoldingContext &context, ArrayConstructor &&values, const Shape &shape) { + if (auto constShape{AsConstantExtents(context, shape)}) { + Expr result{Fold(context, Expr{std::move(values)})}; if (auto *constant{UnwrapConstantValue(result)}) { - return Expr{constant->Reshape(std::move(*shape))}; + // Elements and shape are both constant. + return Expr{constant->Reshape(std::move(*constShape))}; + } + if (constShape->size() == 1) { + if (auto elements{GetShape(context, result)}) { + if (auto constElements{AsConstantExtents(context, *elements)}) { + if (constElements->size() == 1 && + constElements->at(0) == constShape->at(0)) { + // Elements are not constant, but array constructor has + // the right known shape and can be simply returned as is. + return std::move(result); + } + } + } } } - return result; + return std::nullopt; } // MapOperation is a utility for various specializations of ApplyElementwise() @@ -1330,7 +1343,7 @@ // Unary case template -Expr MapOperation(FoldingContext &context, +std::optional> MapOperation(FoldingContext &context, std::function(Expr &&)> &&f, const Shape &shape, Expr &&values) { ArrayConstructor result{values}; @@ -1352,8 +1365,7 @@ result.Push(Fold(context, f(std::move(scalar)))); } } - return FromArrayConstructor( - context, std::move(result), AsConstantExtents(context, shape)); + return FromArrayConstructor(context, std::move(result), shape); } template @@ -1369,10 +1381,11 @@ // array * array case template -Expr MapOperation(FoldingContext &context, +auto MapOperation(FoldingContext &context, std::function(Expr &&, Expr &&)> &&f, const Shape &shape, std::optional> &&length, - Expr &&leftValues, Expr &&rightValues) { + Expr &&leftValues, Expr &&rightValues) + -> std::optional> { auto result{ArrayConstructorFromMold(leftValues, std::move(length))}; auto &leftArrConst{std::get>(leftValues.u)}; if constexpr (common::HasMember) { @@ -1404,16 +1417,16 @@ ++rightIter; } } - return FromArrayConstructor( - context, std::move(result), AsConstantExtents(context, shape)); + return FromArrayConstructor(context, std::move(result), shape); } // array * scalar case template -Expr MapOperation(FoldingContext &context, +auto MapOperation(FoldingContext &context, std::function(Expr &&, Expr &&)> &&f, const Shape &shape, std::optional> &&length, - Expr &&leftValues, const Expr &rightScalar) { + Expr &&leftValues, const Expr &rightScalar) + -> std::optional> { auto result{ArrayConstructorFromMold(leftValues, std::move(length))}; auto &leftArrConst{std::get>(leftValues.u)}; for (auto &leftValue : leftArrConst) { @@ -1421,16 +1434,16 @@ result.Push( Fold(context, f(std::move(leftScalar), Expr{rightScalar}))); } - return FromArrayConstructor( - context, std::move(result), AsConstantExtents(context, shape)); + return FromArrayConstructor(context, std::move(result), shape); } // scalar * array case template -Expr MapOperation(FoldingContext &context, +auto MapOperation(FoldingContext &context, std::function(Expr &&, Expr &&)> &&f, const Shape &shape, std::optional> &&length, - const Expr &leftScalar, Expr &&rightValues) { + const Expr &leftScalar, Expr &&rightValues) + -> std::optional> { auto result{ArrayConstructorFromMold(leftScalar, std::move(length))}; if constexpr (common::HasMember) { common::visit( @@ -1453,8 +1466,7 @@ Fold(context, f(Expr{leftScalar}, std::move(rightScalar)))); } } - return FromArrayConstructor( - context, std::move(result), AsConstantExtents(context, shape)); + return FromArrayConstructor(context, std::move(result), shape); } template