diff --git a/flang/include/flang/Common/long-double.h b/flang/include/flang/Common/long-double.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Common/long-double.h @@ -0,0 +1,23 @@ +/*===-- include/flang/Common/config.h -------------------------------*- C -*-=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * ===-----------------------------------------------------------------------=== + */ + +/* This header can be used by both C and C++. */ + +#ifndef FORTRAN_COMMON_LONG_DOUBLE_H +#define FORTRAN_COMMON_LONG_DOUBLE_H + +#ifdef _MSC_VER /* no long double */ +#undef LONG_DOUBLE +#elif __x86_64__ /* x87 extended precision */ +#define LONG_DOUBLE 80 +#else +#define LONG_DOUBLE 128 +#endif + +#endif /* FORTRAN_COMMON_LONG_DOUBLE_H */ diff --git a/flang/include/flang/Common/uint128.h b/flang/include/flang/Common/uint128.h --- a/flang/include/flang/Common/uint128.h +++ b/flang/include/flang/Common/uint128.h @@ -25,31 +25,31 @@ namespace Fortran::common { -class UnsignedInt128 { +template class Int128 { public: - constexpr UnsignedInt128() {} + constexpr Int128() {} // This means of definition provides some portability for // "size_t" operands. - constexpr UnsignedInt128(unsigned n) : low_{n} {} - constexpr UnsignedInt128(unsigned long n) : low_{n} {} - constexpr UnsignedInt128(unsigned long long n) : low_{n} {} - constexpr UnsignedInt128(int n) + constexpr Int128(unsigned n) : low_{n} {} + constexpr Int128(unsigned long n) : low_{n} {} + constexpr Int128(unsigned long long n) : low_{n} {} + constexpr Int128(int n) : low_{static_cast(n)}, high_{-static_cast( n < 0)} {} - constexpr UnsignedInt128(long n) + constexpr Int128(long n) : low_{static_cast(n)}, high_{-static_cast( n < 0)} {} - constexpr UnsignedInt128(long long n) + constexpr Int128(long long n) : low_{static_cast(n)}, high_{-static_cast( n < 0)} {} - constexpr UnsignedInt128(const UnsignedInt128 &) = default; - constexpr UnsignedInt128(UnsignedInt128 &&) = default; - constexpr UnsignedInt128 &operator=(const UnsignedInt128 &) = default; - constexpr UnsignedInt128 &operator=(UnsignedInt128 &&) = default; + constexpr Int128(const Int128 &) = default; + constexpr Int128(Int128 &&) = default; + constexpr Int128 &operator=(const Int128 &) = default; + constexpr Int128 &operator=(Int128 &&) = default; - constexpr UnsignedInt128 operator+() const { return *this; } - constexpr UnsignedInt128 operator~() const { return {~high_, ~low_}; } - constexpr UnsignedInt128 operator-() const { return ~*this + 1; } + constexpr Int128 operator+() const { return *this; } + constexpr Int128 operator~() const { return {~high_, ~low_}; } + constexpr Int128 operator-() const { return ~*this + 1; } constexpr bool operator!() const { return !low_ && !high_; } constexpr explicit operator bool() const { return low_ || high_; } constexpr explicit operator std::uint64_t() const { return low_; } @@ -59,36 +59,36 @@ constexpr std::uint64_t high() const { return high_; } constexpr std::uint64_t low() const { return low_; } - constexpr UnsignedInt128 operator++(/*prefix*/) { + constexpr Int128 operator++(/*prefix*/) { *this += 1; return *this; } - constexpr UnsignedInt128 operator++(int /*postfix*/) { - UnsignedInt128 result{*this}; + constexpr Int128 operator++(int /*postfix*/) { + Int128 result{*this}; *this += 1; return result; } - constexpr UnsignedInt128 operator--(/*prefix*/) { + constexpr Int128 operator--(/*prefix*/) { *this -= 1; return *this; } - constexpr UnsignedInt128 operator--(int /*postfix*/) { - UnsignedInt128 result{*this}; + constexpr Int128 operator--(int /*postfix*/) { + Int128 result{*this}; *this -= 1; return result; } - constexpr UnsignedInt128 operator&(UnsignedInt128 that) const { + constexpr Int128 operator&(Int128 that) const { return {high_ & that.high_, low_ & that.low_}; } - constexpr UnsignedInt128 operator|(UnsignedInt128 that) const { + constexpr Int128 operator|(Int128 that) const { return {high_ | that.high_, low_ | that.low_}; } - constexpr UnsignedInt128 operator^(UnsignedInt128 that) const { + constexpr Int128 operator^(Int128 that) const { return {high_ ^ that.high_, low_ ^ that.low_}; } - constexpr UnsignedInt128 operator<<(UnsignedInt128 that) const { + constexpr Int128 operator<<(Int128 that) const { if (that >= 128) { return {}; } else if (that == 0) { @@ -102,7 +102,7 @@ } } } - constexpr UnsignedInt128 operator>>(UnsignedInt128 that) const { + constexpr Int128 operator>>(Int128 that) const { if (that >= 128) { return {}; } else if (that == 0) { @@ -117,43 +117,41 @@ } } - constexpr UnsignedInt128 operator+(UnsignedInt128 that) const { + constexpr Int128 operator+(Int128 that) const { std::uint64_t lower{(low_ & ~topBit) + (that.low_ & ~topBit)}; bool carry{((lower >> 63) + (low_ >> 63) + (that.low_ >> 63)) > 1}; return {high_ + that.high_ + carry, low_ + that.low_}; } - constexpr UnsignedInt128 operator-(UnsignedInt128 that) const { - return *this + -that; - } + constexpr Int128 operator-(Int128 that) const { return *this + -that; } - constexpr UnsignedInt128 operator*(UnsignedInt128 that) const { + constexpr Int128 operator*(Int128 that) const { std::uint64_t mask32{0xffffffff}; if (high_ == 0 && that.high_ == 0) { std::uint64_t x0{low_ & mask32}, x1{low_ >> 32}; std::uint64_t y0{that.low_ & mask32}, y1{that.low_ >> 32}; - UnsignedInt128 x0y0{x0 * y0}, x0y1{x0 * y1}; - UnsignedInt128 x1y0{x1 * y0}, x1y1{x1 * y1}; + Int128 x0y0{x0 * y0}, x0y1{x0 * y1}; + Int128 x1y0{x1 * y0}, x1y1{x1 * y1}; return x0y0 + ((x0y1 + x1y0) << 32) + (x1y1 << 64); } else { std::uint64_t x0{low_ & mask32}, x1{low_ >> 32}, x2{high_ & mask32}, x3{high_ >> 32}; std::uint64_t y0{that.low_ & mask32}, y1{that.low_ >> 32}, y2{that.high_ & mask32}, y3{that.high_ >> 32}; - UnsignedInt128 x0y0{x0 * y0}, x0y1{x0 * y1}, x0y2{x0 * y2}, x0y3{x0 * y3}; - UnsignedInt128 x1y0{x1 * y0}, x1y1{x1 * y1}, x1y2{x1 * y2}; - UnsignedInt128 x2y0{x2 * y0}, x2y1{x2 * y1}; - UnsignedInt128 x3y0{x3 * y0}; + Int128 x0y0{x0 * y0}, x0y1{x0 * y1}, x0y2{x0 * y2}, x0y3{x0 * y3}; + Int128 x1y0{x1 * y0}, x1y1{x1 * y1}, x1y2{x1 * y2}; + Int128 x2y0{x2 * y0}, x2y1{x2 * y1}; + Int128 x3y0{x3 * y0}; return x0y0 + ((x0y1 + x1y0) << 32) + ((x0y2 + x1y1 + x2y0) << 64) + ((x0y3 + x1y2 + x2y1 + x3y0) << 96); } } - constexpr UnsignedInt128 operator/(UnsignedInt128 that) const { + constexpr Int128 operator/(Int128 that) const { int j{LeadingZeroes()}; - UnsignedInt128 bits{*this}; + Int128 bits{*this}; bits <<= j; - UnsignedInt128 numerator{}; - UnsignedInt128 quotient{}; + Int128 numerator{}; + Int128 quotient{}; for (; j < 128; ++j) { numerator <<= 1; if (bits.high_ & topBit) { @@ -169,11 +167,11 @@ return quotient; } - constexpr UnsignedInt128 operator%(UnsignedInt128 that) const { + constexpr Int128 operator%(Int128 that) const { int j{LeadingZeroes()}; - UnsignedInt128 bits{*this}; + Int128 bits{*this}; bits <<= j; - UnsignedInt128 remainder{}; + Int128 remainder{}; for (; j < 128; ++j) { remainder <<= 1; if (bits.high_ & topBit) { @@ -187,65 +185,63 @@ return remainder; } - constexpr bool operator<(UnsignedInt128 that) const { + constexpr bool operator<(Int128 that) const { + if (IS_SIGNED && (high_ ^ that.high_) & topBit) { + return (high_ & topBit) != 0; + } return high_ < that.high_ || (high_ == that.high_ && low_ < that.low_); } - constexpr bool operator<=(UnsignedInt128 that) const { - return !(*this > that); - } - constexpr bool operator==(UnsignedInt128 that) const { + constexpr bool operator<=(Int128 that) const { return !(*this > that); } + constexpr bool operator==(Int128 that) const { return low_ == that.low_ && high_ == that.high_; } - constexpr bool operator!=(UnsignedInt128 that) const { - return !(*this == that); - } - constexpr bool operator>=(UnsignedInt128 that) const { return that <= *this; } - constexpr bool operator>(UnsignedInt128 that) const { return that < *this; } + constexpr bool operator!=(Int128 that) const { return !(*this == that); } + constexpr bool operator>=(Int128 that) const { return that <= *this; } + constexpr bool operator>(Int128 that) const { return that < *this; } - constexpr UnsignedInt128 &operator&=(const UnsignedInt128 &that) { + constexpr Int128 &operator&=(const Int128 &that) { *this = *this & that; return *this; } - constexpr UnsignedInt128 &operator|=(const UnsignedInt128 &that) { + constexpr Int128 &operator|=(const Int128 &that) { *this = *this | that; return *this; } - constexpr UnsignedInt128 &operator^=(const UnsignedInt128 &that) { + constexpr Int128 &operator^=(const Int128 &that) { *this = *this ^ that; return *this; } - constexpr UnsignedInt128 &operator<<=(const UnsignedInt128 &that) { + constexpr Int128 &operator<<=(const Int128 &that) { *this = *this << that; return *this; } - constexpr UnsignedInt128 &operator>>=(const UnsignedInt128 &that) { + constexpr Int128 &operator>>=(const Int128 &that) { *this = *this >> that; return *this; } - constexpr UnsignedInt128 &operator+=(const UnsignedInt128 &that) { + constexpr Int128 &operator+=(const Int128 &that) { *this = *this + that; return *this; } - constexpr UnsignedInt128 &operator-=(const UnsignedInt128 &that) { + constexpr Int128 &operator-=(const Int128 &that) { *this = *this - that; return *this; } - constexpr UnsignedInt128 &operator*=(const UnsignedInt128 &that) { + constexpr Int128 &operator*=(const Int128 &that) { *this = *this * that; return *this; } - constexpr UnsignedInt128 &operator/=(const UnsignedInt128 &that) { + constexpr Int128 &operator/=(const Int128 &that) { *this = *this / that; return *this; } - constexpr UnsignedInt128 &operator%=(const UnsignedInt128 &that) { + constexpr Int128 &operator%=(const Int128 &that) { *this = *this % that; return *this; } private: - constexpr UnsignedInt128(std::uint64_t hi, std::uint64_t lo) - : low_{lo}, high_{hi} {} + constexpr Int128(std::uint64_t hi, std::uint64_t lo) : low_{lo}, high_{hi} {} constexpr int LeadingZeroes() const { if (high_ == 0) { return 64 + LeadingZeroBitCount(low_); @@ -257,12 +253,16 @@ std::uint64_t low_{0}, high_{0}; }; -#if AVOID_NATIVE_UINT128_T -using uint128_t = UnsignedInt128; -#elif (defined __GNUC__ || defined __clang__) && defined __SIZEOF_INT128__ +using UnsignedInt128 = Int128; +using SignedInt128 = Int128; + +#if !AVOID_NATIVE_UINT128_t && (defined __GNUC__ || defined __clang__) && \ + defined __SIZEOF_INT128__ using uint128_t = __uint128_t; +using int128_t = __int128_t; #else using uint128_t = UnsignedInt128; +using int128_t = SignedInt128; #endif template struct HostUnsignedIntTypeHelper { @@ -271,8 +271,16 @@ std::conditional_t<(BITS <= 32), std::uint32_t, std::conditional_t<(BITS <= 64), std::uint64_t, uint128_t>>>>; }; +template struct HostSignedIntTypeHelper { + using type = std::conditional_t<(BITS <= 8), std::int8_t, + std::conditional_t<(BITS <= 16), std::int16_t, + std::conditional_t<(BITS <= 32), std::int32_t, + std::conditional_t<(BITS <= 64), std::int64_t, int128_t>>>>; +}; template using HostUnsignedIntType = typename HostUnsignedIntTypeHelper::type; +template +using HostSignedIntType = typename HostSignedIntTypeHelper::type; } // namespace Fortran::common #endif diff --git a/flang/include/flang/Decimal/decimal.h b/flang/include/flang/Decimal/decimal.h --- a/flang/include/flang/Decimal/decimal.h +++ b/flang/include/flang/Decimal/decimal.h @@ -129,20 +129,16 @@ struct NS(ConversionToDecimalResult) ConvertDoubleToDecimal(char *, size_t, enum NS(DecimalConversionFlags), int digits, enum NS(FortranRounding), double); -#if __x86_64__ && !defined(_MSC_VER) struct NS(ConversionToDecimalResult) ConvertLongDoubleToDecimal(char *, size_t, enum NS(DecimalConversionFlags), int digits, enum NS(FortranRounding), long double); -#endif enum NS(ConversionResultFlags) ConvertDecimalToFloat(const char **, float *, enum NS(FortranRounding)); enum NS(ConversionResultFlags) ConvertDecimalToDouble(const char **, double *, enum NS(FortranRounding)); -#if __x86_64__ && !defined(_MSC_VER) enum NS(ConversionResultFlags) ConvertDecimalToLongDouble( const char **, long double *, enum NS(FortranRounding)); -#endif #undef NS #ifdef __cplusplus } // extern "C" diff --git a/flang/lib/Decimal/binary-to-decimal.cpp b/flang/lib/Decimal/binary-to-decimal.cpp --- a/flang/lib/Decimal/binary-to-decimal.cpp +++ b/flang/lib/Decimal/binary-to-decimal.cpp @@ -350,14 +350,12 @@ rounding, Fortran::decimal::BinaryFloatingPointNumber<53>(x)); } -#if __x86_64__ && !defined(_MSC_VER) ConversionToDecimalResult ConvertLongDoubleToDecimal(char *buffer, std::size_t size, enum DecimalConversionFlags flags, int digits, enum FortranRounding rounding, long double x) { return Fortran::decimal::ConvertToDecimal(buffer, size, flags, digits, rounding, Fortran::decimal::BinaryFloatingPointNumber<64>(x)); } -#endif } template diff --git a/flang/lib/Decimal/decimal-to-binary.cpp b/flang/lib/Decimal/decimal-to-binary.cpp --- a/flang/lib/Decimal/decimal-to-binary.cpp +++ b/flang/lib/Decimal/decimal-to-binary.cpp @@ -454,7 +454,6 @@ reinterpret_cast(&result.binary), sizeof *d); return result.flags; } -#if __x86_64__ && !defined(_MSC_VER) enum ConversionResultFlags ConvertDecimalToLongDouble( const char **p, long double *ld, enum FortranRounding rounding) { auto result{Fortran::decimal::ConvertToBinary<64>(*p, rounding)}; @@ -462,6 +461,5 @@ reinterpret_cast(&result.binary), sizeof *ld); return result.flags; } -#endif } } // namespace Fortran::decimal diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt --- a/flang/runtime/CMakeLists.txt +++ b/flang/runtime/CMakeLists.txt @@ -34,6 +34,7 @@ ISO_Fortran_binding.cpp allocatable.cpp buffer.cpp + complex-reduction.c character.cpp connection.cpp derived.cpp @@ -50,6 +51,7 @@ io-stmt.cpp main.cpp memory.cpp + reduction.cpp stat.cpp stop.cpp terminator.cpp diff --git a/flang/runtime/character.h b/flang/runtime/character.h --- a/flang/runtime/character.h +++ b/flang/runtime/character.h @@ -18,6 +18,16 @@ class Descriptor; +template +int CharacterScalarCompare( + const CHAR *x, const CHAR *y, std::size_t xChars, std::size_t yChars); +extern template int CharacterScalarCompare( + const char *x, const char *y, std::size_t xChars, std::size_t yChars); +extern template int CharacterScalarCompare(const char16_t *x, + const char16_t *y, std::size_t xChars, std::size_t yChars); +extern template int CharacterScalarCompare(const char32_t *x, + const char32_t *y, std::size_t xChars, std::size_t yChars); + extern "C" { // Appends the corresponding (or expanded) characters of 'operand' diff --git a/flang/runtime/character.cpp b/flang/runtime/character.cpp --- a/flang/runtime/character.cpp +++ b/flang/runtime/character.cpp @@ -7,8 +7,10 @@ //===----------------------------------------------------------------------===// #include "character.h" +#include "cpp-type.h" #include "descriptor.h" #include "terminator.h" +#include "tools.h" #include "flang/Common/bit-population-count.h" #include "flang/Common/uint128.h" #include @@ -30,7 +32,7 @@ } template -static int Compare( +int CharacterScalarCompare( const CHAR *x, const CHAR *y, std::size_t xChars, std::size_t yChars) { auto minChars{std::min(xChars, yChars)}; if constexpr (sizeof(CHAR) == 1) { @@ -63,6 +65,13 @@ return -CompareToBlankPadding(y, yChars - minChars); } +template int CharacterScalarCompare( + const char *x, const char *y, std::size_t xChars, std::size_t yChars); +template int CharacterScalarCompare(const char16_t *x, + const char16_t *y, std::size_t xChars, std::size_t yChars); +template int CharacterScalarCompare(const char32_t *x, + const char32_t *y, std::size_t xChars, std::size_t yChars); + // Shift count to use when converting between character lengths // and byte counts. template @@ -103,8 +112,8 @@ std::size_t yChars{y.ElementBytes() >> shift}; for (SubscriptValue resultAt{0}; elements-- > 0; ++resultAt, x.IncrementSubscripts(xAt), y.IncrementSubscripts(yAt)) { - *result.OffsetElement(resultAt) = - Compare(x.Element(xAt), y.Element(yAt), xChars, yChars); + *result.OffsetElement(resultAt) = CharacterScalarCompare( + x.Element(xAt), y.Element(yAt), xChars, yChars); } } @@ -216,38 +225,30 @@ const Terminator &terminator) { switch (kind) { case 1: - LenTrim(result, string, terminator); + LenTrim, CHAR>( + result, string, terminator); break; case 2: - LenTrim(result, string, terminator); + LenTrim, CHAR>( + result, string, terminator); break; case 4: - LenTrim(result, string, terminator); + LenTrim, CHAR>( + result, string, terminator); break; case 8: - LenTrim(result, string, terminator); + LenTrim, CHAR>( + result, string, terminator); break; case 16: - LenTrim(result, string, terminator); + LenTrim, CHAR>( + result, string, terminator); break; default: terminator.Crash("LEN_TRIM: bad KIND=%d", kind); } } -// Utility for dealing with elemental LOGICAL arguments -static bool IsLogicalElementTrue( - const Descriptor &logical, const SubscriptValue at[]) { - // A LOGICAL value is false if and only if all of its bytes are zero. - const char *p{logical.Element(at)}; - for (std::size_t j{logical.ElementBytes()}; j-- > 0; ++p) { - if (*p) { - return true; - } - } - return false; -} - // INDEX implementation template inline std::size_t Index(const CHAR *x, std::size_t xLen, const CHAR *want, @@ -419,23 +420,23 @@ const Terminator &terminator) { switch (kind) { case 1: - GeneralCharFunc( + GeneralCharFunc, CHAR, FUNC>( result, string, arg, back, terminator); break; case 2: - GeneralCharFunc( + GeneralCharFunc, CHAR, FUNC>( result, string, arg, back, terminator); break; case 4: - GeneralCharFunc( + GeneralCharFunc, CHAR, FUNC>( result, string, arg, back, terminator); break; case 8: - GeneralCharFunc( + GeneralCharFunc, CHAR, FUNC>( result, string, arg, back, terminator); break; case 16: - GeneralCharFunc( + GeneralCharFunc, CHAR, FUNC>( result, string, arg, back, terminator); break; default: @@ -509,7 +510,7 @@ for (CHAR *result{accumulator.OffsetElement()}; elements-- > 0; accumData += accumChars, result += chars, x.IncrementSubscripts(xAt)) { const CHAR *xData{x.Element(xAt)}; - int cmp{Compare(accumData, xData, accumChars, xChars)}; + int cmp{CharacterScalarCompare(accumData, xData, accumChars, xChars)}; if constexpr (ISMIN) { cmp = -cmp; } @@ -754,14 +755,16 @@ RUNTIME_CHECK(terminator, x.raw().type == y.raw().type); switch (x.raw().type) { case CFI_type_char: - return Compare(x.OffsetElement(), y.OffsetElement(), - x.ElementBytes(), y.ElementBytes()); + return CharacterScalarCompare(x.OffsetElement(), + y.OffsetElement(), x.ElementBytes(), y.ElementBytes()); case CFI_type_char16_t: - return Compare(x.OffsetElement(), y.OffsetElement(), - x.ElementBytes() >> 1, y.ElementBytes() >> 1); + return CharacterScalarCompare(x.OffsetElement(), + y.OffsetElement(), x.ElementBytes() >> 1, + y.ElementBytes() >> 1); case CFI_type_char32_t: - return Compare(x.OffsetElement(), y.OffsetElement(), - x.ElementBytes() >> 2, y.ElementBytes() >> 2); + return CharacterScalarCompare(x.OffsetElement(), + y.OffsetElement(), x.ElementBytes() >> 2, + y.ElementBytes() >> 2); default: terminator.Crash("CharacterCompareScalar: bad string type code %d", static_cast(x.raw().type)); @@ -771,17 +774,17 @@ int RTNAME(CharacterCompareScalar1)( const char *x, const char *y, std::size_t xChars, std::size_t yChars) { - return Compare(x, y, xChars, yChars); + return CharacterScalarCompare(x, y, xChars, yChars); } int RTNAME(CharacterCompareScalar2)(const char16_t *x, const char16_t *y, std::size_t xChars, std::size_t yChars) { - return Compare(x, y, xChars, yChars); + return CharacterScalarCompare(x, y, xChars, yChars); } int RTNAME(CharacterCompareScalar4)(const char32_t *x, const char32_t *y, std::size_t xChars, std::size_t yChars) { - return Compare(x, y, xChars, yChars); + return CharacterScalarCompare(x, y, xChars, yChars); } void RTNAME(CharacterCompare)( diff --git a/flang/runtime/complex-reduction.h b/flang/runtime/complex-reduction.h new file mode 100644 --- /dev/null +++ b/flang/runtime/complex-reduction.h @@ -0,0 +1,51 @@ +/*===-- flang/runtime/complex-reduction.h ---------------------------*- C -*-=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * ===-----------------------------------------------------------------------=== + */ + +/* Wraps the C++-coded complex-valued SUM and PRODUCT reductions with + * C-coded wrapper functions returning _Complex values, to avoid problems + * with C++ build compilers that don't support C's _Complex. + */ + +#ifndef FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_ +#define FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_ + +#include "entry-names.h" + +struct CppDescriptor; /* dummy type name for Fortran::runtime::Descriptor */ + +#ifdef _MSC_VER +typedef _Fcomplex float_Complex_t; +typedef _Dcomplex double_Complex_t; +typedef _Lcomplex long_double_Complex_t; +#else +typedef float _Complex float_Complex_t; +typedef double _Complex double_Complex_t; +typedef long double long_double_Complex_t; +#endif + +#define REDUCTION_ARGS \ + const struct CppDescriptor *x, const char *source, int line, int dim /*=0*/, \ + const struct CppDescriptor *mask /*=NULL*/ +#define REDUCTION_ARG_NAMES x, source, line, dim, mask + +float_Complex_t RTNAME(SumComplex2)(REDUCTION_ARGS); +float_Complex_t RTNAME(SumComplex3)(REDUCTION_ARGS); +float_Complex_t RTNAME(SumComplex4)(REDUCTION_ARGS); +double_Complex_t RTNAME(SumComplex8)(REDUCTION_ARGS); +long_double_Complex_t RTNAME(SumComplex10)(REDUCTION_ARGS); +long_double_Complex_t RTNAME(SumComplex16)(REDUCTION_ARGS); + +float_Complex_t RTNAME(ProductComplex2)(REDUCTION_ARGS); +float_Complex_t RTNAME(ProductComplex3)(REDUCTION_ARGS); +float_Complex_t RTNAME(ProductComplex4)(REDUCTION_ARGS); +double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS); +long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS); +long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS); + +#endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_ diff --git a/flang/runtime/complex-reduction.c b/flang/runtime/complex-reduction.c new file mode 100644 --- /dev/null +++ b/flang/runtime/complex-reduction.c @@ -0,0 +1,68 @@ +/*===-- flang/runtime/complex-reduction.c ---------------------------*- C -*-=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * ===-----------------------------------------------------------------------=== + */ + +#include "complex-reduction.h" +#include "flang/Common/long-double.h" +#include + +/* These are the C standard's names for _Complex constructors; not all C + * compilers support them. + */ +#if !defined(CMPLXF) && defined(__clang__) +#define CMPLXF __builtin_complex +#define CMPLX __builtin_complex +#define CMPLXL __builtin_complex +#endif + +struct CppComplexFloat { + float r, i; +}; +struct CppComplexDouble { + double r, i; +}; +struct CppComplexLongDouble { + long double r, i; +}; + +/* RTNAME(SumComplex4) calls RTNAME(CppSumComplex4) with the same arguments + * and converts the members of its C++ complex result to C _Complex. + */ + +#define CPP_NAME(name) Cpp##name +#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro) \ + struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, REDUCTION_ARGS); \ + cComplex RTNAME(name)(REDUCTION_ARGS) { \ + struct cpptype result; \ + RTNAME(CPP_NAME(name))(&result, REDUCTION_ARG_NAMES); \ + return cmplxMacro(result.r, result.i); \ + } + +/* TODO: COMPLEX(2 & 3) */ + +/* SUM() */ +ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF) +ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX) +#if LONG_DOUBLE == 80 +ADAPT_REDUCTION( + SumComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +#elif LONG_DOUBLE == 128 +ADAPT_REDUCTION( + SumComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +#endif + +/* PRODUCT() */ +ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF) +ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX) +#if LONG_DOUBLE == 80 +ADAPT_REDUCTION( + ProductComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +#elif LONG_DOUBLE == 128 +ADAPT_REDUCTION( + ProductComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +#endif diff --git a/flang/runtime/cpp-type.h b/flang/runtime/cpp-type.h new file mode 100644 --- /dev/null +++ b/flang/runtime/cpp-type.h @@ -0,0 +1,67 @@ +//===-- runtime/cpp-type.h --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Maps Fortran intrinsic types to C++ types used in the runtime. + +#ifndef FORTRAN_RUNTIME_CPP_TYPE_H_ +#define FORTRAN_RUNTIME_CPP_TYPE_H_ + +#include "flang/Common/Fortran.h" +#include "flang/Common/uint128.h" +#include +#include + +namespace Fortran::runtime { + +using common::TypeCategory; + +template struct CppTypeForHelper {}; +template +using CppTypeFor = typename CppTypeForHelper::type; + +template struct CppTypeForHelper { + using type = common::HostSignedIntType<8 * KIND>; +}; + +// TODO: REAL/COMPLEX(2 & 3) +template <> struct CppTypeForHelper { + using type = float; +}; +template <> struct CppTypeForHelper { + using type = double; +}; +template <> struct CppTypeForHelper { + using type = long double; +}; +template <> struct CppTypeForHelper { + using type = long double; +}; + +template struct CppTypeForHelper { + using type = std::complex>; +}; + +template <> struct CppTypeForHelper { + using type = char; +}; +template <> struct CppTypeForHelper { + using type = char16_t; +}; +template <> struct CppTypeForHelper { + using type = char32_t; +}; + +template struct CppTypeForHelper { + using type = common::HostSignedIntType<8 * KIND>; +}; +template <> struct CppTypeForHelper { + using type = bool; +}; + +} // namespace Fortran::runtime +#endif // FORTRAN_RUNTIME_CPP_TYPE_H_ diff --git a/flang/runtime/descriptor-io.h b/flang/runtime/descriptor-io.h --- a/flang/runtime/descriptor-io.h +++ b/flang/runtime/descriptor-io.h @@ -11,6 +11,7 @@ // Implementation of I/O data list item transfers based on descriptors. +#include "cpp-type.h" #include "descriptor.h" #include "edit-input.h" #include "edit-output.h" @@ -260,15 +261,20 @@ case TypeCategory::Integer: switch (kind) { case 1: - return FormattedIntegerIO(io, descriptor); + return FormattedIntegerIO, DIR>( + io, descriptor); case 2: - return FormattedIntegerIO(io, descriptor); + return FormattedIntegerIO, DIR>( + io, descriptor); case 4: - return FormattedIntegerIO(io, descriptor); + return FormattedIntegerIO, DIR>( + io, descriptor); case 8: - return FormattedIntegerIO(io, descriptor); + return FormattedIntegerIO, DIR>( + io, descriptor); case 16: - return FormattedIntegerIO(io, descriptor); + return FormattedIntegerIO, DIR>( + io, descriptor); default: io.GetIoErrorHandler().Crash( "DescriptorIO: Unimplemented INTEGER kind (%d) in descriptor", @@ -330,13 +336,17 @@ case TypeCategory::Logical: switch (kind) { case 1: - return FormattedLogicalIO(io, descriptor); + return FormattedLogicalIO, DIR>( + io, descriptor); case 2: - return FormattedLogicalIO(io, descriptor); + return FormattedLogicalIO, DIR>( + io, descriptor); case 4: - return FormattedLogicalIO(io, descriptor); + return FormattedLogicalIO, DIR>( + io, descriptor); case 8: - return FormattedLogicalIO(io, descriptor); + return FormattedLogicalIO, DIR>( + io, descriptor); default: io.GetIoErrorHandler().Crash( "DescriptorIO: Unimplemented LOGICAL kind (%d) in descriptor", diff --git a/flang/runtime/descriptor.h b/flang/runtime/descriptor.h --- a/flang/runtime/descriptor.h +++ b/flang/runtime/descriptor.h @@ -53,6 +53,19 @@ raw_.extent = upper >= lower ? upper - lower + 1 : 0; return *this; } + Dimension &SetLowerBound(SubscriptValue lower) { + raw_.lower_bound = lower; + return *this; + } + Dimension &SetUpperBound(SubscriptValue upper) { + auto lower{raw_.lower_bound}; + raw_.extent = upper >= lower ? upper - lower + 1 : 0; + return *this; + } + Dimension &SetExtent(SubscriptValue extent) { + raw_.extent = extent; + return *this; + } Dimension &SetByteStride(SubscriptValue bytes) { raw_.sm = bytes; return *this; @@ -137,8 +150,8 @@ raw_.f18Addendum = false; } Descriptor(const Descriptor &); - ~Descriptor(); + Descriptor &operator=(const Descriptor &); static constexpr std::size_t BytesFor(TypeCategory category, int kind) { return category == TypeCategory::Complex ? kind * 2 : kind; diff --git a/flang/runtime/descriptor.cpp b/flang/runtime/descriptor.cpp --- a/flang/runtime/descriptor.cpp +++ b/flang/runtime/descriptor.cpp @@ -17,9 +17,7 @@ namespace Fortran::runtime { -Descriptor::Descriptor(const Descriptor &that) { - std::memcpy(this, &that, that.SizeInBytes()); -} +Descriptor::Descriptor(const Descriptor &that) { *this = that; } Descriptor::~Descriptor() { if (raw_.attribute != CFI_attribute_pointer) { @@ -27,6 +25,11 @@ } } +Descriptor &Descriptor::operator=(const Descriptor &that) { + std::memcpy(this, &that, that.SizeInBytes()); + return *this; +} + void Descriptor::Establish(TypeCode t, std::size_t elementBytes, void *p, int rank, const SubscriptValue *extent, ISO::CFI_attribute_t attribute, bool addendum) { @@ -224,10 +227,9 @@ for (int j{raw_.rank - 1}; j >= 0; --j) { int k{permutation ? permutation[j] : j}; const Dimension &dim{GetDimension(k)}; - std::size_t quotient{j ? elementNumber / dimCoefficient[j] : 0}; - subscript[k] = - dim.LowerBound() + elementNumber - dimCoefficient[j] * quotient; - elementNumber = quotient; + std::size_t quotient{elementNumber / dimCoefficient[j]}; + subscript[k] = quotient + dim.LowerBound(); + elementNumber -= quotient * dimCoefficient[j]; } return true; } diff --git a/flang/runtime/entry-names.h b/flang/runtime/entry-names.h --- a/flang/runtime/entry-names.h +++ b/flang/runtime/entry-names.h @@ -1,20 +1,21 @@ -//===-- runtime/entry-names.h -----------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -// Defines the macro RTNAME(n) which decorates the external name of a runtime -// library function or object with extra characters so that it -// (a) is not in the user's name space, -// (b) doesn't conflict with other libraries, and -// (c) prevents incompatible versions of the runtime library from linking -// -// The value of REVISION should not be changed until/unless the API to the -// runtime library must change in some way that breaks backward compatibility. +/*===-- runtime/entry-names.h ---------------------------------------*- C -*-=== + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + *===------------------------------------------------------------------------=== + */ +/* Defines the macro RTNAME(n) which decorates the external name of a runtime + * library function or object with extra characters so that it + * (a) is not in the user's name space, + * (b) doesn't conflict with other libraries, and + * (c) prevents incompatible versions of the runtime library from linking + * + * The value of REVISION should not be changed until/unless the API to the + * runtime library must change in some way that breaks backward compatibility. + */ #ifndef RTNAME #define NAME_WITH_PREFIX_AND_REVISION(prefix, revision, name) \ prefix##revision##name diff --git a/flang/runtime/io-api.cpp b/flang/runtime/io-api.cpp --- a/flang/runtime/io-api.cpp +++ b/flang/runtime/io-api.cpp @@ -856,26 +856,6 @@ return false; } -template -static bool SetInteger(INT &x, int kind, std::int64_t value) { - switch (kind) { - case 1: - reinterpret_cast(x) = value; - return true; - case 2: - reinterpret_cast(x) = value; - return true; - case 4: - reinterpret_cast(x) = value; - return true; - case 8: - reinterpret_cast(x) = value; - return true; - default: - return false; - } -} - bool IONAME(GetNewUnit)(Cookie cookie, int &unit, int kind) { IoStatementState &io{*cookie}; auto *open{io.get_if()}; diff --git a/flang/runtime/reduction.h b/flang/runtime/reduction.h new file mode 100644 --- /dev/null +++ b/flang/runtime/reduction.h @@ -0,0 +1,230 @@ +//===-- runtime/reduction.h -------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Defines the API for the reduction transformational intrinsic functions. +// (Except the complex-valued total reduction forms of SUM and PRODUCT; +// the API for those is in complex-reduction.h so that C's _Complex can +// be used for their return types.) + +#ifndef FORTRAN_RUNTIME_REDUCTION_H_ +#define FORTRAN_RUNTIME_REDUCTION_H_ + +#include "descriptor.h" +#include "entry-names.h" +#include "flang/Common/uint128.h" +#include +#include + +namespace Fortran::runtime { +extern "C" { + +// Reductions that are known to return scalars have per-type entry +// points. These cover the casse that either have no DIM= +// argument, or have an argument rank of 1. Pass 0 for no DIM= +// or the value of the DIM= argument so that it may be checked. +// The data type in the descriptor is checked against the expected +// return type. +// +// Reductions that return arrays are the remaining cases in which +// the argument rank is greater than one and there is a DIM= +// argument present. These cases establish and allocate their +// results in a caller-supplied descriptor, which is assumed to +// be large enough. +// +// Complex-valued SUM and PRODUCT reductions have their API +// entry points defined in complex-reduction.h; these are C wrappers +// around C++ implementations so as to keep usage of C's _Complex +// types out of C++ code. + +// SUM() + +std::int8_t RTNAME(SumInteger1)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int16_t RTNAME(SumInteger2)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int32_t RTNAME(SumInteger4)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int64_t RTNAME(SumInteger8)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +common::int128_t RTNAME(SumInteger16)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); + +// REAL/COMPLEX(2 & 3) return 32-bit float results for the caller to downconvert +float RTNAME(SumReal2)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(SumReal3)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(SumReal4)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +double RTNAME(SumReal8)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(SumReal10)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(SumReal16)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); + +void RTNAME(CppSumComplex2)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppSumComplex3)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppSumComplex4)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppSumComplex8)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppSumComplex10)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppSumComplex16)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); + +void RTNAME(SumDim)(Descriptor &result, const Descriptor &array, int dim, + const char *source, int line, const Descriptor *mask = nullptr); + +// PRODUCT() + +std::int8_t RTNAME(ProductInteger1)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int16_t RTNAME(ProductInteger2)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int32_t RTNAME(ProductInteger4)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int64_t RTNAME(ProductInteger8)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +common::int128_t RTNAME(ProductInteger16)(const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); + +// REAL/COMPLEX(2 & 3) return 32-bit float results for the caller to downconvert +float RTNAME(ProductReal2)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(ProductReal3)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(ProductReal4)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +double RTNAME(ProductReal8)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(ProductReal10)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(ProductReal16)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); + +void RTNAME(CppProductComplex2)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppProductComplex3)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppProductComplex4)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppProductComplex8)(std::complex &, const Descriptor &, + const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppProductComplex10)(std::complex &, + const Descriptor &, const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); +void RTNAME(CppProductComplex16)(std::complex &, + const Descriptor &, const char *source, int line, int dim = 0, + const Descriptor *mask = nullptr); + +void RTNAME(ProductDim)(Descriptor &result, const Descriptor &array, int dim, + const char *source, int line, const Descriptor *mask = nullptr); + +// MAXLOC and MINLOC +// These return allocated arrays in the supplied descriptor. +// The default value for KIND= should be the default INTEGER in effect at +// compilation time. +void RTNAME(Maxloc)(Descriptor &, const Descriptor &, int kind, + const char *source, int line, const Descriptor *mask = nullptr, + bool back = false); +void RTNAME(MaxlocDim)(Descriptor &, const Descriptor &, int kind, int dim, + const char *source, int line, const Descriptor *mask = nullptr, + bool back = false); +void RTNAME(Minloc)(Descriptor &, const Descriptor &, int kind, + const char *source, int line, const Descriptor *mask = nullptr, + bool back = false); +void RTNAME(MinlocDim)(Descriptor &, const Descriptor &, int kind, int dim, + const char *source, int line, const Descriptor *mask = nullptr, + bool back = false); + +// MAXVAL and MINVAL +std::int8_t RTNAME(MaxvalInteger1)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int16_t RTNAME(MaxvalInteger2)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int32_t RTNAME(MaxvalInteger4)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int64_t RTNAME(MaxvalInteger8)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +common::int128_t RTNAME(MaxvalInteger16)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(MaxvalReal2)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(MaxvalReal3)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(MaxvalReal4)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +double RTNAME(MaxvalReal8)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(MaxvalReal10)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(MaxvalReal16)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +void RTNAME(MaxvalCharacter)(Descriptor &, const Descriptor &, + const char *source, int line, const Descriptor *mask = nullptr); + +std::int8_t RTNAME(MinvalInteger1)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int16_t RTNAME(MinvalInteger2)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int32_t RTNAME(MinvalInteger4)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +std::int64_t RTNAME(MivalInteger8)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +common::int128_t RTNAME(MivalInteger16)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(MinvalReal2)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(MinvalReal3)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +float RTNAME(MinvalReal4)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +double RTNAME(MinvalReal8)(const Descriptor &, const char *source, int line, + int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(MinvalReal10)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +long double RTNAME(MinvalReal16)(const Descriptor &, const char *source, + int line, int dim = 0, const Descriptor *mask = nullptr); +void RTNAME(MinvalCharacter)(Descriptor &, const Descriptor &, + const char *source, int line, const Descriptor *mask = nullptr); + +void RTNAME(MaxvalDim)(Descriptor &, const Descriptor &, int dim, + const char *source, int line, const Descriptor *mask = nullptr); +void RTNAME(MinvalDim)(Descriptor &, const Descriptor &, int dim, + const char *source, int line, const Descriptor *mask = nullptr); + +// ALL, ANY, & COUNT logical reductions +bool RTNAME(All)(const Descriptor &, const char *source, int line, int dim = 0); +void RTNAME(AllDim)(Descriptor &result, const Descriptor &, int dim, + const char *source, int line); +bool RTNAME(Any)(const Descriptor &, const char *source, int line, int dim = 0); +void RTNAME(AnyDim)(Descriptor &result, const Descriptor &, int dim, + const char *source, int line); +std::int64_t RTNAME(Count)( + const Descriptor &, const char *source, int line, int dim = 0); +void RTNAME(CountDim)(Descriptor &result, const Descriptor &, int dim, int kind, + const char *source, int line); + +} // extern "C" +} // namespace Fortran::runtime +#endif // FORTRAN_RUNTIME_REDUCTION_H_ diff --git a/flang/runtime/reduction.cpp b/flang/runtime/reduction.cpp new file mode 100644 --- /dev/null +++ b/flang/runtime/reduction.cpp @@ -0,0 +1,1525 @@ +//===-- runtime/reduction.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Implements ALL, ANY, COUNT, MAXLOC, MAXVAL, MINLOC, MINVAL, PRODUCT, and SUM +// for all required operand types and shapes and (for MAXLOC & MINLOC) kinds of +// results. +// +// * Real and complex SUM reductions attempt to reduce floating-point +// cancellation on intermediate results by adding up partial sums +// for positive and negative elements independently. +// * Partial reductions (i.e., those with DIM= arguments that are not +// required to be 1 by the rank of the argument) return arrays that +// are dynamically allocated in a caller-supplied descriptor. +// * Total reductions (i.e., no DIM= argument) with MAXLOC & MINLOC +// return integer vectors of some kind, not scalars; a caller-supplied +// descriptor is used +// * Character-valued reductions (MAXVAL & MINVAL) return arbitrary +// length results, dynamically allocated in a caller-supplied descriptor + +#include "reduction.h" +#include "character.h" +#include "cpp-type.h" +#include "terminator.h" +#include "tools.h" +#include "flang/Common/long-double.h" +#include +#include +#include +#include + +namespace Fortran::runtime { + +// Generic reduction templates + +// Reductions are implemented with *accumulators*, which are instances of +// classes that incrementally build up the result (or an element thereof) during +// a traversal of the unmasked elements of an array. Each accumulator class +// supports a constructor (which captures a reference to the array), an +// AccumulateAt() member function that applies supplied subscripts to the +// array and does something with a scalar element, and a GetResult() +// member function that copies a final result into its destination. + +// Total reduction of the array argument to a scalar (or to a vector in the +// cases of MAXLOC & MINLOC). These are the cases without DIM= or cases +// where the argument has rank 1 and DIM=, if present, must be 1. +template +inline void DoTotalReduction(const Descriptor &x, int dim, + const Descriptor *mask, ACCUMULATOR &accumulator, const char *intrinsic, + Terminator &terminator) { + if (dim < 0 || dim > 1) { + terminator.Crash( + "%s: bad DIM=%d for argument with rank %d", intrinsic, dim, x.rank()); + } + SubscriptValue xAt[maxRank]; + x.GetLowerBounds(xAt); + if (mask) { + CheckConformability(x, *mask, terminator, intrinsic, "ARRAY", "MASK"); + SubscriptValue maskAt[maxRank]; + mask->GetLowerBounds(maskAt); + if (mask->rank() > 0) { + for (auto elements{x.Elements()}; elements--; + x.IncrementSubscripts(xAt), mask->IncrementSubscripts(maskAt)) { + if (IsLogicalElementTrue(*mask, maskAt)) { + accumulator.template AccumulateAt(xAt); + } + } + return; + } else if (!IsLogicalElementTrue(*mask, maskAt)) { + // scalar MASK=.FALSE.: return identity value + return; + } + } + // No MASK=, or scalar MASK=.TRUE. + for (auto elements{x.Elements()}; elements--; x.IncrementSubscripts(xAt)) { + if (!accumulator.template AccumulateAt(xAt)) { + break; // cut short, result is known + } + } +} + +template +inline CppTypeFor GetTotalReduction(const Descriptor &x, + const char *source, int line, int dim, const Descriptor *mask, + ACCUMULATOR &&accumulator, const char *intrinsic) { + Terminator terminator{source, line}; + RUNTIME_CHECK(terminator, TypeCode(CAT, KIND) == x.type()); + using CppType = CppTypeFor; + DoTotalReduction(x, dim, mask, accumulator, intrinsic, terminator); + CppType result; + accumulator.template GetResult(&result); + return result; +} + +// For reductions on a dimension, e.g. SUM(array,DIM=2) where the shape +// of the array is [2,3,5], the shape of the result is [2,5] and +// result(j,k) = SUM(array(j,:,k)), possibly modified if the array has +// lower bounds other than one. This utility subroutine creates an +// array of subscripts [j,_,k] for result subscripts [j,k] so that the +// elemets of array(j,:,k) can be reduced. +inline void GetExpandedSubscripts(SubscriptValue at[], + const Descriptor &descriptor, int zeroBasedDim, + const SubscriptValue from[]) { + descriptor.GetLowerBounds(at); + int rank{descriptor.rank()}; + int j{0}; + for (; j < zeroBasedDim; ++j) { + at[j] += from[j] - 1 /*lower bound*/; + } + for (++j; j < rank; ++j) { + at[j] += from[j - 1] - 1; + } +} + +template +inline void ReduceDimToScalar(const Descriptor &x, int zeroBasedDim, + SubscriptValue subscripts[], TYPE *result) { + ACCUMULATOR accumulator{x}; + SubscriptValue xAt[maxRank]; + GetExpandedSubscripts(xAt, x, zeroBasedDim, subscripts); + const auto &dim{x.GetDimension(zeroBasedDim)}; + SubscriptValue at{dim.LowerBound()}; + for (auto n{dim.Extent()}; n-- > 0; ++at) { + xAt[zeroBasedDim] = at; + if (!accumulator.template AccumulateAt(xAt)) { + break; + } + } + accumulator.template GetResult(result, zeroBasedDim); +} + +template +inline void ReduceDimMaskToScalar(const Descriptor &x, int zeroBasedDim, + SubscriptValue subscripts[], const Descriptor &mask, TYPE *result) { + ACCUMULATOR accumulator{x}; + SubscriptValue xAt[maxRank], maskAt[maxRank]; + GetExpandedSubscripts(xAt, x, zeroBasedDim, subscripts); + GetExpandedSubscripts(maskAt, mask, zeroBasedDim, subscripts); + const auto &xDim{x.GetDimension(zeroBasedDim)}; + SubscriptValue xPos{xDim.LowerBound()}; + const auto &maskDim{mask.GetDimension(zeroBasedDim)}; + SubscriptValue maskPos{maskDim.LowerBound()}; + for (auto n{x.GetDimension(zeroBasedDim).Extent()}; n-- > 0; + ++xPos, ++maskPos) { + maskAt[zeroBasedDim] = maskPos; + if (IsLogicalElementTrue(mask, maskAt)) { + xAt[zeroBasedDim] = xPos; + if (!accumulator.template AccumulateAt(xAt)) { + break; + } + } + } + accumulator.template GetResult(result, zeroBasedDim); +} + +// Utility: establishes & allocates the result array for a partial +// reduction (i.e., one with DIM=). +static void CreatePartialReductionResult(Descriptor &result, + const Descriptor &x, int dim, Terminator &terminator, const char *intrinsic, + TypeCode typeCode) { + int xRank{x.rank()}; + if (dim < 1 || dim > xRank) { + terminator.Crash("%s: bad DIM=%d for rank %d", intrinsic, dim, xRank); + } + int zeroBasedDim{dim - 1}; + SubscriptValue resultExtent[maxRank]; + for (int j{0}; j < zeroBasedDim; ++j) { + resultExtent[j] = x.GetDimension(j).Extent(); + } + for (int j{zeroBasedDim + 1}; j < xRank; ++j) { + resultExtent[j - 1] = x.GetDimension(j).Extent(); + } + result.Establish(typeCode, x.ElementBytes(), nullptr, xRank - 1, resultExtent, + CFI_attribute_allocatable); + for (int j{0}; j + 1 < xRank; ++j) { + result.GetDimension(j).SetBounds(1, resultExtent[j]); + } + if (int stat{result.Allocate()}) { + terminator.Crash( + "%s: could not allocate memory for result; STAT=%d", intrinsic, stat); + } +} + +// Partial reductions with DIM= + +template +inline void PartialReduction(Descriptor &result, const Descriptor &x, int dim, + const Descriptor *mask, Terminator &terminator, const char *intrinsic) { + CreatePartialReductionResult( + result, x, dim, terminator, intrinsic, TypeCode{CAT, KIND}); + SubscriptValue at[maxRank]; + result.GetLowerBounds(at); + INTERNAL_CHECK(at[0] == 1); + using CppType = CppTypeFor; + if (mask) { + CheckConformability(x, *mask, terminator, intrinsic, "ARRAY", "MASK"); + SubscriptValue maskAt[maxRank]; // contents unused + if (mask->rank() > 0) { + for (auto n{result.Elements()}; n-- > 0; result.IncrementSubscripts(at)) { + ReduceDimMaskToScalar( + x, dim - 1, at, *mask, result.Element(at)); + } + return; + } else if (!IsLogicalElementTrue(*mask, maskAt)) { + // scalar MASK=.FALSE. + ACCUMULATOR accumulator{x}; + for (auto n{result.Elements()}; n-- > 0; result.IncrementSubscripts(at)) { + accumulator.GetResult(result.Element(at)); + } + return; + } + } + // No MASK= or scalar MASK=.TRUE. + for (auto n{result.Elements()}; n-- > 0; result.IncrementSubscripts(at)) { + ReduceDimToScalar( + x, dim - 1, at, result.Element(at)); + } +} + +template