diff --git a/flang/include/flang/Evaluate/constant.h b/flang/include/flang/Evaluate/constant.h --- a/flang/include/flang/Evaluate/constant.h +++ b/flang/include/flang/Evaluate/constant.h @@ -64,6 +64,7 @@ ~ConstantBounds(); const ConstantSubscripts &shape() const { return shape_; } const ConstantSubscripts &lbounds() const { return lbounds_; } + ConstantSubscripts ComputeUbounds(std::optional dim) const; void set_lbounds(ConstantSubscripts &&); void SetLowerBoundsToOne(); int Rank() const { return GetRank(shape_); } diff --git a/flang/lib/Evaluate/constant.cpp b/flang/lib/Evaluate/constant.cpp --- a/flang/lib/Evaluate/constant.cpp +++ b/flang/lib/Evaluate/constant.cpp @@ -32,6 +32,20 @@ } } +ConstantSubscripts ConstantBounds::ComputeUbounds( + std::optional dim) const { + if (dim) { + CHECK(*dim < Rank()); + return {lbounds()[*dim] + shape()[*dim] - 1}; + } else { + ConstantSubscripts ubounds(Rank()); + for (int i = 0; i < Rank(); ++i) { + ubounds[i] = lbounds()[i] + shape()[i] - 1; + } + return ubounds; + } +} + void ConstantBounds::SetLowerBoundsToOne() { for (auto &n : lbounds_) { n = 1; diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -12,21 +12,28 @@ namespace Fortran::evaluate { +namespace { +// Helper types for GetConstantArrayBoundHelper to select ubound/lbound getter +class LowerBound {}; +class UpperBound {}; + // Class to retrieve the constant lower bound of an expression which is an // array that devolves to a type of Constant -class GetConstantArrayLboundHelper { +template class GetConstantArrayBoundHelper { public: - GetConstantArrayLboundHelper(std::optional dim) + GetConstantArrayBoundHelper(std::optional dim) : dim_{dim} {} - template ConstantSubscripts GetLbound(const T &) { + template ConstantSubscripts Get(const T &) { // The method is needed for template expansion, but we should never get // here in practice. CHECK(false); return {0}; } - template ConstantSubscripts GetLbound(const Constant &x) { + template , bool> = true> + ConstantSubscripts Get(const Constant &x) { // Return the lower bound if (dim_) { return {x.lbounds().at(*dim_)}; @@ -35,14 +42,21 @@ } } - template ConstantSubscripts GetLbound(const Parentheses &x) { + template , bool> = true> + ConstantSubscripts Get(const Constant &x) { + // Return the upper bound + return x.ComputeUbounds(dim_); + } + + template ConstantSubscripts Get(const Parentheses &x) { // Strip off the parentheses - return GetLbound(x.left()); + return Get(x.left()); } - template ConstantSubscripts GetLbound(const Expr &x) { + template ConstantSubscripts Get(const Expr &x) { // recurse through Expr'a until we hit a constant - return std::visit([&](const auto &inner) { return GetLbound(inner); }, + return std::visit([&](const auto &inner) { return Get(inner); }, // [&](const auto &) { return 0; }, x.u); } @@ -50,6 +64,7 @@ private: std::optional dim_; }; +} // namespace template Expr> LBOUND(FoldingContext &context, @@ -95,7 +110,7 @@ } if (IsActuallyConstant(*array)) { const ConstantSubscripts bounds = - GetConstantArrayLboundHelper{dim}.GetLbound(*array); + GetConstantArrayBoundHelper{dim}.Get(*array); return Expr{PackageConstant(std::move(bounds), dim.has_value())}; } if (lowerBoundsAreOne) { @@ -160,6 +175,11 @@ takeBoundsFromShape = symbol.Rank() == 0; // UBOUND(array%component) } } + if (IsActuallyConstant(*array)) { + const ConstantSubscripts ubounds = + GetConstantArrayBoundHelper{dim}.Get(*array); + return Expr{PackageConstant(std::move(ubounds), dim.has_value())}; + } if (takeBoundsFromShape) { if (auto shape{GetContextFreeShape(context, *array)}) { if (dim) { diff --git a/flang/test/Evaluate/folding08.f90 b/flang/test/Evaluate/folding08.f90 --- a/flang/test/Evaluate/folding08.f90 +++ b/flang/test/Evaluate/folding08.f90 @@ -95,4 +95,22 @@ integer, parameter :: lba4(*) = lbound(a4) logical, parameter :: test_lba4 = all(lba4 == [2, 1, 4]) end subroutine + subroutine test4_ubound_parameter + ! Test ubound with constant arrays + integer, parameter :: a1(1) = 0 + integer, parameter :: lba1(*) = ubound(a1) + logical, parameter :: test_lba1 = all(lba1 == [1]) + + integer, parameter :: a2(0:0) = 0 + integer, parameter :: lba2(*) = ubound(a2) + logical, parameter :: test_lba2 = all(lba2 == [0]) + + integer, parameter :: a3(2:4,4:6) = 0 + integer, parameter :: lba3(*) = ubound(a3) + logical, parameter :: test_lba3 = all(lba3 == [4, 6]) + + integer, parameter :: a4(2:4,1,4:6) = 0 + integer, parameter :: lba4(*) = ubound(a4) + logical, parameter :: test_lba4 = all(lba4 == [4, 1, 6]) + end subroutine end