diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -65,6 +65,8 @@ Expr EOSHIFT(FunctionRef &&); Expr PACK(FunctionRef &&); Expr RESHAPE(FunctionRef &&); + Expr TRANSPOSE(FunctionRef &&); + Expr UNPACK(FunctionRef &&); private: FoldingContext &context_; @@ -853,6 +855,78 @@ return MakeInvalidIntrinsic(std::move(funcRef)); } +template Expr Folder::TRANSPOSE(FunctionRef &&funcRef) { + auto args{funcRef.arguments()}; + CHECK(args.size() == 1); + const auto *matrix{UnwrapConstantValue(args[0])}; + if (!matrix) { + return Expr{std::move(funcRef)}; + } + // Argument is constant. Traverse its elements in transposed order. + std::vector> resultElements; + ConstantSubscripts at(2); + for (ConstantSubscript j{0}; j < matrix->shape()[0]; ++j) { + at[0] = matrix->lbounds()[0] + j; + for (ConstantSubscript k{0}; k < matrix->shape()[1]; ++k) { + at[1] = matrix->lbounds()[1] + k; + resultElements.push_back(matrix->At(at)); + } + } + at = matrix->shape(); + std::swap(at[0], at[1]); + return Expr{PackageConstant(std::move(resultElements), *matrix, at)}; +} + +template Expr Folder::UNPACK(FunctionRef &&funcRef) { + auto args{funcRef.arguments()}; + CHECK(args.size() == 3); + const auto *vector{UnwrapConstantValue(args[0])}; + auto convertedMask{Fold(context_, + ConvertToType( + Expr{DEREF(UnwrapExpr>(args[1]))}))}; + const auto *mask{UnwrapConstantValue(convertedMask)}; + const auto *field{UnwrapConstantValue(args[2])}; + if (!vector || !mask || !field) { + return Expr{std::move(funcRef)}; + } + // Arguments are constant. + if (field->Rank() > 0 && field->shape() != mask->shape()) { + // Error already emitted from intrinsic processing + return MakeInvalidIntrinsic(std::move(funcRef)); + } + ConstantSubscript maskElements{GetSize(mask->shape())}; + ConstantSubscript truths{0}; + ConstantSubscripts maskAt{mask->lbounds()}; + for (ConstantSubscript j{0}; j < maskElements; + ++j, mask->IncrementSubscripts(maskAt)) { + if (mask->At(maskAt).IsTrue()) { + ++truths; + } + } + if (truths > GetSize(vector->shape())) { + context_.messages().Say( + "Invalid 'vector=' argument in UNPACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements"_err_en_US, + static_cast(truths), + static_cast(GetSize(vector->shape()))); + return MakeInvalidIntrinsic(std::move(funcRef)); + } + std::vector> resultElements; + ConstantSubscripts vectorAt{vector->lbounds()}; + ConstantSubscripts fieldAt{field->lbounds()}; + for (ConstantSubscript j{0}; j < maskElements; ++j) { + if (mask->At(maskAt).IsTrue()) { + resultElements.push_back(vector->At(vectorAt)); + vector->IncrementSubscripts(vectorAt); + } else { + resultElements.push_back(field->At(fieldAt)); + } + mask->IncrementSubscripts(maskAt); + field->IncrementSubscripts(fieldAt); + } + return Expr{ + PackageConstant(std::move(resultElements), *vector, mask->shape())}; +} + template Expr FoldMINorMAX( FoldingContext &context, FunctionRef &&funcRef, Ordering order) { @@ -943,8 +1017,12 @@ return Folder{context}.PACK(std::move(funcRef)); } else if (name == "reshape") { return Folder{context}.RESHAPE(std::move(funcRef)); + } else if (name == "transpose") { + return Folder{context}.TRANSPOSE(std::move(funcRef)); + } else if (name == "unpack") { + return Folder{context}.UNPACK(std::move(funcRef)); } - // TODO: spread, unpack, transpose + // TODO: spread // TODO: extends_type_of, same_type_as if constexpr (!std::is_same_v) { return FoldIntrinsicFunction(context, std::move(funcRef)); diff --git a/flang/test/Evaluate/folding19.f90 b/flang/test/Evaluate/folding19.f90 --- a/flang/test/Evaluate/folding19.f90 +++ b/flang/test/Evaluate/folding19.f90 @@ -43,5 +43,11 @@ !CHECK: error: Invalid 'vector=' argument in PACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements x = pack(array, mask, [0,0]) end subroutine + subroutine s5 + logical, parameter :: mask(2,3) = reshape([.false., .true., .true., .false., .false., .true.], shape(mask)) + integer, parameter :: field(3,2) = reshape([(-j,j=1,6)], shape(field)) + integer :: x(2,3) + !CHECK: error: Invalid 'vector=' argument in UNPACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements + x = unpack([1,2], mask, 0) + end subroutine end module - diff --git a/flang/test/Evaluate/folding25.f90 b/flang/test/Evaluate/folding25.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/folding25.f90 @@ -0,0 +1,10 @@ +! RUN: %S/test_folding.sh %s %t %flang_fc1 +! REQUIRES: shell +! Tests folding of UNPACK (valid cases) +module m + integer, parameter :: vector(*) = [1, 2, 3, 4] + integer, parameter :: field(2,3) = reshape([(-j,j=1,6)], shape(field)) + logical, parameter :: mask(*,*) = reshape([.false., .true., .true., .false., .false., .true.], shape(field)) + logical, parameter :: test_unpack_1 = all(unpack(vector, mask, 0) == reshape([0,1,2,0,0,3], shape(mask))) + logical, parameter :: test_unpack_2 = all(unpack(vector, mask, field) == reshape([-1,1,2,-4,-5,3], shape(mask))) +end module diff --git a/flang/test/Evaluate/folding26.f90 b/flang/test/Evaluate/folding26.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/folding26.f90 @@ -0,0 +1,7 @@ +! RUN: %S/test_folding.sh %s %t %flang_fc1 +! REQUIRES: shell +! Tests folding of TRANSPOSE +module m + integer, parameter :: matrix(0:1,0:2) = reshape([1,2,3,4,5,6],shape(matrix)) + logical, parameter :: test_transpose_1 = all(transpose(matrix) == reshape([1,3,5,2,4,6],[3,2])) +end module