Index: flang/include/flang/Evaluate/complex.h =================================================================== --- flang/include/flang/Evaluate/complex.h +++ flang/include/flang/Evaluate/complex.h @@ -77,6 +77,11 @@ ValueWithRealFlags Divide( const Complex &, Rounding rounding = defaultRounding) const; + // ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2) + ValueWithRealFlags ABS(Rounding rounding = defaultRounding) const { + return re_.HYPOT(im_, rounding); + } + constexpr Complex FlushSubnormalToZero() const { return {re_.FlushSubnormalToZero(), im_.FlushSubnormalToZero()}; } @@ -88,7 +93,6 @@ std::string DumpHexadecimal() const; llvm::raw_ostream &AsFortran(llvm::raw_ostream &, int kind) const; - // TODO: (C)ABS once Real::HYPOT is done // TODO: unit testing private: Index: flang/include/flang/Evaluate/real.h =================================================================== --- flang/include/flang/Evaluate/real.h +++ flang/include/flang/Evaluate/real.h @@ -115,8 +115,10 @@ ValueWithRealFlags Divide( const Real &, Rounding rounding = defaultRounding) const; - // SQRT(x**2 + y**2) but computed so as to avoid spurious overflow - // TODO: not yet implemented; needed for CABS + ValueWithRealFlags SQRT(Rounding rounding = defaultRounding) const; + + // HYPOT(x,y)=SQRT(x**2 + y**2) computed so as to avoid spurious + // intermediate overflows. ValueWithRealFlags HYPOT( const Real &, Rounding rounding = defaultRounding) const; Index: flang/lib/Evaluate/fold-real.cpp =================================================================== --- flang/lib/Evaluate/fold-real.cpp +++ flang/lib/Evaluate/fold-real.cpp @@ -27,8 +27,8 @@ name == "bessel_y1" || name == "cos" || name == "cosh" || name == "erf" || name == "erfc" || name == "erfc_scaled" || name == "exp" || name == "gamma" || name == "log" || name == "log10" || - name == "log_gamma" || name == "sin" || name == "sinh" || - name == "sqrt" || name == "tan" || name == "tanh") { + name == "log_gamma" || name == "sin" || name == "sinh" || name == "tan" || + name == "tanh") { CHECK(args.size() == 1); if (auto callable{GetHostRuntimeWrapper(name)}) { return FoldElementalIntrinsic( @@ -40,8 +40,7 @@ } else if (name == "amax0" || name == "amin0" || name == "amin1" || name == "amax1" || name == "dmin1" || name == "dmax1") { return RewriteSpecificMINorMAX(context, std::move(funcRef)); - } else if (name == "atan" || name == "atan2" || name == "hypot" || - name == "mod") { + } else if (name == "atan" || name == "atan2" || name == "mod") { std::string localName{name == "atan" ? "atan2" : name}; CHECK(args.size() == 2); if (auto callable{GetHostRuntimeWrapper(localName)}) { @@ -71,13 +70,10 @@ return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::ABS); } else if (auto *z{UnwrapExpr>(args[0])}) { - if (auto callable{GetHostRuntimeWrapper("abs")}) { - return FoldElementalIntrinsic( - context, std::move(funcRef), *callable); - } else { - context.messages().Say( - "abs(complex(kind=%d)) cannot be folded on host"_en_US, KIND); - } + return FoldElementalIntrinsic(context, std::move(funcRef), + ScalarFunc([](const Scalar &z) -> Scalar { + return z.ABS().value; + })); } else { common::die(" unexpected argument type inside abs"); } @@ -108,6 +104,13 @@ return Expr{Scalar::EPSILON()}; } else if (name == "huge") { return Expr{Scalar::HUGE()}; + } else if (name == "hypot") { + CHECK(args.size() == 2); + return FoldElementalIntrinsic(context, std::move(funcRef), + ScalarFunc( + [](const Scalar &x, const Scalar &y) -> Scalar { + return x.HYPOT(y).value; + })); } else if (name == "max") { return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater); } else if (name == "maxval") { @@ -130,6 +133,10 @@ } else if (name == "sign") { return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::SIGN); + } else if (name == "sqrt") { + return FoldElementalIntrinsic(context, std::move(funcRef), + ScalarFunc( + [](const Scalar &x) -> Scalar { return x.SQRT().value; })); } else if (name == "sum") { return FoldSum(context, std::move(funcRef)); } else if (name == "tiny") { Index: flang/lib/Evaluate/intrinsics-library.cpp =================================================================== --- flang/lib/Evaluate/intrinsics-library.cpp +++ flang/lib/Evaluate/intrinsics-library.cpp @@ -222,7 +222,6 @@ FolderFactory::Create("erfc"), FolderFactory::Create("exp"), FolderFactory::Create("gamma"), - FolderFactory::Create("hypot"), FolderFactory::Create("log"), FolderFactory::Create("log10"), FolderFactory::Create("log_gamma"), @@ -230,7 +229,6 @@ FolderFactory::Create("pow"), FolderFactory::Create("sin"), FolderFactory::Create("sinh"), - FolderFactory::Create("sqrt"), FolderFactory::Create("tan"), FolderFactory::Create("tanh"), }; Index: flang/lib/Evaluate/real.cpp =================================================================== --- flang/lib/Evaluate/real.cpp +++ flang/lib/Evaluate/real.cpp @@ -261,6 +261,107 @@ return result; } +template +ValueWithRealFlags> Real::SQRT(Rounding rounding) const { + ValueWithRealFlags result; + if (IsNotANumber()) { + result.value = NotANumber(); + if (IsSignalingNaN()) { + result.flags.set(RealFlag::InvalidArgument); + } + } else if (IsNegative()) { + if (IsZero()) { + // SQRT(-0) == -0 in IEEE-754. + result.value.word_ = result.value.word_.IBSET(bits - 1); + } else { + result.value = NotANumber(); + } + } else if (IsInfinite()) { + // SQRT(+Inf) == +Inf + result.value = Infinity(false); + } else { + // Slow but reliable bit-at-a-time method. Start with a clear significand + // and half the unbiased exponent, and then try to set significand bits + // in descending order of magnitude without exceeding the exact result. + int expo{UnbiasedExponent()}; + if (IsSubnormal()) { + expo -= GetFraction().LEADZ(); + } + expo = expo / 2 + exponentBias; + result.value.Normalize(false, expo, Fraction::MASKL(1)); + for (int bit{significandBits - 1}; bit >= 0; --bit) { + Word word{result.value.word_}; + result.value.word_ = word.IBSET(bit); + auto squared{result.value.Multiply(result.value, rounding)}; + if (squared.flags.test(RealFlag::Overflow) || + squared.flags.test(RealFlag::Underflow) || + Compare(squared.value) == Relation::Less) { + result.value.word_ = word; + } + } + // The computed square root, when squared, has a square that's not greater + // than the original argument. Check this square against the square of the + // next Real value, and return that one if its square is closer in magnitude + // to the original argument. + Real resultSq{result.value.Multiply(result.value).value}; + Real diff{Subtract(resultSq).value.ABS()}; + if (diff.IsZero()) { + return result; // exact + } + Real ulp; + ulp.Normalize(false, expo, Fraction::MASKR(1)); + Real nextAfter{result.value.Add(ulp).value}; + auto nextAfterSq{nextAfter.Multiply(nextAfter)}; + if (!nextAfterSq.flags.test(RealFlag::Overflow) && + !nextAfterSq.flags.test(RealFlag::Underflow)) { + Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()}; + if (nextAfterDiff.Compare(diff) == Relation::Less) { + result.value = nextAfter; + if (nextAfterDiff.IsZero()) { + return result; // exact + } + } + } + result.flags.set(RealFlag::Inexact); + } + return result; +} + +// HYPOT(x,y) = SQRT(x**2 + y**2) by definition, but those squared intermediate +// values are susceptible to over/underflow when computed naively. +// Assuming that x>=y, calculate instead: +// HYPOT(x,y) = SQRT(x**2 * (1+(y/x)**2)) +// = ABS(x) * SQRT(1+(y/x)**2) +template +ValueWithRealFlags> Real::HYPOT( + const Real &y, Rounding rounding) const { + ValueWithRealFlags result; + if (IsNotANumber() || y.IsNotANumber()) { + result.flags.set(RealFlag::InvalidArgument); + result.value = NotANumber(); + } else if (ABS().Compare(y.ABS()) == Relation::Less) { + return y.HYPOT(*this); + } else if (IsZero()) { + return result; // x==y==0 + } else { + auto yOverX{y.Divide(*this, rounding)}; // y/x + bool inexact{yOverX.flags.test(RealFlag::Inexact)}; + auto squared{yOverX.value.Multiply(yOverX.value, rounding)}; // (y/x)**2 + inexact |= squared.flags.test(RealFlag::Inexact); + Real one; + one.Normalize(false, exponentBias, Fraction::MASKL(1)); // 1.0 + auto sum{squared.value.Add(one, rounding)}; // 1.0 + (y/x)**2 + inexact |= sum.flags.test(RealFlag::Inexact); + auto sqrt{sum.value.SQRT()}; + inexact |= sqrt.flags.test(RealFlag::Inexact); + result = sqrt.value.Multiply(ABS(), rounding); + if (inexact) { + result.flags.set(RealFlag::Inexact); + } + } + return result; +} + template ValueWithRealFlags> Real::ToWholeNumber( common::RoundingMode mode) const { Index: flang/test/Evaluate/folding28.f90 =================================================================== --- /dev/null +++ flang/test/Evaluate/folding28.f90 @@ -0,0 +1,40 @@ +! RUN: %S/test_folding.sh %s %t %flang_fc1 +! REQUIRES: shell +! Tests folding of SQRT() +module m + implicit none + ! +Inf + real(8), parameter :: inf8 = z'7ff0000000000000' + logical, parameter :: test_inf8 = sqrt(inf8) == inf8 + ! max finite + real(8), parameter :: h8 = huge(1.0_8), h8z = z'7fefffffffffffff' + logical, parameter :: test_h8 = h8 == h8z + real(8), parameter :: sqrt_h8 = sqrt(h8), sqrt_h8z = z'5fefffffffffffff' + logical, parameter :: test_sqrt_h8 = sqrt_h8 == sqrt_h8z + real(8), parameter :: sqr_sqrt_h8 = sqrt_h8 * sqrt_h8, sqr_sqrt_h8z = z'7feffffffffffffe' + logical, parameter :: test_sqr_sqrt_h8 = sqr_sqrt_h8 == sqr_sqrt_h8z + ! -0 (sqrt is -0) + real(8), parameter :: n08 = z'8000000000000000' + real(8), parameter :: sqrt_n08 = sqrt(n08) +!WARN: division by zero + real(8), parameter :: inf_n08 = 1.0_8 / sqrt_n08, inf_n08z = z'fff0000000000000' + logical, parameter :: test_n08 = inf_n08 == inf_n08z + ! min normal + real(8), parameter :: t8 = tiny(1.0_8), t8z = z'0010000000000000' + logical, parameter :: test_t8 = t8 == t8z + real(8), parameter :: sqrt_t8 = sqrt(t8), sqrt_t8z = z'2000000000000000' + logical, parameter :: test_sqrt_t8 = sqrt_t8 == sqrt_t8z + real(8), parameter :: sqr_sqrt_t8 = sqrt_t8 * sqrt_t8 + logical, parameter :: test_sqr_sqrt_t8 = sqr_sqrt_t8 == t8 + ! max subnormal + real(8), parameter :: maxs8 = z'000fffffffffffff' + real(8), parameter :: sqrt_maxs8 = sqrt(maxs8), sqrt_maxs8z = z'2000000000000000' + logical, parameter :: test_sqrt_maxs8 = sqrt_maxs8 == sqrt_maxs8z + ! min subnormal + real(8), parameter :: mins8 = z'1' + real(8), parameter :: sqrt_mins8 = sqrt(mins8), sqrt_mins8z = z'1e60000000000000' + logical, parameter :: test_sqrt_mins8 = sqrt_mins8 == sqrt_mins8z + real(8), parameter :: sqr_sqrt_mins8 = sqrt_mins8 * sqrt_mins8 + logical, parameter :: test_sqr_sqrt_mins8 = sqr_sqrt_mins8 == mins8 +end module +