diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -149,6 +149,7 @@ Expr GetComplexPart( const Expr &, bool isImaginary = false); +Expr GetComplexPart(Expr &&, bool isImaginary = false); template Expr MakeComplex(Expr> &&re, diff --git a/flang/include/flang/Evaluate/type.h b/flang/include/flang/Evaluate/type.h --- a/flang/include/flang/Evaluate/type.h +++ b/flang/include/flang/Evaluate/type.h @@ -461,7 +461,6 @@ #define EXPAND_FOR_EACH_CHARACTER_KIND(M, P, S) M(P, S, 1) M(P, S, 2) M(P, S, 4) #define EXPAND_FOR_EACH_LOGICAL_KIND(M, P, S) \ M(P, S, 1) M(P, S, 2) M(P, S, 4) M(P, S, 8) -#define TEMPLATE_INSTANTIATION(P, S, ARG) P S; #define FOR_EACH_INTEGER_KIND_HELP(PREFIX, SUFFIX, K) \ PREFIX> SUFFIX; diff --git a/flang/include/flang/Evaluate/variable.h b/flang/include/flang/Evaluate/variable.h --- a/flang/include/flang/Evaluate/variable.h +++ b/flang/include/flang/Evaluate/variable.h @@ -353,6 +353,7 @@ ENUM_CLASS(Part, RE, IM) CLASS_BOILERPLATE(ComplexPart) ComplexPart(DataRef &&z, Part p) : complex_{std::move(z)}, part_{p} {} + DataRef &complex() { return complex_; } const DataRef &complex() const { return complex_; } Part part() const { return part_; } int Rank() const; diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -58,6 +58,7 @@ std::optional> GetConstantComponent( Component &, const std::vector> * = nullptr); std::optional> Folding(ArrayRef &); + std::optional> Folding(DataRef &); Expr Folding(Designator &&); Constant *Folding(std::optional &); @@ -118,27 +119,12 @@ DataRef FoldOperation(FoldingContext &, DataRef &&); Substring FoldOperation(FoldingContext &, Substring &&); ComplexPart FoldOperation(FoldingContext &, ComplexPart &&); - template -Expr FoldOperation(FoldingContext &context, FunctionRef &&); -template -Expr> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef> &&); -template -Expr> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef> &&); -template -Expr> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef> &&); -template -Expr> FoldIntrinsicFunction( - FoldingContext &context, FunctionRef> &&); - +Expr FoldOperation(FoldingContext &, FunctionRef &&); template Expr FoldOperation(FoldingContext &context, Designator &&designator) { return Folder{context}.Folding(std::move(designator)); } - Expr FoldOperation( FoldingContext &, TypeParamInquiry &&); Expr FoldOperation( @@ -182,6 +168,25 @@ } } +template +std::optional> Folder::Folding(DataRef &ref) { + return common::visit( + common::visitors{ + [this](SymbolRef &sym) { return GetNamedConstant(*sym); }, + [this](Component &comp) { + comp = FoldOperation(context_, std::move(comp)); + return GetConstantComponent(comp); + }, + [this](ArrayRef &aRef) { + aRef = FoldOperation(context_, std::move(aRef)); + return Folding(aRef); + }, + [](CoarrayRef &) { return std::optional>{}; }, + }, + ref.u); +} + +// TODO: This would be more natural as a member function of Constant. template std::optional> Folder::ApplySubscripts(const Constant &array, const std::vector> &subscripts) { @@ -341,6 +346,19 @@ } } } + } else if constexpr (T::category == TypeCategory::Real) { + if (auto *zPart{std::get_if(&designator.u)}) { + *zPart = FoldOperation(context_, std::move(*zPart)); + using ComplexT = Type; + if (auto zConst{Folder{context_}.Folding(zPart->complex())}) { + return Fold(context_, + Expr{ComplexComponent{ + zPart->part() == ComplexPart::Part::IM, + Expr{std::move(*zConst)}}}); + } else { + return Expr{Designator{std::move(*zPart)}}; + } + } } return common::visit( common::visitors{ @@ -1045,6 +1063,20 @@ return common::visit(insertConversion, sx.u); } +// FoldIntrinsicFunction() +template +Expr> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef> &&); +template +Expr> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef> &&); +template +Expr> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef> &&); +template +Expr> FoldIntrinsicFunction( + FoldingContext &context, FunctionRef> &&); + template Expr FoldOperation(FoldingContext &context, FunctionRef &&funcRef) { ActualArguments &args{funcRef.arguments()}; @@ -1922,6 +1954,31 @@ return result.value(); } +// REAL(z) and AIMAG(z) +template +Expr> FoldOperation( + FoldingContext &context, ComplexComponent &&x) { + using Operand = Type; + using Result = Type; + if (auto array{ApplyElementwise(context, x, + std::function(Expr &&)>{ + [=](Expr &&operand) { + return Expr{ComplexComponent{ + x.isImaginaryPart, std::move(operand)}}; + }})}) { + return *array; + } + auto &operand{x.left()}; + if (auto value{GetScalarConstantValue(operand)}) { + if (x.isImaginaryPart) { + return Expr{Constant{value->AIMAG()}}; + } else { + return Expr{Constant{value->REAL()}}; + } + } + return Expr{std::move(x)}; +} + template Expr ExpressionBase::Rewrite(FoldingContext &context, Expr &&expr) { return common::visit( @@ -1941,6 +1998,5 @@ } FOR_EACH_TYPE_AND_KIND(extern template class ExpressionBase, ) - } // namespace Fortran::evaluate #endif // FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_ diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -112,8 +112,9 @@ common::die(" unexpected argument type inside abs"); } } else if (name == "aimag") { - return FoldElementalIntrinsic( - context, std::move(funcRef), &Scalar::AIMAG); + if (auto *zExpr{UnwrapExpr>(args[0])}) { + return Fold(context, Expr{ComplexComponent{true, std::move(*zExpr)}}); + } } else if (name == "aint" || name == "anint") { // ANINT rounds ties away from zero, not to even common::RoundingMode mode{name == "aint" @@ -318,31 +319,6 @@ return Expr{std::move(funcRef)}; } -template -Expr> FoldOperation( - FoldingContext &context, ComplexComponent &&x) { - using Operand = Type; - using Result = Type; - if (auto array{ApplyElementwise(context, x, - std::function(Expr &&)>{ - [=](Expr &&operand) { - return Expr{ComplexComponent{ - x.isImaginaryPart, std::move(operand)}}; - }})}) { - return *array; - } - using Part = Type; - auto &operand{x.left()}; - if (auto value{GetScalarConstantValue(operand)}) { - if (x.isImaginaryPart) { - return Expr{Constant{value->AIMAG()}}; - } else { - return Expr{Constant{value->REAL()}}; - } - } - return Expr{std::move(x)}; -} - #ifdef _MSC_VER // disable bogus warning about missing definitions #pragma warning(disable : 4661) #endif diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -238,6 +238,16 @@ z.u); } +Expr GetComplexPart(Expr &&z, bool isImaginary) { + return common::visit( + [&](auto &&zk) { + static constexpr int kind{ResultType::kind}; + return AsCategoryExpr( + ComplexComponent{isImaginary, std::move(zk)}); + }, + z.u); +} + // Convert REAL to COMPLEX of the same kind. Preserving the real operand kind // and then applying complex operand promotion rules allows the result to have // the highest precision of REAL and COMPLEX operands as required by Fortran diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -1073,7 +1073,8 @@ MiscKind kind{details->kind()}; if (kind == MiscKind::ComplexPartRe || kind == MiscKind::ComplexPartIm) { if (auto *zExpr{std::get_if>(&base->u)}) { - if (std::optional dataRef{ExtractDataRef(std::move(*zExpr))}) { + if (std::optional dataRef{ExtractDataRef(*zExpr)}) { + // Represent %RE/%IM as a designator Expr realExpr{common::visit( [&](const auto &z) { using PartType = typename ResultType::Part; diff --git a/flang/test/Evaluate/fold-re-im.f90 b/flang/test/Evaluate/fold-re-im.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/fold-re-im.f90 @@ -0,0 +1,15 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of complex components +module m + complex, parameter :: z = (1., 2.) + logical, parameter :: test_1 = z%re == 1. + logical, parameter :: test_2 = z%im == 2. + logical, parameter :: test_3 = real(z+z) == 2. + logical, parameter :: test_4 = aimag(z+z) == 4. + type :: t + complex :: z + end type + type(t), parameter :: tz(*) = [t((3., 4.)), t((5., 6.))] + logical, parameter :: test_5 = all(tz%z%re == [3., 5.]) + logical, parameter :: test_6 = all(tz%z%im == [4., 6.]) +end module