diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -63,6 +63,7 @@ Expr CSHIFT(FunctionRef &&); Expr EOSHIFT(FunctionRef &&); + Expr PACK(FunctionRef &&); Expr RESHAPE(FunctionRef &&); private: @@ -580,7 +581,7 @@ if (j != zbDim) { if (array->shape()[j] != shift->shape()[k]) { context_.messages().Say( - "Invalid 'shift=' argument in CSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US, + "Invalid 'shift=' argument in CSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US, k + 1, static_cast(shift->shape()[k]), static_cast(array->shape()[j])); ok = false; @@ -653,6 +654,9 @@ static_cast(*dim)); } else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) { // message already emitted from intrinsic look-up + } else if (boundary && boundary->Rank() > 0 && + boundary->Rank() != array->Rank() - 1) { + // ditto } else { int rank{array->Rank()}; int zbDim{static_cast(*dim) - 1}; @@ -663,15 +667,23 @@ if (j != zbDim) { if (array->shape()[j] != shift->shape()[k]) { context_.messages().Say( - "Invalid 'shift=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US, + "Invalid 'shift=' argument in EOSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US, k + 1, static_cast(shift->shape()[k]), static_cast(array->shape()[j])); ok = false; } - if (boundary && array->shape()[j] != boundary->shape()[k]) { + ++k; + } + } + } + if (boundary && boundary->Rank() > 0) { + int k{0}; + for (int j{0}; j < rank; ++j) { + if (j != zbDim) { + if (array->shape()[j] != boundary->shape()[k]) { context_.messages().Say( - "Invalid 'boundary=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US, - k + 1, static_cast(shift->shape()[k]), + "Invalid 'boundary=' argument in EOSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US, + k + 1, static_cast(boundary->shape()[k]), static_cast(array->shape()[j])); ok = false; } @@ -726,6 +738,70 @@ return MakeInvalidIntrinsic(std::move(funcRef)); } +template Expr Folder::PACK(FunctionRef &&funcRef) { + auto args{funcRef.arguments()}; + CHECK(args.size() == 3); + const auto *array{UnwrapConstantValue(args[0])}; + const auto *vector{UnwrapConstantValue(args[2])}; + auto convertedMask{Fold(context_, + ConvertToType( + Expr{DEREF(UnwrapExpr>(args[1]))}))}; + const auto *mask{UnwrapConstantValue(convertedMask)}; + if (!array || !mask || (args[2] && !vector)) { + return Expr{std::move(funcRef)}; + } + // Arguments are constant. + ConstantSubscript arrayElements{GetSize(array->shape())}; + ConstantSubscript truths{0}; + ConstantSubscripts maskAt{mask->lbounds()}; + if (mask->Rank() == 0) { + if (mask->At(maskAt).IsTrue()) { + truths = arrayElements; + } + } else if (array->shape() != mask->shape()) { + // Error already emitted from intrinsic processing + return MakeInvalidIntrinsic(std::move(funcRef)); + } else { + for (ConstantSubscript j{0}; j < arrayElements; + ++j, mask->IncrementSubscripts(maskAt)) { + if (mask->At(maskAt).IsTrue()) { + ++truths; + } + } + } + std::vector> resultElements; + ConstantSubscripts arrayAt{array->lbounds()}; + ConstantSubscript resultSize{truths}; + if (vector) { + resultSize = vector->shape().at(0); + if (resultSize < truths) { + context_.messages().Say( + "Invalid 'vector=' argument in PACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements"_err_en_US, + static_cast(truths), + static_cast(resultSize)); + return MakeInvalidIntrinsic(std::move(funcRef)); + } + } + for (ConstantSubscript j{0}; j < truths;) { + if (mask->At(maskAt).IsTrue()) { + resultElements.push_back(array->At(arrayAt)); + ++j; + } + array->IncrementSubscripts(arrayAt); + mask->IncrementSubscripts(maskAt); + } + if (vector) { + ConstantSubscripts vectorAt{vector->lbounds()}; + vectorAt.at(0) += truths; + for (ConstantSubscript j{truths}; j < resultSize; ++j) { + resultElements.push_back(vector->At(vectorAt)); + ++vectorAt[0]; + } + } + return Expr{PackageConstant(std::move(resultElements), *array, + ConstantSubscripts{static_cast(resultSize)})}; +} + template Expr Folder::RESHAPE(FunctionRef &&funcRef) { auto args{funcRef.arguments()}; CHECK(args.size() == 4); @@ -863,10 +939,12 @@ return Folder{context}.CSHIFT(std::move(funcRef)); } else if (name == "eoshift") { return Folder{context}.EOSHIFT(std::move(funcRef)); + } else if (name == "pack") { + return Folder{context}.PACK(std::move(funcRef)); } else if (name == "reshape") { return Folder{context}.RESHAPE(std::move(funcRef)); } - // TODO: eoshift, pack, spread, unpack, transpose + // TODO: spread, unpack, transpose // TODO: extends_type_of, same_type_as if constexpr (!std::is_same_v) { return FoldIntrinsicFunction(context, std::move(funcRef)); diff --git a/flang/test/Evaluate/folding19.f90 b/flang/test/Evaluate/folding19.f90 --- a/flang/test/Evaluate/folding19.f90 +++ b/flang/test/Evaluate/folding19.f90 @@ -18,5 +18,30 @@ !CHECK: error: DIM=2 dimension is out of range for rank-1 array integer :: lb3(lbound(b,2)) end subroutine + subroutine s2 + integer, parameter :: array(2,3) = reshape([(j, j=1, 6)], shape(array)) + integer :: x(2, 3) + !CHECK: error: Invalid 'dim=' argument (0) in CSHIFT + x = cshift(array, [1, 2], dim=0) + !CHECK: error: Invalid 'shift=' argument in CSHIFT: extent on dimension 1 is 2 but must be 3 + x = cshift(array, [1, 2], dim=1) + end subroutine + subroutine s3 + integer, parameter :: array(2,3) = reshape([(j, j=1, 6)], shape(array)) + integer :: x(2, 3) + !CHECK: error: Invalid 'dim=' argument (0) in EOSHIFT + x = eoshift(array, [1, 2], dim=0) + !CHECK: error: Invalid 'shift=' argument in EOSHIFT: extent on dimension 1 is 2 but must be 3 + x = eoshift(array, [1, 2], dim=1) + !CHECK: error: Invalid 'boundary=' argument in EOSHIFT: extent on dimension 1 is 3 but must be 2 + x = eoshift(array, 1, [0, 0, 0], 2) + end subroutine + subroutine s4 + integer, parameter :: array(2,3) = reshape([(j, j=1, 6)], shape(array)) + logical, parameter :: mask(*,*) = reshape([(.true., j=1,3),(.false., j=1,3)], shape(array)) + integer :: x(3) + !CHECK: error: Invalid 'vector=' argument in PACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements + x = pack(array, mask, [0,0]) + end subroutine end module diff --git a/flang/test/Evaluate/folding24.f90 b/flang/test/Evaluate/folding24.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/folding24.f90 @@ -0,0 +1,16 @@ +! RUN: %S/test_folding.sh %s %t %flang_fc1 +! REQUIRES: shell +! Tests folding of PACK (valid cases) +module m + integer, parameter :: arr(2,3) = reshape([1, 2, 3, 4, 5, 6], shape(arr)) + logical, parameter :: odds(*,*) = modulo(arr, 2) /= 0 + integer, parameter :: vect(*) = [(j, j=-10, -1)] + logical, parameter :: test_pack_1 = all(pack(arr, .true.) == [arr]) + logical, parameter :: test_pack_2 = all(pack(arr, .false.) == [integer::]) + logical, parameter :: test_pack_3 = all(pack(arr, odds) == [1, 3, 5]) + logical, parameter :: test_pack_4 = all(pack(arr, .not. odds) == [2, 4, 6]) + logical, parameter :: test_pack_5 = all(pack(arr, .true., vect) == [1, 2, 3, 4, 5, 6, -4, -3, -2, -1]) + logical, parameter :: test_pack_6 = all(pack(arr, .false., vect) == vect) + logical, parameter :: test_pack_7 = all(pack(arr, odds, vect) == [1, 3, 5, -7, -6, -5, -4, -3, -2, -1]) + logical, parameter :: test_pack_8 = all(pack(arr, .not. odds, vect) == [2, 4, 6, -7, -6, -5, -4, -3, -2, -1]) +end module