Index: flang/runtime/CMakeLists.txt =================================================================== --- flang/runtime/CMakeLists.txt +++ flang/runtime/CMakeLists.txt @@ -53,6 +53,7 @@ io-error.cpp io-stmt.cpp main.cpp + matmul.cpp memory.cpp misc-intrinsic.cpp namelist.cpp Index: flang/runtime/dot-product.cpp =================================================================== --- flang/runtime/dot-product.cpp +++ flang/runtime/dot-product.cpp @@ -15,9 +15,33 @@ namespace Fortran::runtime { -template -static inline auto DoDotProduct(const Descriptor &x, const Descriptor &y, - Terminator &terminator) -> typename ACCUMULATOR::Result { +template +class Accumulator { +public: + using Result = RESULT; + Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} + void Accumulate(SubscriptValue xAt, SubscriptValue yAt) { + if constexpr (XCAT == TypeCategory::Complex) { + sum_ += std::conj(static_cast(*x_.Element(&xAt))) * + static_cast(*y_.Element(&yAt)); + } else if constexpr (XCAT == TypeCategory::Logical) { + sum_ = sum_ || + (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); + } else { + sum_ += static_cast(*x_.Element(&xAt)) * + static_cast(*y_.Element(&yAt)); + } + } + Result GetResult() const { return sum_; } + +private: + const Descriptor &x_, &y_; + Result sum_{}; +}; + +template +static inline RESULT DoDotProduct( + const Descriptor &x, const Descriptor &y, Terminator &terminator) { RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); SubscriptValue n{x.GetDimension(0).Extent()}; if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { @@ -25,18 +49,27 @@ "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", static_cast(n), static_cast(yN)); } + if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + // TODO: call BLAS-1 SDOT or SDSDOT + } else if constexpr (std::is_same_v) { + // TODO: call BLAS-1 DDOT + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-1 CDOTC + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-1 ZDOTC + } + } SubscriptValue xAt{x.GetDimension(0).LowerBound()}; SubscriptValue yAt{y.GetDimension(0).LowerBound()}; - ACCUMULATOR accumulator{x, y}; + Accumulator accumulator{x, y}; for (SubscriptValue j{0}; j < n; ++j) { accumulator.Accumulate(xAt++, yAt++); } return accumulator.GetResult(); } -template class ACCUM> -struct DotProduct { +template struct DotProduct { using Result = CppTypeFor; template struct DP1 { template struct DP2 { @@ -46,9 +79,8 @@ GetResultType(XCAT, XKIND, YCAT, YKIND)}) { if constexpr (resultType->first == RCAT && resultType->second <= RKIND) { - using Accum = ACCUM, - CppTypeFor>; - return DoDotProduct(x, y, terminator); + return DoDotProduct, + CppTypeFor>(x, y, terminator); } } terminator.Crash( @@ -73,127 +105,76 @@ } }; -template -class NumericAccumulator { -public: - using Result = RESULT; - NumericAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} - void Accumulate(SubscriptValue xAt, SubscriptValue yAt) { - if constexpr (XCAT == TypeCategory::Complex) { - sum_ += std::conj(static_cast(*x_.Element(&xAt))) * - static_cast(*y_.Element(&yAt)); - } else { - sum_ += static_cast(*x_.Element(&xAt)) * - static_cast(*y_.Element(&yAt)); - } - } - Result GetResult() const { return sum_; } - -private: - const Descriptor &x_, &y_; - Result sum_{0}; -}; - -template -class LogicalAccumulator { -public: - using Result = bool; - LogicalAccumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} - void Accumulate(SubscriptValue xAt, SubscriptValue yAt) { - result_ = result_ || - (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); - } - bool GetResult() const { return result_; } - -private: - const Descriptor &x_, &y_; - bool result_{false}; -}; - extern "C" { std::int8_t RTNAME(DotProductInteger1)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } std::int16_t RTNAME(DotProductInteger2)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } std::int32_t RTNAME(DotProductInteger4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } std::int64_t RTNAME(DotProductInteger8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } #ifdef __SIZEOF_INT128__ common::int128_t RTNAME(DotProductInteger16)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } #endif // TODO: REAL/COMPLEX(2 & 3) float RTNAME(DotProductReal4)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } double RTNAME(DotProductReal8)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } #if LONG_DOUBLE == 80 long double RTNAME(DotProductReal10)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } #elif LONG_DOUBLE == 128 long double RTNAME(DotProductReal16)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } #endif void RTNAME(CppDotProductComplex4)(std::complex &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { - auto z{DotProduct{}( - x, y, source, line)}; + auto z{DotProduct{}(x, y, source, line)}; result = std::complex{ static_cast(z.real()), static_cast(z.imag())}; } void RTNAME(CppDotProductComplex8)(std::complex &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { - result = DotProduct{}( - x, y, source, line); + result = DotProduct{}(x, y, source, line); } #if LONG_DOUBLE == 80 void RTNAME(CppDotProductComplex10)(std::complex &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { - result = DotProduct{}( - x, y, source, line); + result = DotProduct{}(x, y, source, line); } #elif LONG_DOUBLE == 128 void RTNAME(CppDotProductComplex16)(std::complex &result, const Descriptor &x, const Descriptor &y, const char *source, int line) { - result = DotProduct{}( - x, y, source, line); + result = DotProduct{}(x, y, source, line); } #endif bool RTNAME(DotProductLogical)( const Descriptor &x, const Descriptor &y, const char *source, int line) { - return DotProduct{}( - x, y, source, line); + return DotProduct{}(x, y, source, line); } } // extern "C" } // namespace Fortran::runtime Index: flang/runtime/matmul.h =================================================================== --- /dev/null +++ flang/runtime/matmul.h @@ -0,0 +1,29 @@ +//===-- runtime/matmul.h ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// API for the transformational intrinsic function MATMUL. + +#ifndef FORTRAN_RUNTIME_MATMUL_H_ +#define FORTRAN_RUNTIME_MATMUL_H_ +#include "entry-names.h" +namespace Fortran::runtime { +class Descriptor; +extern "C" { + +// The most general MATMUL. All type and shape information is taken from the +// arguments' descriptors, and the result is dynamically allocated. +void RTNAME(Matmul)(Descriptor &, const Descriptor &, const Descriptor &, + const char *sourceFile = nullptr, int line = 0); + +// A non-allocating variant; the result's descriptor must be established +// and have a valid base address. +void RTNAME(MatmulDirect)(const Descriptor &, const Descriptor &, + const Descriptor &, const char *sourceFile = nullptr, int line = 0); +} // extern "C" +} // namespace Fortran::runtime +#endif // FORTRAN_RUNTIME_MATMUL_H_ Index: flang/runtime/matmul.cpp =================================================================== --- /dev/null +++ flang/runtime/matmul.cpp @@ -0,0 +1,220 @@ +//===-- runtime/matmul.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Implements all forms of MATMUL (Fortran 2018 16.9.124) +// +// There are two main entry points; one establishes a descriptor for the +// result and allocates it, and the other expects a result descriptor that +// points to existing storage. +// +// This implementation must handle all combinations of numeric types and +// kinds (100 - 165 cases depending on the target), plus all combinations +// of logical kinds (16). A single template undergoes many instantiations +// to cover all of the valid possibilities. +// +// Places where BLAS routines could be called are marked as TODO items. + +#include "matmul.h" +#include "cpp-type.h" +#include "descriptor.h" +#include "terminator.h" +#include "tools.h" + +namespace Fortran::runtime { + +template +class Accumulator { +public: + // Accumulate floating-point results in (at least) double precision + using Result = CppTypeFor(sizeof(double))) + : RKIND>; + Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} + void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) { + if constexpr (RCAT == TypeCategory::Logical) { + sum_ = sum_ || + (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); + } else { + sum_ += static_cast(*x_.Element(xAt)) * + static_cast(*y_.Element(yAt)); + } + } + Result GetResult() const { return sum_; } + +private: + const Descriptor &x_, &y_; + Result sum_{}; +}; + +// Implements an instance of MATMUL for given argument types. +template +static inline void DoMatmul( + std::conditional_t &result, + const Descriptor &x, const Descriptor &y, Terminator &terminator) { + int xRank{x.rank()}; + int yRank{y.rank()}; + int resRank{xRank + yRank - 2}; + if (xRank * yRank != 2 * resRank) { + terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); + } + SubscriptValue extent[2]{ + xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), + resRank == 2 ? y.GetDimension(1).Extent() : 0}; + if constexpr (IS_ALLOCATING) { + result.Establish( + RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); + for (int j{0}; j < resRank; ++j) { + result.GetDimension(j).SetBounds(1, extent[j]); + } + if (int stat{result.Allocate()}) { + terminator.Crash( + "MATMUL: could not allocate memory for result; STAT=%d", stat); + } + } else { + RUNTIME_CHECK(terminator, resRank == result.rank()); + RUNTIME_CHECK(terminator, result.type() == (TypeCode{RCAT, RKIND})); + RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); + RUNTIME_CHECK(terminator, + resRank == 1 || result.GetDimension(1).Extent() == extent[1]); + } + using WriteResult = + CppTypeFor; + SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; + if (n != y.GetDimension(0).Extent()) { + terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)", + static_cast(n), + static_cast(y.GetDimension(0).Extent())); + } + SubscriptValue xAt[2], yAt[2], resAt[2]; + x.GetLowerBounds(xAt); + y.GetLowerBounds(yAt); + result.GetLowerBounds(resAt); + if (resRank == 2) { // M*M -> M + if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + // TODO: call BLAS-3 SGEMM + } else if constexpr (std::is_same_v) { + // TODO: call BLAS-3 DGEMM + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-3 CGEMM + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-3 ZGEMM + } + } + SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; + for (SubscriptValue i{0}; i < extent[0]; ++i) { + for (SubscriptValue j{0}; j < extent[1]; ++j) { + Accumulator accumulator{x, y}; + yAt[1] = y1 + j; + for (SubscriptValue k{0}; k < n; ++k) { + xAt[1] = x1 + k; + yAt[0] = y0 + k; + accumulator.Accumulate(xAt, yAt); + } + resAt[1] = res1 + j; + *result.template Element(resAt) = accumulator.GetResult(); + } + ++resAt[0]; + ++xAt[0]; + } + } else { + if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + // TODO: call BLAS-2 SGEMV + } else if constexpr (std::is_same_v) { + // TODO: call BLAS-2 DGEMV + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-2 CGEMV + } else if constexpr (std::is_same_v>) { + // TODO: call BLAS-2 ZGEMV + } + } + if (xRank == 2) { // M*V -> V + SubscriptValue x1{xAt[1]}, y0{yAt[0]}; + for (SubscriptValue j{0}; j < extent[0]; ++j) { + Accumulator accumulator{x, y}; + for (SubscriptValue k{0}; k < n; ++k) { + xAt[1] = x1 + k; + yAt[0] = y0 + k; + accumulator.Accumulate(xAt, yAt); + } + *result.template Element(resAt) = accumulator.GetResult(); + ++resAt[0]; + ++xAt[0]; + } + } else { // V*M -> V + SubscriptValue x0{xAt[0]}, y0{yAt[0]}; + for (SubscriptValue j{0}; j < extent[0]; ++j) { + Accumulator accumulator{x, y}; + for (SubscriptValue k{0}; k < n; ++k) { + xAt[0] = x0 + k; + yAt[0] = y0 + k; + accumulator.Accumulate(xAt, yAt); + } + *result.template Element(resAt) = accumulator.GetResult(); + ++resAt[0]; + ++yAt[1]; + } + } + } +} + +// Maps the dynamic type information from the arguments' descriptors +// to the right instantiation of DoMatmul() for valid combinations of +// types. +template struct Matmul { + using ResultDescriptor = + std::conditional_t; + template struct MM1 { + template struct MM2 { + void operator()(ResultDescriptor &result, const Descriptor &x, + const Descriptor &y, Terminator &terminator) const { + if constexpr (constexpr auto resultType{ + GetResultType(XCAT, XKIND, YCAT, YKIND)}) { + if constexpr (common::IsNumericTypeCategory(resultType->first) || + resultType->first == TypeCategory::Logical) { + return DoMatmulfirst, + resultType->second, CppTypeFor, + CppTypeFor>(result, x, y, terminator); + } + } + terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", + static_cast(XCAT), XKIND, static_cast(YCAT), YKIND); + } + }; + void operator()(ResultDescriptor &result, const Descriptor &x, + const Descriptor &y, Terminator &terminator, TypeCategory yCat, + int yKind) const { + ApplyType(yCat, yKind, terminator, result, x, y, terminator); + } + }; + void operator()(ResultDescriptor &result, const Descriptor &x, + const Descriptor &y, const char *sourceFile, int line) const { + Terminator terminator{sourceFile, line}; + auto xCatKind{x.type().GetCategoryAndKind()}; + auto yCatKind{y.type().GetCategoryAndKind()}; + RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); + ApplyType(xCatKind->first, xCatKind->second, terminator, result, + x, y, terminator, yCatKind->first, yCatKind->second); + } +}; + +extern "C" { +void RTNAME(Matmul)(Descriptor &result, const Descriptor &x, + const Descriptor &y, const char *sourceFile, int line) { + Matmul{}(result, x, y, sourceFile, line); +} +void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x, + const Descriptor &y, const char *sourceFile, int line) { + Matmul{}(result, x, y, sourceFile, line); +} +} // extern "C" +} // namespace Fortran::runtime Index: flang/runtime/reduction.h =================================================================== --- flang/runtime/reduction.h +++ flang/runtime/reduction.h @@ -7,9 +7,6 @@ //===----------------------------------------------------------------------===// // Defines the API for the reduction transformational intrinsic functions. -// (Except the complex-valued DOT_PRODUCT and the complex-valued total reduction -// forms of SUM & PRODUCT; the API for those is in complex-reduction.h so that -// C's _Complex can be used for their return types.) #ifndef FORTRAN_RUNTIME_REDUCTION_H_ #define FORTRAN_RUNTIME_REDUCTION_H_ @@ -36,10 +33,10 @@ // results in a caller-supplied descriptor, which is assumed to // be large enough. // -// Complex-valued SUM and PRODUCT reductions have their API -// entry points defined in complex-reduction.h; these are C wrappers -// around C++ implementations so as to keep usage of C's _Complex -// types out of C++ code. +// Complex-valued SUM and PRODUCT reductions and complex-valued +// DOT_PRODUCT have their API entry points defined in complex-reduction.h; +// these here are C wrappers around C++ implementations so as to keep +// usage of C's _Complex types out of C++ code. // SUM() Index: flang/unittests/RuntimeGTest/CMakeLists.txt =================================================================== --- flang/unittests/RuntimeGTest/CMakeLists.txt +++ flang/unittests/RuntimeGTest/CMakeLists.txt @@ -2,6 +2,7 @@ CharacterTest.cpp CrashHandlerFixture.cpp Format.cpp + Matmul.cpp MiscIntrinsic.cpp Namelist.cpp Numeric.cpp Index: flang/unittests/RuntimeGTest/Matmul.cpp =================================================================== --- /dev/null +++ flang/unittests/RuntimeGTest/Matmul.cpp @@ -0,0 +1,98 @@ +//===-- flang/unittests/RuntimeGTest/Matmul.cpp---- -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../../runtime/matmul.h" +#include "gtest/gtest.h" +#include "tools.h" +#include "../../runtime/allocatable.h" +#include "../../runtime/cpp-type.h" +#include "../../runtime/descriptor.h" +#include "../../runtime/type-code.h" + +using namespace Fortran::runtime; +using Fortran::common::TypeCategory; + +TEST(Matmul, Basic) { + // X 0 2 4 Y 6 9 V -1 -2 + // 1 3 5 7 10 + // 8 11 + auto x{MakeArray( + std::vector{2, 3}, std::vector{0, 1, 2, 3, 4, 5})}; + auto y{MakeArray( + std::vector{3, 2}, std::vector{6, 7, 8, 9, 10, 11})}; + auto v{MakeArray( + std::vector{2}, std::vector{-1, -2})}; + StaticDescriptor<2> statDesc; + Descriptor &result{statDesc.descriptor()}; + + RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + + std::memset( + result.raw().base_addr, 0, result.Elements() * result.ElementBytes()); + result.GetDimension(0).SetLowerBound(0); + result.GetDimension(1).SetLowerBound(2); + RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 1); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 3); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), -2); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), -8); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), -14); + result.Destroy(); + + RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 1); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 3); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), -24); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), -27); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), -30); + result.Destroy(); + + // X F F T Y F T + // F T T F T + // F F + auto xLog{MakeArray(std::vector{2, 3}, + std::vector{false, false, false, true, true, false})}; + auto yLog{MakeArray(std::vector{3, 2}, + std::vector{false, false, false, true, true, false})}; + RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2})); + EXPECT_FALSE( + static_cast(*result.ZeroBasedIndexedElement(0))); + EXPECT_FALSE( + static_cast(*result.ZeroBasedIndexedElement(1))); + EXPECT_FALSE( + static_cast(*result.ZeroBasedIndexedElement(2))); + EXPECT_TRUE( + static_cast(*result.ZeroBasedIndexedElement(3))); +}