diff --git a/flang/lib/Decimal/decimal-to-binary.cpp b/flang/lib/Decimal/decimal-to-binary.cpp --- a/flang/lib/Decimal/decimal-to-binary.cpp +++ b/flang/lib/Decimal/decimal-to-binary.cpp @@ -408,19 +408,37 @@ } else { // Could not parse a decimal floating-point number. p has been // advanced over any leading spaces. - if (toupper(p[0]) == 'N' && toupper(p[1]) == 'A' && toupper(p[2]) == 'N') { + if ((!limit || limit >= p + 3) && toupper(p[0]) == 'N' && + toupper(p[1]) == 'A' && toupper(p[2]) == 'N') { // NaN p += 3; + if ((!limit || p < limit) && *p == '(') { + int depth{1}; + do { + ++p; + if (limit && p >= limit) { + // Invalid input + return {Real{NaN()}, Invalid}; + } else if (*p == '(') { + ++depth; + } else if (*p == ')') { + --depth; + } + } while (depth > 0); + ++p; + } return {Real{NaN()}}; } else { // Try to parse Inf, maybe with a sign const char *q{p}; - isNegative_ = *q == '-'; - if (*q == '-' || *q == '+') { - ++q; + if (!limit || q < limit) { + isNegative_ = *q == '-'; + if (isNegative_ || *q == '+') { + ++q; + } } - if (toupper(q[0]) == 'I' && toupper(q[1]) == 'N' && - toupper(q[2]) == 'F') { + if ((!limit || limit >= q + 3) && toupper(q[0]) == 'I' && + toupper(q[1]) == 'N' && toupper(q[2]) == 'F') { p = q + 3; return {Real{Infinity()}}; } else { diff --git a/flang/runtime/edit-input.cpp b/flang/runtime/edit-input.cpp --- a/flang/runtime/edit-input.cpp +++ b/flang/runtime/edit-input.cpp @@ -176,9 +176,19 @@ } } if (next && *next == '(') { // NaN(...) - while (next && *next != ')') { + Put('('); + int depth{1}; + do { next = io.NextInField(remaining, edit); - } + if (!next) { + break; + } else if (*next == '(') { + ++depth; + } else if (*next == ')') { + --depth; + } + Put(*next); + } while (depth > 0); } exponent = 0; } else if (first == decimal || (first >= '0' && first <= '9') || @@ -225,7 +235,7 @@ exponent = -edit.modes.scale; if (next && (*next == '-' || *next == '+' || (*next >= '0' && *next <= '9') || - (bzMode && (*next == ' ' || *next == '\t')))) { + *next == ' ' || *next == '\t')) { bool negExpo{*next == '-'}; if (negExpo || *next == '+') { next = io.NextInField(remaining, edit); @@ -233,8 +243,10 @@ for (exponent = 0; next; next = io.NextInField(remaining, edit)) { if (*next >= '0' && *next <= '9') { exponent = 10 * exponent + *next - '0'; - } else if (bzMode && (*next == ' ' || *next == '\t')) { - exponent = 10 * exponent; + } else if (*next == ' ' || *next == '\t') { + if (bzMode) { + exponent = 10 * exponent; + } } else { break; } @@ -328,11 +340,19 @@ if (converted.flags & decimal::Invalid) { return false; } - if (edit.digits.value_or(0) != 0 && - std::memchr(str, '.', p - str) == nullptr) { - // No explicit decimal point, and edit descriptor is Fw.d (or other) - // with d != 0, which implies scaling. - return false; + if (edit.digits.value_or(0) != 0) { + // Edit descriptor is Fw.d (or other) with d != 0, which + // implies scaling + const char *q{str}; + for (; q < limit; ++q) { + if (*q == '.' || *q == 'n' || *q == 'N') { + break; + } + } + if (q == limit) { + // No explicit decimal point, and not NaN/Inf. + return false; + } } for (; p < limit && (*p == ' ' || *p == '\t'); ++p) { } @@ -422,6 +442,10 @@ converted.flags = static_cast( converted.flags | decimal::Inexact); } + if (*p) { // unprocessed junk after value + io.GetIoErrorHandler().SignalError(IostatBadRealInput); + return false; + } *reinterpret_cast *>(n) = converted.binary; // Set FP exception flags diff --git a/flang/unittests/Decimal/quick-sanity-test.cpp b/flang/unittests/Decimal/quick-sanity-test.cpp --- a/flang/unittests/Decimal/quick-sanity-test.cpp +++ b/flang/unittests/Decimal/quick-sanity-test.cpp @@ -61,13 +61,15 @@ if (!(x == x)) { if (y == y || *p != '\0' || (rflags & Invalid)) { u.x = y; - failed(x) << " (NaN) " << flags << ": -> '" << result.str << "' -> 0x"; - failed(x).write_hex(u.u) << " '" << p << "' " << rflags << '\n'; + (failed(x) << " (NaN) " << flags << ": -> '" << result.str << "' -> 0x") + .write_hex(u.u) + << " '" << p << "' " << rflags << '\n'; } } else if (x != y || *p != '\0' || (rflags & Invalid)) { - u.x = y; - failed(x) << ' ' << flags << ": -> '" << result.str << "' -> 0x"; - failed(x).write_hex(u.u) << " '" << p << "' " << rflags << '\n'; + u.x = x; + (failed(x) << ' ' << flags << ": -> '" << result.str << "' -> 0x") + .write_hex(u.u) + << " '" << p << "' " << rflags << '\n'; } } }