diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -149,6 +149,7 @@ Expr<SomeReal> GetComplexPart( const Expr<SomeComplex> &, bool isImaginary = false); +Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&, bool isImaginary = false); template <int KIND> Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re, diff --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h --- a/flang/include/flang/Evaluate/type.h +++ b/flang/include/flang/Evaluate/type.h @@ -461,7 +461,6 @@ #define EXPAND_FOR_EACH_CHARACTER_KIND(M, P, S) M(P, S, 1) M(P, S, 2) M(P, S, 4) #define EXPAND_FOR_EACH_LOGICAL_KIND(M, P, S) \ M(P, S, 1) M(P, S, 2) M(P, S, 4) M(P, S, 8) -#define TEMPLATE_INSTANTIATION(P, S, ARG) P<ARG> S; #define FOR_EACH_INTEGER_KIND_HELP(PREFIX, SUFFIX, K) \ PREFIX<Type<TypeCategory::Integer, K>> SUFFIX; diff --git a/flang/include/flang/Evaluate/variable.h b/flang/include/flang/Evaluate/variable.h --- a/flang/include/flang/Evaluate/variable.h +++ b/flang/include/flang/Evaluate/variable.h @@ -353,6 +353,7 @@ ENUM_CLASS(Part, RE, IM) CLASS_BOILERPLATE(ComplexPart) ComplexPart(DataRef &&z, Part p) : complex_{std::move(z)}, part_{p} {} + DataRef &complex() { return complex_; } const DataRef &complex() const { return complex_; } Part part() const { return part_; } int Rank() const; diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -58,6 +58,7 @@ std::optional<Constant<T>> GetConstantComponent( Component &, const std::vector<Constant<SubscriptInteger>> * = nullptr); std::optional<Constant<T>> Folding(ArrayRef &); + std::optional<Constant<T>> Folding(DataRef &); Expr<T> Folding(Designator<T> &&); Constant<T> *Folding(std::optional<ActualArgument> &); @@ -118,27 +119,12 @@ DataRef FoldOperation(FoldingContext &, DataRef &&); Substring FoldOperation(FoldingContext &, Substring &&); ComplexPart FoldOperation(FoldingContext &, ComplexPart &&); - template <typename T> -Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&); -template <int KIND> -Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef<Type<TypeCategory::Integer, KIND>> &&); -template <int KIND> -Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef<Type<TypeCategory::Real, KIND>> &&); -template <int KIND> -Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef<Type<TypeCategory::Complex, KIND>> &&); -template <int KIND> -Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef<Type<TypeCategory::Logical, KIND>> &&); - +Expr<T> FoldOperation(FoldingContext &, FunctionRef<T> &&); template <typename T> Expr<T> FoldOperation(FoldingContext &context, Designator<T> &&designator) { return Folder<T>{context}.Folding(std::move(designator)); } - Expr<TypeParamInquiry::Result> FoldOperation( FoldingContext &, TypeParamInquiry &&); Expr<ImpliedDoIndex::Result> FoldOperation( @@ -182,6 +168,25 @@ } } +template <typename T> +std::optional<Constant<T>> Folder<T>::Folding(DataRef &ref) { + return common::visit( + common::visitors{ + [this](SymbolRef &sym) { return GetNamedConstant(*sym); }, + [this](Component &comp) { + comp = FoldOperation(context_, std::move(comp)); + return GetConstantComponent(comp); + }, + [this](ArrayRef &aRef) { + aRef = FoldOperation(context_, std::move(aRef)); + return Folding(aRef); + }, + [](CoarrayRef &) { return std::optional<Constant<T>>{}; }, + }, + ref.u); +} + +// TODO: This would be more natural as a member function of Constant<T>. template <typename T> std::optional<Constant<T>> Folder<T>::ApplySubscripts(const Constant<T> &array, const std::vector<Constant<SubscriptInteger>> &subscripts) { @@ -341,6 +346,19 @@ } } } + } else if constexpr (T::category == TypeCategory::Real) { + if (auto *zPart{std::get_if<ComplexPart>(&designator.u)}) { + *zPart = FoldOperation(context_, std::move(*zPart)); + using ComplexT = Type<TypeCategory::Complex, T::kind>; + if (auto zConst{Folder<ComplexT>{context_}.Folding(zPart->complex())}) { + return Fold(context_, + Expr<T>{ComplexComponent<T::kind>{ + zPart->part() == ComplexPart::Part::IM, + Expr<ComplexT>{std::move(*zConst)}}}); + } else { + return Expr<T>{Designator<T>{std::move(*zPart)}}; + } + } } return common::visit( common::visitors{ @@ -1045,6 +1063,20 @@ return common::visit(insertConversion, sx.u); } +// FoldIntrinsicFunction() +template <int KIND> +Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef<Type<TypeCategory::Integer, KIND>> &&); +template <int KIND> +Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef<Type<TypeCategory::Real, KIND>> &&); +template <int KIND> +Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef<Type<TypeCategory::Complex, KIND>> &&); +template <int KIND> +Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef<Type<TypeCategory::Logical, KIND>> &&); + template <typename T> Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) { ActualArguments &args{funcRef.arguments()}; @@ -1922,6 +1954,31 @@ return result.value(); } +// REAL(z) and AIMAG(z) +template <int KIND> +Expr<Type<TypeCategory::Real, KIND>> FoldOperation( + FoldingContext &context, ComplexComponent<KIND> &&x) { + using Operand = Type<TypeCategory::Complex, KIND>; + using Result = Type<TypeCategory::Real, KIND>; + if (auto array{ApplyElementwise(context, x, + std::function<Expr<Result>(Expr<Operand> &&)>{ + [=](Expr<Operand> &&operand) { + return Expr<Result>{ComplexComponent<KIND>{ + x.isImaginaryPart, std::move(operand)}}; + }})}) { + return *array; + } + auto &operand{x.left()}; + if (auto value{GetScalarConstantValue<Operand>(operand)}) { + if (x.isImaginaryPart) { + return Expr<Result>{Constant<Result>{value->AIMAG()}}; + } else { + return Expr<Result>{Constant<Result>{value->REAL()}}; + } + } + return Expr<Result>{std::move(x)}; +} + template <typename T> Expr<T> ExpressionBase<T>::Rewrite(FoldingContext &context, Expr<T> &&expr) { return common::visit( @@ -1941,6 +1998,5 @@ } FOR_EACH_TYPE_AND_KIND(extern template class ExpressionBase, ) - } // namespace Fortran::evaluate #endif // FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_ diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -112,8 +112,9 @@ common::die(" unexpected argument type inside abs"); } } else if (name == "aimag") { - return FoldElementalIntrinsic<T, ComplexT>( - context, std::move(funcRef), &Scalar<ComplexT>::AIMAG); + if (auto *zExpr{UnwrapExpr<Expr<ComplexT>>(args[0])}) { + return Fold(context, Expr<T>{ComplexComponent{true, std::move(*zExpr)}}); + } } else if (name == "aint" || name == "anint") { // ANINT rounds ties away from zero, not to even common::RoundingMode mode{name == "aint" @@ -318,31 +319,6 @@ return Expr<T>{std::move(funcRef)}; } -template <int KIND> -Expr<Type<TypeCategory::Real, KIND>> FoldOperation( - FoldingContext &context, ComplexComponent<KIND> &&x) { - using Operand = Type<TypeCategory::Complex, KIND>; - using Result = Type<TypeCategory::Real, KIND>; - if (auto array{ApplyElementwise(context, x, - std::function<Expr<Result>(Expr<Operand> &&)>{ - [=](Expr<Operand> &&operand) { - return Expr<Result>{ComplexComponent<KIND>{ - x.isImaginaryPart, std::move(operand)}}; - }})}) { - return *array; - } - using Part = Type<TypeCategory::Real, KIND>; - auto &operand{x.left()}; - if (auto value{GetScalarConstantValue<Operand>(operand)}) { - if (x.isImaginaryPart) { - return Expr<Part>{Constant<Part>{value->AIMAG()}}; - } else { - return Expr<Part>{Constant<Part>{value->REAL()}}; - } - } - return Expr<Part>{std::move(x)}; -} - #ifdef _MSC_VER // disable bogus warning about missing definitions #pragma warning(disable : 4661) #endif diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -238,6 +238,16 @@ z.u); } +Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&z, bool isImaginary) { + return common::visit( + [&](auto &&zk) { + static constexpr int kind{ResultType<decltype(zk)>::kind}; + return AsCategoryExpr( + ComplexComponent<kind>{isImaginary, std::move(zk)}); + }, + z.u); +} + // Convert REAL to COMPLEX of the same kind. Preserving the real operand kind // and then applying complex operand promotion rules allows the result to have // the highest precision of REAL and COMPLEX operands as required by Fortran diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -1073,7 +1073,8 @@ MiscKind kind{details->kind()}; if (kind == MiscKind::ComplexPartRe || kind == MiscKind::ComplexPartIm) { if (auto *zExpr{std::get_if<Expr<SomeComplex>>(&base->u)}) { - if (std::optional<DataRef> dataRef{ExtractDataRef(std::move(*zExpr))}) { + if (std::optional<DataRef> dataRef{ExtractDataRef(*zExpr)}) { + // Represent %RE/%IM as a designator Expr<SomeReal> realExpr{common::visit( [&](const auto &z) { using PartType = typename ResultType<decltype(z)>::Part; diff --git a/flang/test/Evaluate/fold-re-im.f90 b/flang/test/Evaluate/fold-re-im.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/fold-re-im.f90 @@ -0,0 +1,15 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of complex components +module m + complex, parameter :: z = (1., 2.) + logical, parameter :: test_1 = z%re == 1. + logical, parameter :: test_2 = z%im == 2. + logical, parameter :: test_3 = real(z+z) == 2. + logical, parameter :: test_4 = aimag(z+z) == 4. + type :: t + complex :: z + end type + type(t), parameter :: tz(*) = [t((3., 4.)), t((5., 6.))] + logical, parameter :: test_5 = all(tz%z%re == [3., 5.]) + logical, parameter :: test_6 = all(tz%z%im == [4., 6.]) +end module