diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -31,11 +31,15 @@ #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS #endif // _WIN32 +#include +#include #include +#include //===----------------------------------------------------------------------===// // Codegen-compatible structures for Vector type. //===----------------------------------------------------------------------===// +namespace mlir { namespace detail { constexpr bool isPowerOf2(int N) { return (!(N & (N - 1))); } @@ -65,9 +69,8 @@ template struct Vector1D { Vector1D() { - static_assert(detail::nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), - "size error"); - static_assert(detail::nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), + static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error"); + static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), "size error"); } inline T &operator[](unsigned i) { return vector[i]; } @@ -75,9 +78,10 @@ private: T vector[Dim]; - char padding[detail::nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; + char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; }; } // end namespace detail +} // end namespace mlir // N-D vectors recurse down to 1-D. template @@ -95,7 +99,9 @@ // We insert explicit padding in to account for this. template struct Vector - : public detail::Vector1D {}; + : public mlir::detail::Vector1D { +}; template using Vector1D = Vector; @@ -115,6 +121,9 @@ //===----------------------------------------------------------------------===// // Codegen-compatible structures for StridedMemRef type. //===----------------------------------------------------------------------===// +template +class StridedMemrefIterator; + /// StridedMemRef descriptor type with static rank. template struct StridedMemRefType { @@ -123,6 +132,23 @@ int64_t offset; int64_t sizes[N]; int64_t strides[N]; + + template + T &operator[](Range indices) { + assert(indices.size() == N && + "indices should match rank in memref subscript"); + int64_t curOffset = offset; + for (int dim = N - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + StridedMemrefIterator begin() { return {*this}; } + StridedMemrefIterator end() { return {*this, -1}; } + // This operator[] is extremely slow and only for sugaring purposes. StridedMemRefType operator[](int64_t idx) { StridedMemRefType res; @@ -143,6 +169,17 @@ int64_t offset; int64_t sizes[1]; int64_t strides[1]; + + template + T &operator[](Range indices) { + assert(indices.size() == 1 && + "indices should match rank in memref subscript"); + return (*this)[*indices.begin()]; + } + + StridedMemrefIterator begin() { return {*this}; } + StridedMemrefIterator end() { return {*this, -1}; } + T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } }; @@ -152,6 +189,99 @@ T *basePtr; T *data; int64_t offset; + + template + T &operator[](Range indices) { + assert(indices.size() == 0 && + "Expect no indices for 0-rank memref subscript"); + return data[offset]; + } + + StridedMemrefIterator begin() { return {*this}; } + StridedMemrefIterator end() { return {*this, 1}; } +}; + +/// Iterate over all elements in a strided memref. +template +class StridedMemrefIterator { +public: + StridedMemrefIterator(StridedMemRefType &descriptor, + int64_t offset = 0) + : offset(offset), descriptor(descriptor) {} + StridedMemrefIterator &operator++() { + int dim = Rank - 1; + while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) { + offset -= indices[dim] * descriptor.strides[dim]; + indices[dim] = 0; + --dim; + } + if (dim < 0) { + offset = -1; + return *this; + } + ++indices[dim]; + offset += descriptor.strides[dim]; + return *this; + } + + T &operator*() { return descriptor.data[offset]; } + T *operator->() { return &descriptor.data[offset]; } + + const std::array &getIndices() { return indices; } + + bool operator==(const StridedMemrefIterator &other) const { + return other.offset == offset && &other.descriptor == &descriptor; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + /// Array of indices in the multi-dimensional memref. + std::array indices = {}; + /// Descriptor for the strided memref. + StridedMemRefType &descriptor; +}; + +/// Iterate over all elements in a 0-ranked strided memref. +template +class StridedMemrefIterator { +public: + StridedMemrefIterator(StridedMemRefType &descriptor, int64_t offset = 0) + : elt(descriptor.data + offset) {} + + StridedMemrefIterator &operator++() { + ++elt; + return *this; + } + + T &operator*() { return *elt; } + T *operator->() { return elt; } + + // There are no indices for a 0-ranked memref, but this API is provided for + // consistency with the general case. + const std::array &getIndices() { + // Since this is a 0-array of indices we can keep a single global const + // copy. + static const std::array indices = {}; + return indices; + } + + bool operator==(const StridedMemrefIterator &other) const { + return other.elt == elt; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Pointer to the single element in the zero-ranked memref. + T *elt; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -0,0 +1,214 @@ +//===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utils for MLIR ABI interfacing with frameworks. +// +// The templated free functions below make it possible to allocate dense +// contiguous buffers with shapes that interoperate properly with the MLIR +// codegen ABI. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" + +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include +#include +#include + +#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ +#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ + +namespace mlir { +using AllocFunType = llvm::function_ref; + +namespace detail { + +/// Given a shape with sizes greater than 0 along all dimensions, returns the +/// distance, in number of elements, between a slice in a dimension and the next +/// slice in the same dimension. +/// e.g. shape[3, 4, 5] -> strides[20, 5, 1] +template +inline std::array makeStrides(ArrayRef shape) { + assert(shape.size() == N && "expect shape specification to match rank"); + std::array res; + int64_t running = 1; + for (int64_t idx = N - 1; idx >= 0; --idx) { + assert(shape[idx] && "size must be non-negative for all shape dimensions"); + res[idx] = running; + running *= shape[idx]; + } + return res; +} + +/// Build a `StridedMemRefDescriptor` that matches the MLIR ABI. +/// This is an implementation detail that is kept in sync with MLIR codegen +/// conventions. Additionally takes a `shapeAlloc` array which +/// is used instead of `shape` to allocate "more aligned" data and compute the +/// corresponding strides. +template +typename std::enable_if<(N >= 1), StridedMemRefType>::type +makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef shape, + ArrayRef shapeAlloc) { + assert(shape.size() == N); + assert(shapeAlloc.size() == N); + StridedMemRefType descriptor; + descriptor.basePtr = static_cast(ptr); + descriptor.data = static_cast(alignedPtr); + descriptor.offset = 0; + std::copy(shape.begin(), shape.end(), descriptor.sizes); + auto strides = makeStrides(shapeAlloc); + std::copy(strides.begin(), strides.end(), descriptor.strides); + return descriptor; +} + +/// Build a `StridedMemRefDescriptor` that matches the MLIR ABI. +/// This is an implementation detail that is kept in sync with MLIR codegen +/// conventions. Additionally takes a `shapeAlloc` array which +/// is used instead of `shape` to allocate "more aligned" data and compute the +/// corresponding strides. +template +typename std::enable_if<(N == 0), StridedMemRefType>::type +makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef shape = {}, + ArrayRef shapeAlloc = {}) { + assert(shape.size() == N); + assert(shapeAlloc.size() == N); + StridedMemRefType descriptor; + descriptor.basePtr = static_cast(ptr); + descriptor.data = static_cast(alignedPtr); + descriptor.offset = 0; + return descriptor; +} + +/// Align `nElements` of type T with an optional `alignment`. +/// This replaces a portable `posix_memalign`. +/// `alignment` must be a power of 2 and greater than the size of T. By default +/// the alignment is sizeof(T). +template +std::pair +allocAligned(size_t nElements, AllocFunType allocFun = &::malloc, + llvm::Optional alignment = llvm::Optional()) { + assert(sizeof(T) < (1ul << 32) && "Elemental type overflows"); + auto size = nElements * sizeof(T); + auto desiredAlignment = alignment.getValueOr(nextPowerOf2(sizeof(T))); + assert((desiredAlignment & (desiredAlignment - 1)) == 0); + assert(desiredAlignment >= sizeof(T)); + T *data = reinterpret_cast(allocFun(size + desiredAlignment)); + uintptr_t addr = reinterpret_cast(data); + uintptr_t rem = addr % desiredAlignment; + T *alignedData = (rem == 0) + ? data + : reinterpret_cast(addr + (desiredAlignment - rem)); + assert(reinterpret_cast(alignedData) % desiredAlignment == 0); + return std::make_pair(data, alignedData); +} + +} // namespace detail + +//===----------------------------------------------------------------------===// +// Public API +//===----------------------------------------------------------------------===// + +/// Convenient callback to "visit" a memref element by element. +/// This takes a reference to an individual element as well as the coordinates. +/// It can be used in conjuction with a StridedMemrefIterator. +template +using ElementWiseVisitor = llvm::function_ref)>; + +/// Owning MemRef type that abstracts over the runtime type for ranked strided +/// memref. +template +class OwningMemRef { +public: + using DescriptorType = StridedMemRefType; + using FreeFunType = std::function; + + /// Allocate a new dense StridedMemrefRef with a given `shape`. An optional + /// `shapeAlloc` array can be supplied to "pad" every dimension individually. + /// If an ElementWiseVisitor is provided, it will be used to initialize the + /// data, else the memory will be zero-initialized. The alloc and free method + /// used to manage the data allocation can be optionally provided, and default + /// to malloc/free. + OwningMemRef( + ArrayRef shape, ArrayRef shapeAlloc = {}, + ElementWiseVisitor init = {}, + llvm::Optional alignment = llvm::Optional(), + AllocFunType allocFun = &::malloc, + std::function)> freeFun = + [](StridedMemRefType descriptor) { + ::free(descriptor.data); + }) + : freeFunc(freeFun) { + if (shapeAlloc.empty()) + shapeAlloc = shape; + assert(shape.size() == Rank); + assert(shapeAlloc.size() == Rank); + for (unsigned i = 0; i < Rank; ++i) + assert(shape[i] <= shapeAlloc[i] && + "shapeAlloc must be greater than or equal to shape"); + int64_t nElements = 1; + for (int64_t s : shapeAlloc) + nElements *= s; + T *data, *alignedData; + std::tie(data, alignedData) = + detail::allocAligned(nElements, allocFun, alignment); + descriptor = detail::makeStridedMemRefDescriptor(data, alignedData, + shape, shapeAlloc); + if (init) { + for (StridedMemrefIterator it = descriptor.begin(), + end = descriptor.end(); + it != end; ++it) + init(*it, it.getIndices()); + } else { + memset(descriptor.data, 0, + nElements * sizeof(T) + + alignment.getValueOr(detail::nextPowerOf2(sizeof(T)))); + } + } + /// Take ownership of an existing descriptor with a custom deleter. + OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc) + : freeFunc(freeFunc), descriptor(descriptor) {} + ~OwningMemRef() { + if (freeFunc) + freeFunc(descriptor); + } + OwningMemRef(const OwningMemRef &) = delete; + OwningMemRef &operator=(const OwningMemRef &) = delete; + OwningMemRef &operator=(const OwningMemRef &&other) { + freeFunc = other.freeFunc; + descriptor = other.descriptor; + other.freeFunc = nullptr; + memset(0, &other.descriptor, sizeof(other.descriptor)); + } + OwningMemRef(OwningMemRef &&other) { *this = std::move(other); } + + DescriptorType &operator*() { return descriptor; } + DescriptorType *operator->() { return &descriptor; } + T &operator[](std::initializer_list indices) { + return descriptor[std::move(indices)]; + } + +private: + /// Custom deleter used to release the data buffer manager with the descriptor + /// below. + FreeFunType freeFunc; + /// The descriptor is an instance of StridedMemRefType. + DescriptorType descriptor; +}; + +} // namespace mlir + +#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/ExecutionEngine/RunnerUtils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" @@ -89,4 +90,163 @@ ASSERT_EQ(result, 42.f); } +TEST(NativeMemRefJit, ZeroRankMemref) { + OwningMemRef A({}); + A[{}] = 42.; + ASSERT_EQ(*A->data, 42); + A[{}] = 0; + std::string moduleStr = R"mlir( + func @zero_ranked(%arg0 : memref) attributes { llvm.emit_c_interface } { + %cst42 = constant 42.0 : f32 + store %cst42, %arg0[] : memref + return + } + )mlir"; + MLIRContext context; + registerAllDialects(context.getDialectRegistry()); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + + llvm::Error error = jit->invoke("zero_ranked", &*A); + ASSERT_TRUE(!error); + EXPECT_EQ((A[{}]), 42.); + for (float &elt : *A) + EXPECT_EQ(&elt, &(A[{}])); +} + +TEST(NativeMemRefJit, RankOneMemref) { + int64_t shape[] = {9}; + OwningMemRef A(shape); + int count = 1; + for (float &elt : *A) { + EXPECT_EQ(&elt, &(A[{count - 1}])); + elt = count++; + } + + std::string moduleStr = R"mlir( + func @one_ranked(%arg0 : memref) attributes { llvm.emit_c_interface } { + %cst42 = constant 42.0 : f32 + %cst5 = constant 5 : index + store %cst42, %arg0[%cst5] : memref + return + } + )mlir"; + MLIRContext context; + registerAllDialects(context.getDialectRegistry()); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + + llvm::Error error = jit->invoke("one_ranked", &*A); + ASSERT_TRUE(!error); + count = 1; + for (float &elt : *A) { + if (count == 6) + EXPECT_EQ(elt, 42.); + else + EXPECT_EQ(elt, count); + count++; + } +} + +TEST(NativeMemRefJit, BasicMemref) { + constexpr int K = 3; + constexpr int M = 7; + // Prepare arguments beforehand. + auto init = [=](float &elt, ArrayRef indices) { + assert(indices.size() == 2); + elt = M * indices[0] + indices[1]; + }; + int64_t shape[] = {K, M}; + int64_t shapeAlloc[] = {K + 1, M + 1}; + OwningMemRef A(shape, shapeAlloc, init); + ASSERT_EQ(A->sizes[0], K); + ASSERT_EQ(A->sizes[1], M); + ASSERT_EQ(A->strides[0], M + 1); + ASSERT_EQ(A->strides[1], 1); + for (int i = 0; i < K; ++i) + for (int j = 0; j < M; ++j) + EXPECT_EQ((A[{i, j}]), i * M + j); + + std::string moduleStr = R"mlir( + func @rank2_memref(%arg0 : memref, %arg1 : memref) attributes { llvm.emit_c_interface } { + %x = constant 2 : index + %y = constant 1 : index + %cst42 = constant 42.0 : f32 + store %cst42, %arg0[%y, %x] : memref + store %cst42, %arg1[%x, %y] : memref + return + } + )mlir"; + MLIRContext context; + registerAllDialects(context.getDialectRegistry()); + OwningModuleRef module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + std::unique_ptr jit = std::move(jitOrError.get()); + + llvm::Error error = jit->invoke("rank2_memref", &*A, &*A); + ASSERT_TRUE(!error); + EXPECT_EQ((A[{1, 2}]), 42.); + EXPECT_EQ((A[{2, 1}]), 42.); +} + +// A helper function that will be called from the JIT +static void memref_multiply(::StridedMemRefType *memref, + int32_t coefficient) { + for (float &elt : *memref) + elt *= coefficient; +} + +TEST(NativeMemRefJit, JITCallback) { + constexpr int K = 2; + constexpr int M = 2; + int64_t shape[] = {K, M}; + int64_t shapeAlloc[] = {K + 1, M + 1}; + OwningMemRef A(shape, shapeAlloc); + int count = 1; + for (float &elt : *A) + elt = count++; + + std::string moduleStr = R"mlir( + func private @callback(%arg0: memref, %coefficient: i32) attributes { llvm.emit_c_interface } + func @caller_for_callback(%arg0: memref, %coefficient: i32) attributes { llvm.emit_c_interface } { + %unranked = memref_cast %arg0: memref to memref<*xf32> + call @callback(%arg0, %coefficient) : (memref, i32) -> () + return + } + )mlir"; + MLIRContext context; + registerAllDialects(context.getDialectRegistry()); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + // Define any extra symbols so they're available at runtime. + jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { + llvm::orc::SymbolMap symbolMap; + symbolMap[interner("_mlir_ciface_callback")] = + llvm::JITEvaluatedSymbol::fromPointer(memref_multiply); + return symbolMap; + }); + + int32_t coefficient = 3.; + llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient); + ASSERT_TRUE(!error); + count = 1; + for (float &elt : *A) + ASSERT_EQ(elt, coefficient * count++); +} + #endif // _WIN32