Index: llvm/lib/Support/APFloat.cpp =================================================================== --- llvm/lib/Support/APFloat.cpp +++ llvm/lib/Support/APFloat.cpp @@ -2617,27 +2617,96 @@ } bool IEEEFloat::convertFromStringSpecials(StringRef str) { - if (str.equals("inf") || str.equals("INFINITY") || str.equals("+Inf")) { - makeInf(false); - return true; - } + if (str.size() < 3) + return false; - if (str.equals("-inf") || str.equals("-INFINITY") || str.equals("-Inf")) { - makeInf(true); - return true; - } + bool SNaN = false; + char NDiff = 'N' - 'Q'; + char C = str[0]; + str = str.drop_front(); + switch (C) { + case 'I': + // Check for one of: INF, INFINITY. + if (str.startswith("NF")) { + str = str.drop_front(2); + if (!str.empty() && !str.equals("INITY")) + return false; + + makeInf(isNegative()); + return true; + } + LLVM_FALLTHROUGH; + case 'i': + // Check for one of: inf, Inf, infinity, Infinity. + if (!str.startswith("nf")) + return false; - if (str.equals("nan") || str.equals("NaN")) { - makeNaN(false, false); - return true; - } + str = str.drop_front(2); + if (!str.empty() && !str.equals("inity")) + return false; - if (str.equals("-nan") || str.equals("-NaN")) { - makeNaN(false, true); + makeInf(isNegative()); return true; - } - return false; + case 'S': + case 's': + SNaN = true; + NDiff = 'N' - 'S'; + LLVM_FALLTHROUGH; + case 'Q': + case 'q': + if (str.size() < 3) + return false; + + // If the next character is not 'n' or 'N' (if the last character was 'S' or + // 'Q'). + if (str[0] != 'N' && str[0] != (C + NDiff)) + return false; + + C = str[0]; + str = str.drop_front(); + LLVM_FALLTHROUGH; + case 'N': + case 'n': + // Check for one of: nan, NaN, NAN. + if ((str[0] == 'a' || str[0] == (C - ('N' - 'A'))) && str[1] == C) { + str = str.drop_front(2); + if (str.empty()) { + makeNaN(SNaN, isNegative()); + return true; + } + + // Allow the payload to be inside parentheses. + if (str[0] == '(') { + if (!(str.size() > 2 && str.back() == ')')) + return false; + + str = str.substr(1, str.size() - 2); + } + + // Determine the payload number's radix. + unsigned Radix; + if (str[0] == '0') { + if (str.size() > 2 && (str[1] == 'x' || str[1] == 'X')) { + str = str.drop_front(2); + Radix = 16; + } else + Radix = 8; + } else + Radix = 10; + + // Parse the payload and make the NaN.. + APInt Payload; + if (!str.getAsInteger(Radix, Payload)) { + makeNaN(SNaN, isNegative(), &Payload); + return true; + } + } + LLVM_FALLTHROUGH; + + default: + return false; + } } Expected @@ -2645,29 +2714,28 @@ if (str.empty()) return createError("Invalid string length"); + // Handle a leading plus/minus sign. + bool Negative = str[0] == '-'; + if (Negative || str[0] == '+') { + str = str.drop_front(); + if (str.empty()) + return createError("String has no digits"); + } + + sign = Negative; + // Handle special cases. if (convertFromStringSpecials(str)) return opOK; - /* Handle a leading minus sign. */ - StringRef::iterator p = str.begin(); - size_t slen = str.size(); - sign = *p == '-' ? 1 : 0; - if (*p == '-' || *p == '+') { - p++; - slen--; - if (!slen) - return createError("String has no digits"); - } - - if (slen >= 2 && p[0] == '0' && (p[1] == 'x' || p[1] == 'X')) { - if (slen == 2) + if (str.size() >= 2 && str[0] == '0' && (str[1] == 'x' || str[1] == 'X')) { + str = str.drop_front(2); + if (str.empty()) return createError("Invalid string"); - return convertFromHexadecimalString(StringRef(p + 2, slen - 2), - rounding_mode); + return convertFromHexadecimalString(str, rounding_mode); } - return convertFromDecimalString(StringRef(p, slen), rounding_mode); + return convertFromDecimalString(str, rounding_mode); } /* Write out a hexadecimal representation of the floating point value Index: llvm/unittests/ADT/APFloatTest.cpp =================================================================== --- llvm/unittests/ADT/APFloatTest.cpp +++ llvm/unittests/ADT/APFloatTest.cpp @@ -913,6 +913,147 @@ EXPECT_EQ(2.71828, convertToDoubleFromString("2.71828")); } +TEST(APFloatTest, fromStringSpecials) { + const StringRef NaNStrings[] = {"nan", "NaN", "NAN", "naN", + "nAn", "nAN", "Nan", "NAn"}; + const size_t NumValidNaNStrings = 3; // only the first 3 are valid + + const fltSemantics &Sem = APFloat::IEEEdouble(); + const unsigned Precision = 53; + const unsigned PayloadBits = Precision - 2; + uint64_t PayloadMask = (uint64_t(1) << PayloadBits) - uint64_t(1); + + uint64_t NaNPayloads[] = { + 0, + 1, + 123, + 0xDEADBEEF, + uint64_t(-2), + uint64_t(1) << PayloadBits, // overflow bit + uint64_t(1) << (PayloadBits - 1), // signaling bit + uint64_t(1) << (PayloadBits - 2) // highest possible bit + }; + + std::string NaNPayloadDecStrings[array_lengthof(NaNPayloads)]; + for (size_t I = 0; I < array_lengthof(NaNPayloads); ++I) + NaNPayloadDecStrings[I] = utostr(NaNPayloads[I]); + + std::string NaNPayloadHexStrings[array_lengthof(NaNPayloads)]; + for (size_t I = 0; I < array_lengthof(NaNPayloads); ++I) + NaNPayloadHexStrings[I] = "0x" + utohexstr(NaNPayloads[I]); + + // Fix payloads to expected result. + for (uint64_t &Payload : NaNPayloads) + Payload &= PayloadMask; + + // Signaling NaN must have a non-zero payload. In case a zero payload is + // requested, a default arbitrary payload is set instead. Save this payload + // for testing. + const uint64_t SNaNDefaultPayload = + APFloat::getSNaN(Sem).bitcastToAPInt().getZExtValue() & PayloadMask; + + const char Signs[] = {0, '+', '-'}; + const char NaNTypes[] = {0, 'q', 'Q', 's', 'S'}; + + for (size_t I = 0; I < array_lengthof(NaNStrings); ++I) { + StringRef NaNStr = NaNStrings[I]; + + for (char TypeChar : NaNTypes) { + bool TestSuccess = I < NumValidNaNStrings && + (!isupper(TypeChar) || NaNStr.front() == 'N'); + + bool Signaling = (TypeChar == 's' || TypeChar == 'S'); + + for (size_t J = 0; J < array_lengthof(NaNPayloads); ++J) { + uint64_t Payload = (Signaling && !NaNPayloads[J]) ? SNaNDefaultPayload + : NaNPayloads[J]; + std::string &PayloadDec = NaNPayloadDecStrings[J]; + std::string &PayloadHex = NaNPayloadHexStrings[J]; + + for (char SignChar : Signs) { + bool Negative = (SignChar == '-'); + + std::string TestStrings[5]; + size_t NumTestStrings = 0; + + std::string Prefix; + if (SignChar) + Prefix += SignChar; + if (TypeChar) + Prefix += TypeChar; + Prefix += NaNStr; + + // Test without any paylod. + if (!Payload) + TestStrings[NumTestStrings++] = Prefix; + + // Test with the payload as a suffix. + TestStrings[NumTestStrings++] = Prefix + PayloadDec; + TestStrings[NumTestStrings++] = Prefix + PayloadHex; + + // Test with the payload inside parentheses. + TestStrings[NumTestStrings++] = Prefix + '(' + PayloadDec + ')'; + TestStrings[NumTestStrings++] = Prefix + '(' + PayloadHex + ')'; + + for (size_t K = 0; K < NumTestStrings; ++K) { + StringRef TestStr = TestStrings[K]; + + APFloat F(Sem); + bool HasError = !F.convertFromString( + TestStr, llvm::APFloat::rmNearestTiesToEven); + if (TestSuccess) { + EXPECT_FALSE(HasError); + EXPECT_TRUE(F.isNaN()); + EXPECT_EQ(Signaling, F.isSignaling()); + EXPECT_EQ(Negative, F.isNegative()); + uint64_t PayloadResult = + F.bitcastToAPInt().getZExtValue() & PayloadMask; + EXPECT_EQ(Payload, PayloadResult); + } else + EXPECT_TRUE(HasError); + } + } + } + } + } + + const StringRef InfStrings[] = {"inf", "Inf", "INF", "infinity", + "Infinity", "INFINITY", "inF", "iNf", + "iNF", "InF", "INf", "INFinity"}; + const size_t NumValidInfStrings = 6; // only the first 6 are valid + + for (size_t I = 0; I < array_lengthof(InfStrings); ++I) { + StringRef InfStr = InfStrings[I]; + bool TestSuccess = I < NumValidInfStrings; + + for (char SignChar : Signs) { + bool Negative = (SignChar == '-'); + + StringRef TestStr; + std::string CombinedStr; + if (SignChar) { + CombinedStr += SignChar; + CombinedStr += InfStr; + TestStr = CombinedStr; + } else + TestStr = InfStr; + + APFloat F(Sem); + bool HasError = + !F.convertFromString(TestStr, llvm::APFloat::rmNearestTiesToEven); + if (TestSuccess) { + EXPECT_FALSE(HasError); + EXPECT_TRUE(F.isInfinity()); + EXPECT_EQ(Negative, F.isNegative()); + uint64_t PayloadResult = + F.bitcastToAPInt().getZExtValue() & PayloadMask; + EXPECT_EQ(0, PayloadResult); + } else + EXPECT_TRUE(HasError); + } + } +} + TEST(APFloatTest, fromToStringSpecials) { auto expects = [] (const char *first, const char *second) { std::string roundtrip = convertToString(convertToDoubleFromString(second), 0, 3);