diff --git a/flang/include/flang/Evaluate/constant.h b/flang/include/flang/Evaluate/constant.h --- a/flang/include/flang/Evaluate/constant.h +++ b/flang/include/flang/Evaluate/constant.h @@ -183,6 +183,9 @@ // Apply subscripts, if any. Scalar At(const ConstantSubscripts &) const; + // Extract substring(s); returns nullopt for errors. + std::optional Substring(ConstantSubscript, ConstantSubscript) const; + Constant Reshape(ConstantSubscripts &&) const; llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const; static constexpr DynamicType GetType() { diff --git a/flang/lib/Evaluate/constant.cpp b/flang/lib/Evaluate/constant.cpp --- a/flang/lib/Evaluate/constant.cpp +++ b/flang/lib/Evaluate/constant.cpp @@ -221,6 +221,28 @@ return values_.substr(offset * length_, length_); } +template +auto Constant>::Substring( + ConstantSubscript lo, ConstantSubscript hi) const + -> std::optional { + std::vector elements; + ConstantSubscript n{GetSize(shape())}; + ConstantSubscript newLength{0}; + if (lo > hi) { // zero-length results + while (n-- > 0) { + elements.emplace_back(); // "" + } + } else if (lo < 1 || hi > length_) { + return std::nullopt; + } else { + newLength = hi - lo + 1; + for (ConstantSubscripts at{lbounds()}; n-- > 0; IncrementSubscripts(at)) { + elements.emplace_back(At(at).substr(lo - 1, newLength)); + } + } + return Constant{newLength, std::move(elements), ConstantSubscripts{shape()}}; +} + template auto Constant>::Reshape( ConstantSubscripts &&dims) const -> Constant { diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -331,8 +331,8 @@ if (auto *substring{common::Unwrap(designator.u)}) { if (std::optional> folded{ substring->Fold(context_)}) { - if (auto value{GetScalarConstantValue(*folded)}) { - return Expr{*value}; + if (const auto *specific{std::get_if>(&folded->u)}) { + return std::move(*specific); } } if (auto length{ToInt64(Fold(context_, substring->LEN()))}) { diff --git a/flang/lib/Evaluate/variable.cpp b/flang/lib/Evaluate/variable.cpp --- a/flang/lib/Evaluate/variable.cpp +++ b/flang/lib/Evaluate/variable.cpp @@ -163,81 +163,81 @@ } std::optional> Substring::Fold(FoldingContext &context) { + if (!upper_) { + upper_ = upper(); + if (!upper_) { + return std::nullopt; + } + } + upper_.value() = evaluate::Fold(context, std::move(upper_.value().value())); + std::optional ubi{ToInt64(upper_.value().value())}; + if (!ubi) { + return std::nullopt; + } if (!lower_) { lower_ = AsExpr(Constant{1}); } lower_.value() = evaluate::Fold(context, std::move(lower_.value().value())); std::optional lbi{ToInt64(lower_.value().value())}; - if (lbi && *lbi < 1) { - context.messages().Say( - "Lower bound (%jd) on substring is less than one"_en_US, *lbi); + if (!lbi) { + return std::nullopt; + } + if (*lbi > *ubi) { // empty result; canonicalize *lbi = 1; - lower_ = AsExpr(Constant{1}); + *ubi = 0; + lower_ = AsExpr(Constant{*lbi}); + upper_ = AsExpr(Constant{*ubi}); } - if (!upper_) { - upper_ = upper(); - if (!upper_) { - return std::nullopt; + std::optional length; + std::optional> strings; // a Constant + if (const auto *literal{std::get_if(&parent_)}) { + length = (*literal)->data().size(); + if (auto str{(*literal)->AsString()}) { + strings = + Expr(Expr(Constant{std::move(*str)})); } - } - upper_.value() = evaluate::Fold(context, std::move(upper_.value().value())); - if (std::optional ubi{ToInt64(upper_.value().value())}) { - auto *literal{std::get_if(&parent_)}; - std::optional length; - if (literal) { - length = (*literal)->data().size(); - } else if (const Symbol * symbol{GetLastSymbol()}) { - if (const semantics::DeclTypeSpec * type{symbol->GetType()}) { - if (type->category() == semantics::DeclTypeSpec::Character) { - length = ToInt64(type->characterTypeSpec().length().GetExplicit()); + } else if (const auto *dataRef{std::get_if(&parent_)}) { + if (auto expr{AsGenericExpr(DataRef{*dataRef})}) { + auto folded{evaluate::Fold(context, std::move(*expr))}; + if (IsActuallyConstant(folded)) { + if (const auto *value{UnwrapExpr>(folded)}) { + strings = *value; } } } - if (*ubi < 1 || (lbi && *ubi < *lbi)) { - // Zero-length string: canonicalize - *lbi = 1, *ubi = 0; - lower_ = AsExpr(Constant{*lbi}); - upper_ = AsExpr(Constant{*ubi}); - } else if (length && *ubi > *length) { - context.messages().Say("Upper bound (%jd) on substring is greater " - "than character length (%jd)"_en_US, - *ubi, *length); - *ubi = *length; - } - if (lbi && literal) { - auto newStaticData{StaticDataObject::Create()}; - auto items{0}; // If the lower bound is greater, the length is 0 - if (*ubi >= *lbi) { - items = *ubi - *lbi + 1; - } - auto width{(*literal)->itemBytes()}; - auto bytes{items * width}; - auto startByte{(*lbi - 1) * width}; - const auto *from{&(*literal)->data()[0] + startByte}; - for (auto j{0}; j < bytes; ++j) { - newStaticData->data().push_back(from[j]); - } - parent_ = newStaticData; + } + std::optional> result; + if (strings) { + result = std::visit( + [&](const auto &expr) -> std::optional> { + using Type = typename std::decay_t::Result; + if (const auto *cc{std::get_if>(&expr.u)}) { + if (auto substr{cc->Substring(*lbi, *ubi)}) { + return Expr{Expr{*substr}}; + } + } + return std::nullopt; + }, + strings->u); + } + if (!result) { // error cases + if (*lbi < 1) { + context.messages().Say( + "Lower bound (%jd) on substring is less than one"_en_US, + static_cast(*lbi)); + *lbi = 1; lower_ = AsExpr(Constant{1}); - ConstantSubscript length = newStaticData->data().size(); - upper_ = AsExpr(Constant{length}); - switch (width) { - case 1: - return { - AsCategoryExpr(AsExpr(Constant>{ - *newStaticData->AsString()}))}; - case 2: - return {AsCategoryExpr(Constant>{ - *newStaticData->AsU16String()})}; - case 4: - return {AsCategoryExpr(Constant>{ - *newStaticData->AsU32String()})}; - default: - CRASH_NO_CASE; - } + } + if (length && *ubi > *length) { + context.messages().Say( + "Upper bound (%jd) on substring is greater than character length (%jd)"_en_US, + static_cast(*ubi), + static_cast(*length)); + *ubi = *length; + upper_ = AsExpr(Constant{*ubi}); } } - return std::nullopt; + return result; } DescriptorInquiry::DescriptorInquiry( diff --git a/flang/test/Evaluate/fold-substr.f90 b/flang/test/Evaluate/fold-substr.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/fold-substr.f90 @@ -0,0 +1,17 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Test folding of substrings +module m + logical, parameter :: test_01a = "abc"(1:3) == "abc" + logical, parameter :: test_01b = len("abc"(1:3)) == 3 + logical, parameter :: test_02a = "abc"(-1:-2) == "" + logical, parameter :: test_02b = len("abc"(-1:-2)) == 0 + logical, parameter :: test_03a = "abc"(9999:4) == "" + logical, parameter :: test_03b = len("abc"(9999:4)) == 0 + character(4), parameter :: ca(3) = ["abcd", "efgh", "ijkl"] + logical, parameter :: test_04a = ca(2)(2:4) == "fgh" + logical, parameter :: test_04b = len(ca(2)(2:4)) == 3 + logical, parameter :: test_05a = all(ca(:)(2:4) == ["bcd", "fgh", "jkl"]) + logical, parameter :: test_05b = len(ca(:)(2:4)) == 3 + logical, parameter :: test_06a = ca(1)(1:2)//ca(2)(2:3)//ca(3)(3:4) == "abfgkl" + logical, parameter :: test_06b = len(ca(1)(1:2)//ca(2)(2:3)//ca(3)(3:4)) == 6 +end module