diff --git a/flang/include/flang/Evaluate/fold.h b/flang/include/flang/Evaluate/fold.h --- a/flang/include/flang/Evaluate/fold.h +++ b/flang/include/flang/Evaluate/fold.h @@ -57,10 +57,8 @@ if (auto *c{UnwrapExpr>(expr)}) { return c; } else { - if constexpr (!std::is_same_v) { - if (auto *parens{UnwrapExpr>(expr)}) { - return UnwrapConstantValue(parens->left()); - } + if (auto *parens{UnwrapExpr>(expr)}) { + return UnwrapConstantValue(parens->left()); } return nullptr; } diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp --- a/flang/lib/Evaluate/fold-complex.cpp +++ b/flang/lib/Evaluate/fold-complex.cpp @@ -62,6 +62,8 @@ ToReal(context, std::move(im))}}); } } + } else if (name == "dot_product") { + return FoldDotProduct(context, std::move(funcRef)); } else if (name == "merge") { return FoldMerge(context, std::move(funcRef)); } else if (name == "product") { @@ -70,7 +72,7 @@ } else if (name == "sum") { return FoldSum(context, std::move(funcRef)); } - // TODO: dot_product, matmul + // TODO: matmul return Expr{std::move(funcRef)}; } diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -552,6 +552,8 @@ } else if (name == "dim") { return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::DIM); + } else if (name == "dot_product") { + return FoldDotProduct(context, std::move(funcRef)); } else if (name == "dshiftl" || name == "dshiftr") { const auto fptr{ name == "dshiftl" ? &Scalar::DSHIFTL : &Scalar::DSHIFTR}; diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp --- a/flang/lib/Evaluate/fold-logical.cpp +++ b/flang/lib/Evaluate/fold-logical.cpp @@ -140,6 +140,8 @@ }, ix->u); } + } else if (name == "dot_product") { + return FoldDotProduct(context, std::move(funcRef)); } else if (name == "extends_type_of") { // Type extension testing with EXTENDS_TYPE_OF() ignores any type // parameters. Returns a constant truth value when the result is known now. @@ -231,7 +233,7 @@ name == "__builtin_ieee_support_underflow_control") { return Expr{true}; } - // TODO: dot_product, is_iostat_end, + // TODO: is_iostat_end, // is_iostat_eor, logical, matmul, out_of_range, // parity return Expr{std::move(funcRef)}; diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -136,6 +136,8 @@ [](const Scalar &x, const Scalar &y) -> Scalar { return x.DIM(y).value; })); + } else if (name == "dot_product") { + return FoldDotProduct(context, std::move(funcRef)); } else if (name == "dprod") { if (auto scalars{GetScalarConstantArguments(context, args)}) { return Fold(context, diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -// TODO: DOT_PRODUCT, NORM2, PARITY +// TODO: NORM2, PARITY #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ @@ -15,10 +15,96 @@ namespace Fortran::evaluate { -// Fold and validate a DIM= argument. Returns true (with &dim empty) -// when DIM= is not present or (with &dim set) when DIM= is present, constant, -// and valid. Returns false, possibly with an error message, when -// DIM= is present but either not constant or not valid. +// DOT_PRODUCT +template +static Expr FoldDotProduct( + FoldingContext &context, FunctionRef &&funcRef) { + using Element = typename Constant::Element; + auto args{funcRef.arguments()}; + CHECK(args.size() == 2); + Folder folder{context}; + Constant *va{folder.Folding(args[0])}; + Constant *vb{folder.Folding(args[1])}; + if (va && vb) { + CHECK(va->Rank() == 1 && vb->Rank() == 1); + if (va->size() != vb->size()) { + context.messages().Say( + "Vector arguments to DOT_PRODUCT have distinct extents %zd and %zd"_err_en_US, + va->size(), vb->size()); + return MakeInvalidIntrinsic(std::move(funcRef)); + } + Element sum{}; + bool overflow{false}; + if constexpr (T::category == TypeCategory::Complex) { + std::vector conjugates; + for (const Element &x : va->values()) { + conjugates.emplace_back(x.CONJG()); + } + Constant conjgA{ + std::move(conjugates), ConstantSubscripts{va->shape()}}; + Expr products{Fold( + context, Expr{std::move(conjgA)} * Expr{Constant{*vb}})}; + Constant &cProducts{DEREF(UnwrapConstantValue(products))}; + Element correction; // Use Kahan summation for greater precision. + const auto &rounding{context.targetCharacteristics().roundingMode()}; + for (const Element &x : cProducts.values()) { + auto next{correction.Add(x, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } + } else if constexpr (T::category == TypeCategory::Logical) { + Expr conjunctions{Fold(context, + Expr{LogicalOperation{LogicalOperator::And, + Expr{Constant{*va}}, Expr{Constant{*vb}}}})}; + Constant &cConjunctions{DEREF(UnwrapConstantValue(conjunctions))}; + for (const Element &x : cConjunctions.values()) { + if (x.IsTrue()) { + sum = Element{true}; + break; + } + } + } else if constexpr (T::category == TypeCategory::Integer) { + Expr products{ + Fold(context, Expr{Constant{*va}} * Expr{Constant{*vb}})}; + Constant &cProducts{DEREF(UnwrapConstantValue(products))}; + for (const Element &x : cProducts.values()) { + auto next{sum.AddSigned(x)}; + overflow |= next.overflow; + sum = std::move(next.value); + } + } else { // T::category == TypeCategory::Real + Expr products{ + Fold(context, Expr{Constant{*va}} * Expr{Constant{*vb}})}; + Constant &cProducts{DEREF(UnwrapConstantValue(products))}; + Element correction; // Use Kahan summation for greater precision. + const auto &rounding{context.targetCharacteristics().roundingMode()}; + for (const Element &x : cProducts.values()) { + auto next{correction.Add(x, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } + } + if (overflow) { + context.messages().Say( + "DOT_PRODUCT of %s data overflowed during computation"_warn_en_US, + T::AsFortran()); + } + return Expr{Constant{std::move(sum)}}; + } + return Expr{std::move(funcRef)}; +} + +// Fold and validate a DIM= argument. Returns false on error. bool CheckReductionDIM(std::optional &dim, FoldingContext &, ActualArguments &, std::optional dimIndex, int rank); @@ -203,13 +289,15 @@ overflow |= sum.overflow; element = sum.value; } else { // Real & Complex: use Kahan summation - auto next{array->At(at).Add(correction)}; + const auto &rounding{context.targetCharacteristics().roundingMode()}; + auto next{array->At(at).Add(correction, rounding)}; overflow |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value)}; + auto sum{element.Add(next.value, rounding)}; overflow |= sum.flags.test(RealFlag::Overflow); // correction = (sum - element) - next; algebraically zero - correction = - sum.value.Subtract(element).value.Subtract(next.value).value; + correction = sum.value.Subtract(element, rounding) + .value.Subtract(next.value, rounding) + .value; element = sum.value; } }}; diff --git a/flang/test/Evaluate/fold-dot.f90 b/flang/test/Evaluate/fold-dot.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Evaluate/fold-dot.f90 @@ -0,0 +1,10 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of DOT_PRODUCT() +module m + logical, parameter :: test_i4a = dot_product([(j,j=1,10)],[(j,j=1,10)]) == sum([(j*j,j=1,10)]) + logical, parameter :: test_r4a = dot_product([(1.*j,j=1,10)],[(j,j=1,10)]) == sum([(j*j,j=1,10)]) + logical, parameter :: test_z4a = dot_product([((j,j),j=1,10)],[((j,j),j=1,10)]) == sum([(((j,-j)*(j,j)),j=1,10)]) + logical, parameter :: test_l4a = .not. dot_product([logical::],[logical::]) + logical, parameter :: test_l4b = .not. dot_product([(j==2,j=1,10)], [(j==3,j=1,10)]) + logical, parameter :: test_l4c = dot_product([(j==4,j=1,10)], [(j==4,j=1,10)]) +end