diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp --- a/flang/lib/Evaluate/real.cpp +++ b/flang/lib/Evaluate/real.cpp @@ -274,6 +274,7 @@ // SQRT(-0) == -0 in IEEE-754. result.value = NegativeZero(); } else { + result.flags.set(RealFlag::InvalidArgument); result.value = NotANumber(); } } else if (IsInfinite()) { @@ -297,53 +298,31 @@ result.value.GetFraction()); return result; } - // Compute the square root of the reduced value with the slow but - // reliable bit-at-a-time method. Start with a clear significand and - // half of the unbiased exponent, and then try to set significand bits - // in descending order of magnitude without exceeding the exact result. - expo = expo / 2 + exponentBias; - result.value.Normalize(false, expo, Fraction::MASKL(1)); - Real initialSq{result.value.Multiply(result.value).value}; - if (Compare(initialSq) == Relation::Less) { - // Initial estimate is too large; this can happen for values just - // under 1.0. - --expo; - result.value.Normalize(false, expo, Fraction::MASKL(1)); - } - for (int bit{significandBits - 1}; bit >= 0; --bit) { - Word word{result.value.word_}; - result.value.word_ = word.IBSET(bit); - auto squared{result.value.Multiply(result.value, rounding)}; - if (squared.flags.test(RealFlag::Overflow) || - squared.flags.test(RealFlag::Underflow) || - Compare(squared.value) == Relation::Less) { - result.value.word_ = word; - } - } - // The computed square root has a square that's not greater than the - // original argument. Check this square against the square of the next - // larger Real and return that one if its square is closer in magnitude to - // the original argument. - Real resultSq{result.value.Multiply(result.value).value}; - Real diff{Subtract(resultSq).value.ABS()}; - if (diff.IsZero()) { - return result; // exact - } - Real ulp; - ulp.Normalize(false, expo, Fraction::MASKR(1)); - Real nextAfter{result.value.Add(ulp).value}; - auto nextAfterSq{nextAfter.Multiply(nextAfter)}; - if (!nextAfterSq.flags.test(RealFlag::Overflow) && - !nextAfterSq.flags.test(RealFlag::Underflow)) { - Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()}; - if (nextAfterDiff.Compare(diff) == Relation::Less) { - result.value = nextAfter; - if (nextAfterDiff.IsZero()) { - return result; // exact - } + // (-1) <= expo <= 1; use it as a shift to set the desired square. + using Extended = typename value::Integer<(binaryPrecision + 2)>; + Extended goal{ + Extended::ConvertUnsigned(GetFraction()).value.SHIFTL(expo + 1)}; + // Calculate the exact square root by maximizing a value whose square + // does not exceed the goal. Use two extra bits of precision for + // rounding. + bool sticky{true}; + Extended extFrac{}; + for (int bit{Extended::bits - 1}; bit >= 0; --bit) { + Extended next{extFrac.IBSET(bit)}; + auto squared{next.MultiplyUnsigned(next)}; + auto cmp{squared.upper.CompareUnsigned(goal)}; + if (cmp == Ordering::Less) { + extFrac = next; + } else if (cmp == Ordering::Equal && squared.lower.IsZero()) { + extFrac = next; + sticky = false; + break; // exact result } } - result.flags.set(RealFlag::Inexact); + RoundingBits roundingBits{extFrac.BTEST(1), extFrac.BTEST(0), sticky}; + NormalizeAndRound(result, false, exponentBias, + Fraction::ConvertUnsigned(extFrac.SHIFTR(2)).value, rounding, + roundingBits); } return result; } diff --git a/flang/test/Evaluate/folding28.f90 b/flang/test/Evaluate/folding28.f90 --- a/flang/test/Evaluate/folding28.f90 +++ b/flang/test/Evaluate/folding28.f90 @@ -49,4 +49,25 @@ logical, parameter :: test_sqrt_zero_4 = sqrt_zero_4 == 0.0 real(8), parameter :: sqrt_zero_8 = sqrt(0.0) logical, parameter :: test_sqrt_zero_8 = sqrt_zero_8 == 0.0 + ! Some common values to get right + real(8), parameter :: sqrt_1_8 = sqrt(1.d0) + logical, parameter :: test_sqrt_1_8 = sqrt_1_8 == 1.d0 + real(8), parameter :: sqrt_2_8 = sqrt(2.d0) + logical, parameter :: test_sqrt_2_8 = sqrt_2_8 == 1.4142135623730951454746218587388284504413604736328125d0 + real(8), parameter :: sqrt_3_8 = sqrt(3.d0) + logical, parameter :: test_sqrt_3_8 = sqrt_3_8 == 1.732050807568877193176604123436845839023590087890625d0 + real(8), parameter :: sqrt_4_8 = sqrt(4.d0) + logical, parameter :: test_sqrt_4_8 = sqrt_4_8 == 2.d0 + real(8), parameter :: sqrt_5_8 = sqrt(5.d0) + logical, parameter :: test_sqrt_5_8 = sqrt_5_8 == 2.236067977499789805051477742381393909454345703125d0 + real(8), parameter :: sqrt_6_8 = sqrt(6.d0) + logical, parameter :: test_sqrt_6_8 = sqrt_6_8 == 2.44948974278317788133563226438127458095550537109375d0 + real(8), parameter :: sqrt_7_8 = sqrt(7.d0) + logical, parameter :: test_sqrt_7_8 = sqrt_7_8 == 2.64575131106459071617109657381661236286163330078125d0 + real(8), parameter :: sqrt_8_8 = sqrt(8.d0) + logical, parameter :: test_sqrt_8_8 = sqrt_8_8 == 2.828427124746190290949243717477656900882720947265625d0 + real(8), parameter :: sqrt_9_8 = sqrt(9.d0) + logical, parameter :: test_sqrt_9_8 = sqrt_9_8 == 3.d0 + real(8), parameter :: sqrt_10_8 = sqrt(10.d0) + logical, parameter :: test_sqrt_10_8 = sqrt_10_8 == 3.162277660168379522787063251598738133907318115234375d0 end module diff --git a/flang/unittests/Evaluate/real.cpp b/flang/unittests/Evaluate/real.cpp --- a/flang/unittests/Evaluate/real.cpp +++ b/flang/unittests/Evaluate/real.cpp @@ -392,6 +392,22 @@ ("%d AINT(0x%jx)", pass, static_cast(rj)); } + { + ValueWithRealFlags root{x.SQRT(rounding)}; +#ifndef __clang__ // broken and also slow + fpenv.ClearFlags(); +#endif + FLT fcheck{std::sqrt(fj)}; + auto actualFlags{FlagsToBits(fpenv.CurrentFlags())}; + u.f = fcheck; + UINT rcheck{NormalizeNaN(u.ui)}; + UINT check = root.value.RawBits().ToUInt64(); + MATCH(rcheck, check) + ("%d SQRT(0x%jx)", pass, static_cast(rj)); + MATCH(actualFlags, FlagsToBits(root.flags)) + ("%d SQRT(0x%jx)", pass, static_cast(rj)); + } + { MATCH(IsNaN(rj), x.IsNotANumber()) ("%d IsNaN(0x%jx)", pass, static_cast(rj));