diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -65,6 +65,7 @@ Expr EOSHIFT(FunctionRef &&); Expr PACK(FunctionRef &&); Expr RESHAPE(FunctionRef &&); + Expr SPREAD(FunctionRef &&); Expr TRANSPOSE(FunctionRef &&); Expr UNPACK(FunctionRef &&); @@ -855,6 +856,51 @@ return MakeInvalidIntrinsic(std::move(funcRef)); } +template Expr Folder::SPREAD(FunctionRef &&funcRef) { + auto args{funcRef.arguments()}; + CHECK(args.size() == 3); + const Constant *source{UnwrapConstantValue(args[0])}; + auto dim{GetInt64Arg(args[1])}; + auto ncopies{GetInt64Arg(args[2])}; + if (!source || !dim) { + return Expr{std::move(funcRef)}; + } + int sourceRank{source->Rank()}; + if (sourceRank >= common::maxRank) { + context_.messages().Say( + "SOURCE= argument to SPREAD has rank %d but must have rank less than %d"_err_en_US, + sourceRank, common::maxRank); + } else if (*dim < 1 || *dim > sourceRank + 1) { + context_.messages().Say( + "DIM=%d argument to SPREAD must be between 1 and %d"_err_en_US, *dim, + sourceRank + 1); + } else if (!ncopies) { + return Expr{std::move(funcRef)}; + } else { + if (*ncopies < 0) { + ncopies = 0; + } + // TODO: Consider moving this implementation (after the user error + // checks), along with other transformational intrinsics, into + // constant.h (or a new header) so that the transformationals + // are available for all Constant<>s without needing to be packaged + // as references to intrinsic functions for folding. + ConstantSubscripts shape{source->shape()}; + shape.insert(shape.begin() + *dim - 1, *ncopies); + Constant spread{source->Reshape(std::move(shape))}; + std::vector dimOrder; + for (int j{0}; j < sourceRank; ++j) { + dimOrder.push_back(j); + } + dimOrder.insert(dimOrder.begin() + *dim - 1, sourceRank); + ConstantSubscripts at{spread.lbounds()}; // all 1 + spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder); + return Expr{std::move(spread)}; + } + // Invalid, prevent re-folding + return MakeInvalidIntrinsic(std::move(funcRef)); +} + template Expr Folder::TRANSPOSE(FunctionRef &&funcRef) { auto args{funcRef.arguments()}; CHECK(args.size() == 1); @@ -1017,12 +1063,13 @@ return Folder{context}.PACK(std::move(funcRef)); } else if (name == "reshape") { return Folder{context}.RESHAPE(std::move(funcRef)); + } else if (name == "spread") { + return Folder{context}.SPREAD(std::move(funcRef)); } else if (name == "transpose") { return Folder{context}.TRANSPOSE(std::move(funcRef)); } else if (name == "unpack") { return Folder{context}.UNPACK(std::move(funcRef)); } - // TODO: spread // TODO: extends_type_of, same_type_as if constexpr (!std::is_same_v) { return FoldIntrinsicFunction(context, std::move(funcRef)); diff --git a/flang/test/Evaluate/folding19.f90 b/flang/test/Evaluate/errors01.f90 rename from flang/test/Evaluate/folding19.f90 rename to flang/test/Evaluate/errors01.f90 --- a/flang/test/Evaluate/folding19.f90 +++ b/flang/test/Evaluate/errors01.f90 @@ -90,4 +90,14 @@ !CHECK: error: SHIFT=65 count for shiftl is greater than 64 integer(8), parameter :: bad6 = shiftl(1_8, 65) end subroutine + subroutine s9 + integer, parameter :: rank15(1,1,1,1,1,1,1,1,1,1,1,1,1,1,1) = 1 + !CHECK: error: SOURCE= argument to SPREAD has rank 15 but must have rank less than 15 + integer, parameter :: bad1 = spread(rank15, 1, 1) + integer, parameter :: matrix(2, 2) = reshape([1, 2, 3, 4], [2, 2]) + !CHECK: error: DIM=0 argument to SPREAD must be between 1 and 3 + integer, parameter :: bad2 = spread(matrix, 0, 1) + !CHECK: error: DIM=4 argument to SPREAD must be between 1 and 3 + integer, parameter :: bad3 = spread(matrix, 4, 1) + end subroutine end module diff --git a/flang/test/Evaluate/fold-spread.f90 b/flang/test/Evaluate/fold-spread.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/fold-spread.f90 @@ -0,0 +1,13 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of SPREAD +module m1 + logical, parameter :: test_empty = size(spread(1, 1, 0)) == 0 + logical, parameter :: test_stov = all(spread(1, 1, 2) == [1, 1]) + logical, parameter :: test_vtom1 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2])) + logical, parameter :: test_vtom2 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3])) + logical, parameter :: test_vtom3 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3])) + logical, parameter :: test_log1 = all(all(spread([.false., .true.], 1, 2), dim=2) .eqv. [.false., .false.]) + logical, parameter :: test_log2 = all(all(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.]) + logical, parameter :: test_log3 = all(any(spread([.false., .true.], 1, 2), dim=2) .eqv. [.true., .true.]) + logical, parameter :: test_log4 = all(any(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.]) +end module