diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -992,6 +992,23 @@ std::optional lbounds_; }; +// Given a collection of element values, package them as a Constant. +// If the type is Character or a derived type, take the length or type +// (resp.) from a another Constant. +template +Constant PackageConstant(std::vector> &&elements, + const Constant &reference, const ConstantSubscripts &shape) { + if constexpr (T::category == TypeCategory::Character) { + return Constant{ + reference.LEN(), std::move(elements), ConstantSubscripts{shape}}; + } else if constexpr (T::category == TypeCategory::Derived) { + return Constant{reference.GetType().GetDerivedTypeSpec(), + std::move(elements), ConstantSubscripts{shape}}; + } else { + return Constant{std::move(elements), ConstantSubscripts{shape}}; + } +} + } // namespace Fortran::evaluate namespace Fortran::semantics { diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp --- a/flang/lib/Evaluate/fold-character.cpp +++ b/flang/lib/Evaluate/fold-character.cpp @@ -102,8 +102,7 @@ CharacterUtils::TRIM(std::get>(*scalar))}}; } } - // TODO: cshift, eoshift, maxloc, minloc, pack, spread, transfer, - // transpose, unpack + // TODO: findloc, maxloc, minloc, transfer return Expr{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp --- a/flang/lib/Evaluate/fold-complex.cpp +++ b/flang/lib/Evaluate/fold-complex.cpp @@ -60,8 +60,7 @@ } else if (name == "sum") { return FoldSum(context, std::move(funcRef)); } - // TODO: cshift, dot_product, eoshift, matmul, pack, spread, transfer, - // transpose, unpack + // TODO: dot_product, matmul, transfer return Expr{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -60,7 +60,9 @@ std::optional> Folding(ArrayRef &); Expr Folding(Designator &&); Constant *Folding(std::optional &); - Expr Reshape(FunctionRef &&); + + Expr CSHIFT(FunctionRef &&); + Expr RESHAPE(FunctionRef &&); private: FoldingContext &context_; @@ -546,7 +548,78 @@ ActualArguments{std::move(funcRef.arguments())}}}; } -template Expr Folder::Reshape(FunctionRef &&funcRef) { +template Expr Folder::CSHIFT(FunctionRef &&funcRef) { + auto args{funcRef.arguments()}; + CHECK(args.size() == 3); + const auto *array{UnwrapConstantValue(args[0])}; + const auto *shiftExpr{UnwrapExpr>(args[1])}; + auto dim{GetInt64ArgOr(args[2], 1)}; + if (!array || !shiftExpr || !dim) { + return Expr{std::move(funcRef)}; + } + auto convertedShift{Fold(context_, + ConvertToType(Expr{*shiftExpr}))}; + const auto *shift{UnwrapConstantValue(convertedShift)}; + if (!shift) { + return Expr{std::move(funcRef)}; + } + // Arguments are constant + if (*dim < 1 || *dim > array->Rank()) { + context_.messages().Say("Invalid 'dim=' argument (%jd) in CSHIFT"_err_en_US, + static_cast(*dim)); + } else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) { + // message already emitted from intrinsic look-up + } else { + int rank{array->Rank()}; + int zbDim{static_cast(*dim) - 1}; + bool ok{true}; + if (shift->Rank() > 0) { + int k{0}; + for (int j{0}; j < rank; ++j) { + if (j != zbDim) { + if (array->shape()[j] != shift->shape()[k]) { + context_.messages().Say( + "Invalid 'shift=' argument in CSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US, + k + 1, static_cast(shift->shape()[k]), + static_cast(array->shape()[j])); + ok = false; + } + ++k; + } + } + } + if (ok) { + std::vector> resultElements; + ConstantSubscripts arrayAt{array->lbounds()}; + ConstantSubscript dimLB{arrayAt[zbDim]}; + ConstantSubscript dimExtent{array->shape()[zbDim]}; + ConstantSubscripts shiftAt{shift->lbounds()}; + for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) { + ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()}; + ConstantSubscript zbDimIndex{shiftCount % dimExtent}; + if (zbDimIndex < 0) { + zbDimIndex += dimExtent; + } + for (ConstantSubscript j{0}; j < dimExtent; ++j) { + arrayAt[zbDim] = dimLB + zbDimIndex; + resultElements.push_back(array->At(arrayAt)); + if (++zbDimIndex == dimExtent) { + zbDimIndex = 0; + } + } + arrayAt[zbDim] = dimLB + dimExtent - 1; + array->IncrementSubscripts(arrayAt); + shift->IncrementSubscripts(shiftAt); + } + return Expr{PackageConstant( + std::move(resultElements), *array, array->shape())}; + } + } + // Invalid, prevent re-folding + return MakeInvalidIntrinsic(std::move(funcRef)); +} + +template Expr Folder::RESHAPE(FunctionRef &&funcRef) { auto args{funcRef.arguments()}; CHECK(args.size() == 4); const auto *source{UnwrapConstantValue(args[0])}; @@ -679,10 +752,13 @@ } if (auto *intrinsic{std::get_if(&funcRef.proc().u)}) { const std::string name{intrinsic->name}; - if (name == "reshape") { - return Folder{context}.Reshape(std::move(funcRef)); + if (name == "cshift") { + return Folder{context}.CSHIFT(std::move(funcRef)); + } else if (name == "reshape") { + return Folder{context}.RESHAPE(std::move(funcRef)); } - // TODO: other type independent transformationals + // TODO: eoshift, pack, spread, unpack, transpose + // TODO: extends_type_of, same_type_as if constexpr (!std::is_same_v) { return FoldIntrinsicFunction(context, std::move(funcRef)); } diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -689,10 +689,8 @@ } else if (name == "ubound") { return UBOUND(context, std::move(funcRef)); } - // TODO: - // cshift, dot_product, eoshift, findloc, ibits, image_status, ishftc, - // matmul, maxloc, minloc, not, pack, sign, spread, transfer, transpose, - // unpack + // TODO: count(w/ dim), dot_product, findloc, ibits, image_status, ishftc, + // matmul, maxloc, minloc, sign, transfer return Expr{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp --- a/flang/lib/Evaluate/fold-logical.cpp +++ b/flang/lib/Evaluate/fold-logical.cpp @@ -125,10 +125,9 @@ name == "__builtin_ieee_support_underflow_control") { return Expr{true}; } - // TODO: btest, cshift, dot_product, eoshift, is_iostat_end, + // TODO: btest, dot_product, eoshift, is_iostat_end, // is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range, - // pack, parity, spread, transfer, transpose, unpack, extends_type_of, - // same_type_as + // parity, transfer return Expr{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -135,9 +135,9 @@ } else if (name == "tiny") { return Expr{Scalar::TINY()}; } - // TODO: cshift, dim, dot_product, eoshift, fraction, matmul, - // maxloc, minloc, modulo, nearest, norm2, pack, rrspacing, scale, - // set_exponent, spacing, spread, transfer, transpose, unpack, + // TODO: dim, dot_product, fraction, matmul, + // maxloc, minloc, modulo, nearest, norm2, rrspacing, scale, + // set_exponent, spacing, transfer, // bessel_jn (transformational) and bessel_yn (transformational) return Expr{std::move(funcRef)}; } diff --git a/flang/test/Evaluate/folding22.f90 b/flang/test/Evaluate/folding22.f90 --- a/flang/test/Evaluate/folding22.f90 +++ b/flang/test/Evaluate/folding22.f90 @@ -20,4 +20,3 @@ logical, parameter :: test_zero_sized = len(zero_sized).eq.6 end - diff --git a/flang/test/Evaluate/folding27.f90 b/flang/test/Evaluate/folding27.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/folding27.f90 @@ -0,0 +1,16 @@ +! RUN: %S/test_folding.sh %s %t %flang_fc1 +! REQUIRES: shell +! Tests folding of CSHIFT (valid cases) +module m + integer, parameter :: arr(2,3) = reshape([1, 2, 3, 4, 5, 6], shape(arr)) + logical, parameter :: test_sanity = all([arr] == [1, 2, 3, 4, 5, 6]) + logical, parameter :: test_cshift_0 = all(cshift([1, 2, 3], 0) == [1, 2, 3]) + logical, parameter :: test_cshift_1 = all(cshift([1, 2, 3], 1) == [2, 3, 1]) + logical, parameter :: test_cshift_2 = all(cshift([1, 2, 3], 3) == [1, 2, 3]) + logical, parameter :: test_cshift_3 = all(cshift([1, 2, 3], 4) == [2, 3, 1]) + logical, parameter :: test_cshift_4 = all(cshift([1, 2, 3], -1) == [3, 1, 2]) + logical, parameter :: test_cshift_5 = all([cshift(arr, 1, dim=1)] == [2, 1, 4, 3, 6, 5]) + logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 5, 1, 4, 6, 2]) + logical, parameter :: test_cshift_7 = all([cshift(arr, [1, 2, 3])] == [2, 1, 3, 4, 6, 5]) + logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 5, 1, 6, 2, 4]) +end module