diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt --- a/flang/runtime/CMakeLists.txt +++ b/flang/runtime/CMakeLists.txt @@ -39,6 +39,7 @@ connection.cpp derived.cpp descriptor.cpp + dot-product.cpp edit-input.cpp edit-output.cpp environment.cpp diff --git a/flang/runtime/complex-reduction.h b/flang/runtime/complex-reduction.h --- a/flang/runtime/complex-reduction.h +++ b/flang/runtime/complex-reduction.h @@ -49,4 +49,17 @@ long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS); long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS); +#define DOT_PRODUCT_ARGS \ + const struct CppDescriptor *x, const struct CppDescriptor *y, \ + const char *source, int line, int dim /*=0*/, \ + const struct CppDescriptor *mask /*=NULL*/ +#define DOT_PRODUCT_ARG_NAMES x, y, source, line, dim, mask + +float_Complex_t RTNAME(DotProductComplex2)(DOT_PRODUCT_ARGS); +float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS); +float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS); +double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS); +long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS); +long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS); + #endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_ diff --git a/flang/runtime/complex-reduction.c b/flang/runtime/complex-reduction.c --- a/flang/runtime/complex-reduction.c +++ b/flang/runtime/complex-reduction.c @@ -75,34 +75,51 @@ */ #define CPP_NAME(name) Cpp##name -#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro) \ - struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, REDUCTION_ARGS); \ - cComplex RTNAME(name)(REDUCTION_ARGS) { \ +#define ADAPT_REDUCTION(name, cComplex, cpptype, cmplxMacro, ARGS, ARG_NAMES) \ + struct cpptype RTNAME(CPP_NAME(name))(struct cpptype *, ARGS); \ + cComplex RTNAME(name)(ARGS) { \ struct cpptype result; \ - RTNAME(CPP_NAME(name))(&result, REDUCTION_ARG_NAMES); \ + RTNAME(CPP_NAME(name))(&result, ARG_NAMES); \ return cmplxMacro(result.r, result.i); \ } /* TODO: COMPLEX(2 & 3) */ /* SUM() */ -ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF) -ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX) +ADAPT_REDUCTION(SumComplex4, float_Complex_t, CppComplexFloat, CMPLXF, + REDUCTION_ARGS, REDUCTION_ARG_NAMES) +ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX, + REDUCTION_ARGS, REDUCTION_ARG_NAMES) #if LONG_DOUBLE == 80 -ADAPT_REDUCTION( - SumComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble, + CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES) #elif LONG_DOUBLE == 128 -ADAPT_REDUCTION( - SumComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble, + CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES) #endif /* PRODUCT() */ -ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF) -ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX) +ADAPT_REDUCTION(ProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF, + REDUCTION_ARGS, REDUCTION_ARG_NAMES) +ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX, + REDUCTION_ARGS, REDUCTION_ARG_NAMES) #if LONG_DOUBLE == 80 -ADAPT_REDUCTION( - ProductComplex10, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble, + CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES) #elif LONG_DOUBLE == 128 -ADAPT_REDUCTION( - ProductComplex16, long_double_Complex_t, CppComplexLongDouble, CMPLXL) +ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble, + CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES) +#endif + +/* DOT_PRODUCT() */ +ADAPT_REDUCTION(DotProductComplex4, float_Complex_t, CppComplexFloat, CMPLXF, + DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES) +ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX, + DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES) +#if LONG_DOUBLE == 80 +ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t, + CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES) +#elif LONG_DOUBLE == 128 +ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t, + CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES) #endif diff --git a/flang/runtime/dot-product.cpp b/flang/runtime/dot-product.cpp new file mode 100644 --- /dev/null +++ b/flang/runtime/dot-product.cpp @@ -0,0 +1,199 @@ +//===-- runtime/dot-product.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "cpp-type.h" +#include "descriptor.h" +#include "reduction.h" +#include "terminator.h" +#include "tools.h" +#include + +namespace Fortran::runtime { + +template +static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y, + Terminator &terminator) -> typename ACCUMULATOR::Result { + RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); + SubscriptValue n{x.GetDimension(0).Extent()}; + if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { + terminator.Crash( + "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", + static_cast(n), static_cast(yN)); + } + SubscriptValue xAt{x.GetDimension(0).LowerBound()}; + SubscriptValue yAt{y.GetDimension(0).LowerBound()}; + ACCUMULATOR accumulator{x, y}; + for (SubscriptValue j{0}; j < n; ++j) { + accumulator.Accumulate(xAt++, yAt++); + } + return accumulator.GetResult(); +} + +template class ACCUM> +struct DotProduct { + using Result = CppTypeFor; + template struct DP1 { + template struct DP2 { + Result operator()(const Descriptor &x, const Descriptor &y, + Terminator &terminator) const { + if constexpr (constexpr auto resultType{ + GetResultType(XCAT, XKIND, YCAT, YKIND)}) { + if constexpr (resultType->first == RCAT && + resultType->second <= RKIND) { + using Accum = ACCUM, + CppTypeFor>; + return DoDotProduct(x, y, terminator); + } + } + terminator.Crash( + "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", + static_cast(RCAT), RKIND, static_cast(XCAT), XKIND, + static_cast(YCAT), YKIND); + } + }; + Result operator()(const Descriptor &x, const Descriptor &y, + Terminator &terminator, TypeCategory yCat, int yKind) const { + return ApplyType(yCat, yKind, terminator, x, y, terminator); + } + }; + Result operator()(const Descriptor &x, const Descriptor &y, + const char *source, int line) const { + Terminator terminator{source, line}; + auto xCatKind{x.type().GetCategoryAndKind()}; + auto yCatKind{y.type().GetCategoryAndKind()}; + RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); + return ApplyType(xCatKind->first, xCatKind->second, terminator, + x, y, terminator, yCatKind->first, yCatKind->second); + } +}; + +template +class NumericAccumulator { +public: + using Result = RESULT; + NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} + void Accumulate(SubscriptValue xAt, SubscriptValue yAt) { + if constexpr (XCAT == TypeCategory::Complex) { + sum_ += std::conj(static_cast(*x_.Element(&xAt))) * + static_cast(*y_.Element(&yAt)); + } else { + sum_ += static_cast(*x_.Element(&xAt)) * + static_cast(*y_.Element(&yAt)); + } + } + Result GetResult() const { return sum_; } + +private: + const Descriptor &x_, &y_; + Result sum_{0}; +}; + +template +class LogicalAccumulator { +public: + using Result = bool; + LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} + void Accumulate(SubscriptValue xAt, SubscriptValue yAt) { + result_ = result_ || + (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); + } + bool GetResult() const { return result_; } + +private: + const Descriptor &x_, &y_; + bool result_{false}; +}; + +extern "C" { +std::int8_t RTNAME(DotProductInteger1)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +std::int16_t RTNAME(DotProductInteger2)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +std::int32_t RTNAME(DotProductInteger4)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +std::int64_t RTNAME(DotProductInteger8)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +#ifdef __SIZEOF_INT128__ +common::int128_t RTNAME(DotProductInteger16)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +#endif + +// TODO: REAL/COMPLEX(2 & 3) +float RTNAME(DotProductReal4)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +double RTNAME(DotProductReal8)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +#if LONG_DOUBLE == 80 +long double RTNAME(DotProductReal10)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +#elif LONG_DOUBLE == 128 +long double RTNAME(DotProductReal16)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +#endif + +void RTNAME(CppDotProductComplex4)(std::complex &result, + const Descriptor &x, const Descriptor &y, const char *source, int line) { + auto z{DotProduct{}( + x, y, source, line)}; + result = std::complex{ + static_cast(z.real()), static_cast(z.imag())}; +} +void RTNAME(CppDotProductComplex8)(std::complex &result, + const Descriptor &x, const Descriptor &y, const char *source, int line) { + result = DotProduct{}( + x, y, source, line); +} +#if LONG_DOUBLE == 80 +void RTNAME(CppDotProductComplex10)(std::complex &result, + const Descriptor &x, const Descriptor &y, const char *source, int line) { + result = DotProduct{}( + x, y, source, line); +} +#elif LONG_DOUBLE == 128 +void RTNAME(CppDotProductComplex16)(std::complex &result, + const Descriptor &x, const Descriptor &y, const char *source, int line) { + result = DotProduct{}( + x, y, source, line); +} +#endif + +bool RTNAME(DotProductLogical)( + const Descriptor &x, const Descriptor &y, const char *source, int line) { + return DotProduct{}( + x, y, source, line); +} +} // extern "C" +} // namespace Fortran::runtime diff --git a/flang/runtime/reduction.h b/flang/runtime/reduction.h --- a/flang/runtime/reduction.h +++ b/flang/runtime/reduction.h @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// // Defines the API for the reduction transformational intrinsic functions. -// (Except the complex-valued total reduction forms of SUM and PRODUCT; -// the API for those is in complex-reduction.h so that C's _Complex can -// be used for their return types.) +// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction +// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that +// C's _Complex can be used for their return types.) #ifndef FORTRAN_RUNTIME_REDUCTION_H_ #define FORTRAN_RUNTIME_REDUCTION_H_ @@ -275,6 +275,48 @@ void RTNAME(ParityDim)(Descriptor &result, const Descriptor &, int dim, const char *source, int line); +// DOT_PRODUCT +std::int8_t RTNAME(DotProductInteger1)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +std::int16_t RTNAME(DotProductInteger2)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +std::int32_t RTNAME(DotProductInteger4)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +std::int64_t RTNAME(DotProductInteger8)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +#ifdef __SIZEOF_INT128__ +common::int128_t RTNAME(DotProductInteger16)(const Descriptor &, + const Descriptor &, const char *source = nullptr, int line = 0); +#endif +float RTNAME(DotProductReal2)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +float RTNAME(DotProductReal3)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +float RTNAME(DotProductReal4)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +double RTNAME(DotProductReal8)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +long double RTNAME(DotProductReal10)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +long double RTNAME(DotProductReal16)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); +void RTNAME(CppDotProductComplex2)(std::complex &, const Descriptor &, + const Descriptor &, const char *source = nullptr, int line = 0); +void RTNAME(CppDotProductComplex3)(std::complex &, const Descriptor &, + const Descriptor &, const char *source = nullptr, int line = 0); +void RTNAME(CppDotProductComplex4)(std::complex &, const Descriptor &, + const Descriptor &, const char *source = nullptr, int line = 0); +void RTNAME(CppDotProductComplex8)(std::complex &, const Descriptor &, + const Descriptor &, const char *source = nullptr, int line = 0); +void RTNAME(CppDotProductComplex10)(std::complex &, + const Descriptor &, const Descriptor &, const char *source = nullptr, + int line = 0); +void RTNAME(CppDotProductComplex16)(std::complex &, + const Descriptor &, const Descriptor &, const char *source = nullptr, + int line = 0); +bool RTNAME(DotProductLogical)(const Descriptor &, const Descriptor &, + const char *source = nullptr, int line = 0); + } // extern "C" } // namespace Fortran::runtime #endif // FORTRAN_RUNTIME_REDUCTION_H_ diff --git a/flang/runtime/reduction.cpp b/flang/runtime/reduction.cpp --- a/flang/runtime/reduction.cpp +++ b/flang/runtime/reduction.cpp @@ -9,8 +9,8 @@ // Implements ALL, ANY, COUNT, IPARITY, & PARITY for all required operand // types and shapes. // -// FINDLOC, SUM, and PRODUCT are in their own eponymous source files; -// NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp. +// DOT_PRODUCT, FINDLOC, SUM, and PRODUCT are in their own eponymous source +// files; NORM2, MAXLOC, MINLOC, MAXVAL, and MINVAL are in extrema.cpp. #include "reduction.h" #include "reduction-templates.h" diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h --- a/flang/runtime/tools.h +++ b/flang/runtime/tools.h @@ -102,6 +102,104 @@ } } +// Maps intrinsic runtime type category and kind values to the appropriate +// instantiation of a function object template and calls it with the supplied +// arguments. +template