diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Runtime/matmul-transpose.h @@ -0,0 +1,30 @@ +//===-- include/flang/Runtime/matmul-transpose.h ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// API for optimised MATMUL(TRANSPOSE(a), b) + +#ifndef FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_ +#define FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_ +#include "flang/Runtime/entry-names.h" +namespace Fortran::runtime { +class Descriptor; +extern "C" { + +// The most general MATMUL(TRANSPOSE()). All type and shape information is +// taken from the arguments' descriptors, and the result is dynamically +// allocated. +void RTNAME(MatmulTranspose)(Descriptor &, const Descriptor &, + const Descriptor &, const char *sourceFile = nullptr, int line = 0); + +// A non-allocating variant; the result's descriptor must be established +// and have a valid base address. +void RTNAME(MatmulTransposeDirect)(const Descriptor &, const Descriptor &, + const Descriptor &, const char *sourceFile = nullptr, int line = 0); +} // extern "C" +} // namespace Fortran::runtime +#endif // FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_ diff --git a/flang/runtime/CMakeLists.txt b/flang/runtime/CMakeLists.txt --- a/flang/runtime/CMakeLists.txt +++ b/flang/runtime/CMakeLists.txt @@ -125,6 +125,7 @@ io-error.cpp io-stmt.cpp main.cpp + matmul-transpose.cpp matmul.cpp memory.cpp misc-intrinsic.cpp diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp new file mode 100644 --- /dev/null +++ b/flang/runtime/matmul-transpose.cpp @@ -0,0 +1,296 @@ +//===-- runtime/matmul-transpose.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// Implements a fused matmul-transpose operation +// +// There are two main entry points; one establishes a descriptor for the +// result and allocates it, and the other expects a result descriptor that +// points to existing storage. +// +// This implementation must handle all combinations of numeric types and +// kinds (100 - 165 cases depending on the target), plus all combinations +// of logical kinds (16). A single template undergoes many instantiations +// to cover all of the valid possibilities. +// +// The usefulness of this optimization should be reviewed once Matmul is swapped +// to use the faster BLAS routines. + +#include "flang/Runtime/matmul-transpose.h" +#include "terminator.h" +#include "tools.h" +#include "flang/Runtime/c-or-cpp.h" +#include "flang/Runtime/cpp-type.h" +#include "flang/Runtime/descriptor.h" +#include + +namespace { +using namespace Fortran::runtime; + +// Contiguous numeric TRANSPOSE(matrix)*matrix multiplication +// TRANSPOSE(matrix(n, rows)) * matrix(n,cols) -> +// matrix(rows, n) * matrix(n,cols) -> matrix(rows,cols) +// The transpose is implemented by swapping the indices of accesses into the LHS +// +// Straightforward algorithm: +// DO 1 I = 1, NROWS +// DO 1 J = 1, NCOLS +// RES(I,J) = 0 +// DO 1 K = 1, N +// 1 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) +// +// With loop distribution and transposition to avoid the inner sum +// reduction and to avoid non-unit strides: +// DO 1 I = 1, NROWS +// DO 1 J = 1, NCOLS +// 1 RES(I,J) = 0 +// DO 2 J = 1, NCOLS +// DO 2 I = 1, NROWS +// DO 2 K = 1, N +// 2 RES(I,J) = RES(I,J) + X(K,I)*Y(K,J) ! loop-invariant last term +template +inline static void MatrixTransposedTimesMatrix( + CppTypeFor *RESTRICT product, SubscriptValue rows, + SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, + SubscriptValue n) { + using ResultType = CppTypeFor; + + std::memset(product, 0, rows * cols * sizeof *product); + for (SubscriptValue j{0}; j < cols; ++j) { + for (SubscriptValue i{0}; i < rows; ++i) { + for (SubscriptValue k{0}; k < n; ++k) { + ResultType x_ki = static_cast(x[i * n + k]); + ResultType y_kj = static_cast(y[j * n + k]); + product[j * rows + i] += x_ki * y_kj; + } + } + } +} + +// Contiguous numeric matrix*vector multiplication +// matrix(rows,n) * column vector(n) -> column vector(rows) +// Straightforward algorithm: +// DO 1 I = 1, NROWS +// RES(I) = 0 +// DO 1 K = 1, N +// 1 RES(I) = RES(I) + X(K,I)*Y(K) +// With loop distribution and transposition to avoid the inner +// sum reduction and to avoid non-unit strides: +// DO 1 I = 1, NROWS +// 1 RES(I) = 0 +// DO 2 I = 1, NROWS +// DO 2 K = 1, N +// 2 RES(I) = RES(I) + X(K,I)*Y(K) +template +inline static void MatrixTransposedTimesVector( + CppTypeFor *RESTRICT product, SubscriptValue rows, + SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) { + using ResultType = CppTypeFor; + std::memset(product, 0, rows * sizeof *product); + for (SubscriptValue i{0}; i < rows; ++i) { + for (SubscriptValue k{0}; k < n; ++k) { + ResultType x_ki = static_cast(x[i * n + k]); + ResultType y_k = static_cast(y[k]); + product[i] += x_ki * y_k; + } + } +} + +// Implements an instance of MATMUL for given argument types. +template +inline static void DoMatmulTranspose( + std::conditional_t &result, + const Descriptor &x, const Descriptor &y, Terminator &terminator) { + int xRank{x.rank()}; + int yRank{y.rank()}; + int resRank{xRank + yRank - 2}; + if (xRank * yRank != 2 * resRank) { + terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); + } + SubscriptValue extent[2]{x.GetDimension(1).Extent(), + resRank == 2 ? y.GetDimension(1).Extent() : 0}; + if constexpr (IS_ALLOCATING) { + result.Establish( + RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); + for (int j{0}; j < resRank; ++j) { + result.GetDimension(j).SetBounds(1, extent[j]); + } + if (int stat{result.Allocate()}) { + terminator.Crash( + "MATMUL: could not allocate memory for result; STAT=%d", stat); + } + } else { + RUNTIME_CHECK(terminator, resRank == result.rank()); + RUNTIME_CHECK( + terminator, result.ElementBytes() == static_cast(RKIND)); + RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); + RUNTIME_CHECK(terminator, + resRank == 1 || result.GetDimension(1).Extent() == extent[1]); + } + SubscriptValue n{x.GetDimension(0).Extent()}; + if (n != y.GetDimension(0).Extent()) { + terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", + static_cast(x.GetDimension(0).Extent()), + static_cast(x.GetDimension(1).Extent()), + static_cast(y.GetDimension(0).Extent()), + static_cast(y.GetDimension(1).Extent())); + } + using WriteResult = + CppTypeFor; + const SubscriptValue rows{extent[0]}; + const SubscriptValue cols{extent[1]}; + if constexpr (RCAT != TypeCategory::Logical) { + if (x.IsContiguous() && y.IsContiguous() && + (IS_ALLOCATING || result.IsContiguous())) { + // Contiguous numeric matrices + if (resRank == 2) { // M*M -> M + MatrixTransposedTimesMatrix( + result.template OffsetElement(), rows, cols, + x.OffsetElement(), y.OffsetElement(), n); + return; + } + if (xRank == 2) { // M*V -> V + MatrixTransposedTimesVector( + result.template OffsetElement(), rows, n, + x.OffsetElement(), y.OffsetElement()); + return; + } + // else V*M -> V (not allowed because TRANSPOSE() is only defined for rank + // 1 matrices + terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", + static_cast(x.GetDimension(0).Extent()), + static_cast(n), + static_cast(y.GetDimension(0).Extent()), + static_cast(y.GetDimension(1).Extent())); + return; + } + } + // General algorithms for LOGICAL and noncontiguity + SubscriptValue xLB[2], yLB[2], resLB[2]; + x.GetLowerBounds(xLB); + y.GetLowerBounds(yLB); + result.GetLowerBounds(resLB); + using ResultType = CppTypeFor; + if (resRank == 2) { // M*M -> M + for (SubscriptValue i{0}; i < rows; ++i) { + for (SubscriptValue j{0}; j < cols; ++j) { + ResultType res_ij; + if constexpr (RCAT == TypeCategory::Logical) { + res_ij = false; + } else { + res_ij = 0; + } + + for (SubscriptValue k{0}; k < n; ++k) { + SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; + SubscriptValue yAt[2]{k + yLB[0], j + yLB[1]}; + if constexpr (RCAT == TypeCategory::Logical) { + ResultType x_ki = IsLogicalElementTrue(x, xAt); + ResultType y_kj = IsLogicalElementTrue(y, yAt); + res_ij = res_ij || (x_ki && y_kj); + } else { + ResultType x_ki = static_cast(*x.Element(xAt)); + ResultType y_kj = static_cast(*y.Element(yAt)); + res_ij += x_ki * y_kj; + } + } + SubscriptValue resAt[2]{i + resLB[0], j + resLB[1]}; + *result.template Element(resAt) = res_ij; + } + } + } else if (xRank == 2) { // M*V -> V + for (SubscriptValue i{0}; i < rows; ++i) { + ResultType res_i; + if constexpr (RCAT == TypeCategory::Logical) { + res_i = false; + } else { + res_i = 0; + } + + for (SubscriptValue k{0}; k < n; ++k) { + SubscriptValue xAt[2]{k + xLB[0], i + xLB[1]}; + SubscriptValue yAt[1]{k + yLB[0]}; + if constexpr (RCAT == TypeCategory::Logical) { + ResultType x_ki = IsLogicalElementTrue(x, xAt); + ResultType y_k = IsLogicalElementTrue(y, yAt); + res_i = res_i || (x_ki && y_k); + } else { + ResultType x_ki = static_cast(*x.Element(xAt)); + ResultType y_k = static_cast(*y.Element(yAt)); + res_i += x_ki * y_k; + } + } + SubscriptValue resAt[1]{i + resLB[0]}; + *result.template Element(resAt) = res_i; + } + } else { // V*M -> V + // TRANSPOSE(V) not allowed by fortran standard + terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", + static_cast(x.GetDimension(0).Extent()), + static_cast(n), + static_cast(y.GetDimension(0).Extent()), + static_cast(y.GetDimension(1).Extent())); + } +} + +// Maps the dynamic type information from the arguments' descriptors +// to the right instantiation of DoMatmul() for valid combinations of +// types. +template struct MatmulTranspose { + using ResultDescriptor = + std::conditional_t; + template struct MM1 { + template struct MM2 { + void operator()(ResultDescriptor &result, const Descriptor &x, + const Descriptor &y, Terminator &terminator) const { + if constexpr (constexpr auto resultType{ + GetResultType(XCAT, XKIND, YCAT, YKIND)}) { + if constexpr (Fortran::common::IsNumericTypeCategory( + resultType->first) || + resultType->first == TypeCategory::Logical) { + return DoMatmulTransposefirst, + resultType->second, CppTypeFor, + CppTypeFor>(result, x, y, terminator); + } + } + terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", + static_cast(XCAT), XKIND, static_cast(YCAT), YKIND); + } + }; + void operator()(ResultDescriptor &result, const Descriptor &x, + const Descriptor &y, Terminator &terminator, TypeCategory yCat, + int yKind) const { + ApplyType(yCat, yKind, terminator, result, x, y, terminator); + } + }; + void operator()(ResultDescriptor &result, const Descriptor &x, + const Descriptor &y, const char *sourceFile, int line) const { + Terminator terminator{sourceFile, line}; + auto xCatKind{x.type().GetCategoryAndKind()}; + auto yCatKind{y.type().GetCategoryAndKind()}; + RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); + ApplyType(xCatKind->first, xCatKind->second, terminator, result, + x, y, terminator, yCatKind->first, yCatKind->second); + } +}; +} // namespace + +namespace Fortran::runtime { +extern "C" { +void RTNAME(MatmulTranspose)(Descriptor &result, const Descriptor &x, + const Descriptor &y, const char *sourceFile, int line) { + MatmulTranspose{}(result, x, y, sourceFile, line); +} +void RTNAME(MatmulTransposeDirect)(const Descriptor &result, + const Descriptor &x, const Descriptor &y, const char *sourceFile, + int line) { + MatmulTranspose{}(result, x, y, sourceFile, line); +} +} // extern "C" +} // namespace Fortran::runtime diff --git a/flang/unittests/Runtime/CMakeLists.txt b/flang/unittests/Runtime/CMakeLists.txt --- a/flang/unittests/Runtime/CMakeLists.txt +++ b/flang/unittests/Runtime/CMakeLists.txt @@ -12,6 +12,7 @@ Inquiry.cpp ListInputTest.cpp Matmul.cpp + MatmulTranspose.cpp MiscIntrinsic.cpp Namelist.cpp Numeric.cpp diff --git a/flang/unittests/Runtime/MatmulTranspose.cpp b/flang/unittests/Runtime/MatmulTranspose.cpp new file mode 100644 --- /dev/null +++ b/flang/unittests/Runtime/MatmulTranspose.cpp @@ -0,0 +1,126 @@ +//===-- flang/unittests/Runtime/MatmulTranspose.cpp -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gtest/gtest.h" +#include "tools.h" +#include "flang/Runtime/allocatable.h" +#include "flang/Runtime/cpp-type.h" +#include "flang/Runtime/descriptor.h" +#include "flang/Runtime/matmul-transpose.h" +#include "flang/Runtime/type-code.h" + +using namespace Fortran::runtime; +using Fortran::common::TypeCategory; + +TEST(MatmulTranspose, Basic) { + // X 0 1 Y 6 9 Z 6 7 8 M 0 0 1 1 V -1 -2 + // 2 3 7 10 9 10 11 0 1 0 1 + // 4 5 8 11 + + auto x{MakeArray( + std::vector{3, 2}, std::vector{0, 2, 4, 1, 3, 5})}; + auto y{MakeArray( + std::vector{3, 2}, std::vector{6, 7, 8, 9, 10, 11})}; + auto z{MakeArray( + std::vector{2, 3}, std::vector{6, 9, 7, 10, 8, 11})}; + auto m{MakeArray(std::vector{2, 4}, + std::vector{0, 0, 0, 1, 1, 0, 1, 1})}; + auto v{MakeArray( + std::vector{2}, std::vector{-1, -2})}; + StaticDescriptor<2, true> statDesc; + Descriptor &result{statDesc.descriptor()}; + + RTNAME(MatmulTranspose)(result, *x, *y, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + + std::memset( + result.raw().base_addr, 0, result.Elements() * result.ElementBytes()); + result.GetDimension(0).SetLowerBound(0); + result.GetDimension(1).SetLowerBound(2); + RTNAME(MatmulTransposeDirect)(result, *x, *y, __FILE__, __LINE__); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 46); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 67); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 64); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 94); + result.Destroy(); + + RTNAME(MatmulTranspose)(result, *z, *v, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 1); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 3); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), -24); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), -27); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), -30); + result.Destroy(); + + RTNAME(MatmulTranspose)(result, *m, *z, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + ASSERT_EQ(result.GetDimension(0).LowerBound(), 1); + ASSERT_EQ(result.GetDimension(0).UpperBound(), 4); + ASSERT_EQ(result.GetDimension(1).LowerBound(), 1); + ASSERT_EQ(result.GetDimension(1).UpperBound(), 3); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 2})); + EXPECT_EQ(*result.ZeroBasedIndexedElement(0), 0); + EXPECT_EQ(*result.ZeroBasedIndexedElement(1), 9); + EXPECT_EQ(*result.ZeroBasedIndexedElement(2), 6); + EXPECT_EQ(*result.ZeroBasedIndexedElement(3), 15); + EXPECT_EQ(*result.ZeroBasedIndexedElement(4), 0); + EXPECT_EQ(*result.ZeroBasedIndexedElement(5), 10); + EXPECT_EQ(*result.ZeroBasedIndexedElement(6), 7); + EXPECT_EQ(*result.ZeroBasedIndexedElement(7), 17); + EXPECT_EQ(*result.ZeroBasedIndexedElement(8), 0); + EXPECT_EQ(*result.ZeroBasedIndexedElement(9), 11); + EXPECT_EQ(*result.ZeroBasedIndexedElement(10), 8); + EXPECT_EQ(*result.ZeroBasedIndexedElement(11), 19); + result.Destroy(); + + // X F F Y F T V T F T + // T F F T + // T T F F + auto xLog{MakeArray(std::vector{3, 2}, + std::vector{false, true, true, false, false, true})}; + auto yLog{MakeArray(std::vector{3, 2}, + std::vector{false, false, false, true, true, false})}; + auto vLog{MakeArray( + std::vector{3}, std::vector{true, false, true})}; + RTNAME(MatmulTranspose)(result, *xLog, *yLog, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 2); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(1).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2})); + EXPECT_FALSE( + static_cast(*result.ZeroBasedIndexedElement(0))); + EXPECT_FALSE( + static_cast(*result.ZeroBasedIndexedElement(1))); + EXPECT_TRUE( + static_cast(*result.ZeroBasedIndexedElement(2))); + EXPECT_FALSE( + static_cast(*result.ZeroBasedIndexedElement(3))); + + RTNAME(MatmulTranspose)(result, *yLog, *vLog, __FILE__, __LINE__); + ASSERT_EQ(result.rank(), 1); + EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); + EXPECT_EQ(result.GetDimension(0).Extent(), 2); + ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2})); + EXPECT_FALSE( + static_cast(*result.ZeroBasedIndexedElement(0))); + EXPECT_TRUE( + static_cast(*result.ZeroBasedIndexedElement(1))); +}