diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -31,11 +31,15 @@ #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS #endif // _WIN32 +#include +#include #include +#include //===----------------------------------------------------------------------===// // Codegen-compatible structures for Vector type. //===----------------------------------------------------------------------===// +namespace mlir { namespace detail { constexpr bool isPowerOf2(int N) { return (!(N & (N - 1))); } @@ -65,9 +69,8 @@ template struct Vector1D { Vector1D() { - static_assert(detail::nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), - "size error"); - static_assert(detail::nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), + static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error"); + static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), "size error"); } inline T &operator[](unsigned i) { return vector[i]; } @@ -75,9 +78,10 @@ private: T vector[Dim]; - char padding[detail::nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; + char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; }; } // end namespace detail +} // end namespace mlir // N-D vectors recurse down to 1-D. template @@ -95,7 +99,9 @@ // We insert explicit padding in to account for this. template struct Vector - : public detail::Vector1D {}; + : public mlir::detail::Vector1D { +}; template using Vector1D = Vector; @@ -115,6 +121,9 @@ //===----------------------------------------------------------------------===// // Codegen-compatible structures for StridedMemRef type. //===----------------------------------------------------------------------===// +template +class StridedMemrefIterator; + /// StridedMemRef descriptor type with static rank. template struct StridedMemRefType { @@ -123,6 +132,18 @@ int64_t offset; int64_t sizes[N]; int64_t strides[N]; + + T &operator[](std::initializer_list indices) { + assert(indices.size() == N); + int64_t curOffset = offset; + for (int dim = N - 1; dim >= 0; --dim) + curOffset += *(indices.begin() + dim) * strides[dim]; + return data[curOffset]; + } + + StridedMemrefIterator begin() { return {*this}; } + StridedMemrefIterator end() { return {*this, -1}; } + // This operator[] is extremely slow and only for sugaring purposes. StridedMemRefType operator[](int64_t idx) { StridedMemRefType res; @@ -143,6 +164,12 @@ int64_t offset; int64_t sizes[1]; int64_t strides[1]; + + T &operator[](std::initializer_list indices) { + assert(indices.size() == 1); + return (*this)[*indices.begin()]; + } + T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } }; @@ -152,6 +179,53 @@ T *basePtr; T *data; int64_t offset; + + T &operator[](std::initializer_list indices) { + assert(indices.size() == 0); + return data[offset]; + } +}; + +/// Iterate over all elements in a strided memref. +template +class StridedMemrefIterator { +public: + StridedMemrefIterator(StridedMemRefType &descriptor, + int64_t offset = 0) + : offset(offset), descriptor(descriptor) {} + StridedMemrefIterator &operator++() { + int dim = Rank - 1; + while (dim >= 0 && indices[dim] == descriptor.sizes[dim] - 1) { + offset -= indices[dim] * descriptor.strides[dim]; + indices[dim] = 0; + --dim; + } + if (dim < 0) { + offset = -1; + return *this; + } + ++indices[dim]; + offset += descriptor.strides[dim]; + return *this; + } + + T &operator*() { return descriptor.data[offset]; } + T *operator->() { return &descriptor.data[offset]; } + + std::array &getIndices() { return indices; } + + bool operator==(const StridedMemrefIterator &other) const { + return other.offset == offset && &other.descriptor == &descriptor; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + int64_t offset = 0; + std::array indices = {}; + StridedMemRefType &descriptor; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -0,0 +1,213 @@ +//===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utils for MLIR ABI interfacing with frameworks +// +// The templated free functions below make it possible to allocate dense +// contiguous buffers with shapes that interoperate properly with the MLIR +// codegen ABI. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" + +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include +#include +#include + +#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ +#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ + +namespace mlir { +using AllocFunType = llvm::function_ref; + +namespace detail { + +// Given a shape with sizes greater than 0 along all dimensions, +// returns the distance, in number of elements, between a slice in a dimension +// and the next slice in the same dimension. +// e.g. shape[3, 4, 5] -> strides[20, 5, 1] +template +inline std::array makeStrides(ArrayRef shape) { + assert(shape.size() == N); + std::array res; + int64_t running = 1; + for (int64_t idx = N - 1; idx >= 0; --idx) { + assert(shape[idx] && "size must be nonnegatice for all shape dimensions"); + res[idx] = running; + running *= shape[idx]; + } + return res; +} + +// Mallocs a StridedMemRefDescriptor* that matches the MLIR ABI. +// This is an implementation detail that is kept in sync with MLIR codegen +// conventions. Additionally takes a `shapeAlloc` array which +// is used instead of `shape` to allocate "more aligned" data and compute the +// corresponding strides. +template +typename std::enable_if<(N >= 1), StridedMemRefType>::type +makeStridedMemRefDescriptor(void *ptr, void *alignedPtr, + ArrayRef shape, + ArrayRef shapeAlloc) { + assert(shape.size() == N); + assert(shapeAlloc.size() == N); + StridedMemRefType descriptor; + descriptor.basePtr = static_cast(ptr); + descriptor.data = static_cast(alignedPtr); + descriptor.offset = 0; + std::copy(shape.begin(), shape.end(), descriptor.sizes); + auto strides = makeStrides(shapeAlloc); + std::copy(strides.begin(), strides.end(), descriptor.strides); + return descriptor; +} + +// Mallocs a StridedMemRefDescriptor* that matches the MLIR ABI. +// This is an implementation detail that is kept in sync with MLIR codegen +// conventions. Additionally takes a `shapeAlloc` array which +// is used instead of `shape` to allocate "more aligned" data and compute the +// corresponding strides. +template +typename std::enable_if<(N == 0), StridedMemRefType>::type +makeStridedMemRefDescriptor(void *ptr, void *alignedPtr, + ArrayRef shape = {}, + ArrayRef shapeAlloc = {}) { + assert(shape.size() == N); + assert(shapeAlloc.size() == N); + StridedMemRefType descriptor; + descriptor.basePtr = static_cast(ptr); + descriptor.data = static_cast(alignedPtr); + descriptor.offset = 0; + return descriptor; +} + +// Mallocs a StridedMemRefDescriptor* that matches the MLIR ABI. +// This is an implementation detail that is kept in sync with MLIR codegen +// conventions. +template +typename std::enable_if<(N >= 1), StridedMemRefType *>::type +makeStridedMemRefDescriptor(void *ptr, void *alignedPtr, + ArrayRef shape) { + return makeStridedMemRefDescriptor(ptr, alignedPtr, shape, shape); +} + +// No such thing as a portable posix_memalign, roll our own. +// [alignment] allow to specify an arbitrary alignment. It must be a power of 2 +// and greater than the size of T. By default the alignment is sizeof(T). +template +std::pair +allocAligned(size_t nElements, AllocFunType allocFun = &::malloc, + llvm::Optional alignment = llvm::Optional()) { + assert(sizeof(T) < (1ul << 32) && "Elemental type overflows"); + auto size = nElements * sizeof(T); + auto desiredAlignment = alignment.getValueOr(nextPowerOf2(sizeof(T))); + assert((desiredAlignment & (desiredAlignment - 1)) == 0); + assert(desiredAlignment >= sizeof(T)); + void *data = allocFun(size + desiredAlignment); + uintptr_t addr = reinterpret_cast(data); + uintptr_t rem = addr % desiredAlignment; + void *alignedData = + (rem == 0) ? data + : reinterpret_cast(addr + (desiredAlignment - rem)); + assert(reinterpret_cast(alignedData) % desiredAlignment == 0); + return std::pair(data, alignedData); +} + +} // namespace detail + +//===----------------------------------------------------------------------===// +// Public API +//===----------------------------------------------------------------------===// + +// Convenient callback to "visit" a memref element by element. +// This takes a reference to an individual element as well as the coordinates. +// It can be used in conjuction with a StridedMemrefIterator. +template +using ElementWiseVisitor = llvm::function_ref)>; + +/// Owning MemRef type that abstracts over the runtime type for ranked memref. +template +class OwningMemRef { +public: + using DescriptorType = StridedMemRefType; + using FreeFunType = std::function; + + // Allocate a new dense StridedMemrefRef with a given `shape`. An optional + // `shapeAlloc` array can be supplied to "pad" every dimension individually. + // If an ElementWiseVisitor is provided, it will be used to initialize the + // data, else the memory will be zero-initialized. The alloc and free method + // used to manage the data allocation can be optionally provided, and default + // to malloc/free. + OwningMemRef( + ArrayRef shape, ArrayRef shapeAlloc = {}, + ElementWiseVisitor init = {}, + llvm::Optional alignment = llvm::Optional(), + AllocFunType allocFun = &::malloc, + std::function)> freeFun = + [](StridedMemRefType descriptor) { + ::free(descriptor.data); + }) + : freeFunc(freeFun) { + if (shapeAlloc.empty()) + shapeAlloc = shape; + assert(shape.size() == Rank); + assert(shapeAlloc.size() == Rank); + for (unsigned i = 0; i < Rank; ++i) + assert(shape[i] <= shapeAlloc[i] && + "shapeAlloc must be greater than or equal to shape"); + int64_t nElements = 1; + for (int64_t s : shapeAlloc) + nElements *= s; + auto allocated = detail::allocAligned(nElements, allocFun, alignment); + T *data = static_cast(allocated.first); + auto *alignedData = static_cast(allocated.second); + descriptor = detail::makeStridedMemRefDescriptor( + data, alignedData, shape, shapeAlloc); + if (init) { + for (StridedMemrefIterator it = descriptor.begin(), + end = descriptor.end(); + it != end; ++it) + init(*it, it.getIndices()); + } else { + memset(descriptor.data, 0, + nElements * sizeof(T) + + alignment.getValueOr(detail::nextPowerOf2(sizeof(T)))); + } + } + /// Take ownership of an existing descriptor with a custom deleter. + OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc) + : freeFunc(freeFunc), descriptor(descriptor) {} + ~OwningMemRef() { + if (freeFunc) + freeFunc(descriptor); + } + + DescriptorType &operator*() { return descriptor; } + DescriptorType *operator->() { return &descriptor; } + T &operator[](std::initializer_list indices) { + return descriptor[std::move(indices)]; + } + +private: + FreeFunType freeFunc; + // The descriptor is an instance of StridedMemRefType. + DescriptorType descriptor; +}; + +} // namespace mlir + +#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ \ No newline at end of file diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/MemRefUtils.h" #include "mlir/ExecutionEngine/RunnerUtils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" @@ -89,4 +90,52 @@ ASSERT_EQ(result, 42.f); } +TEST(NativeMemRefJit, BasicMemref) { + constexpr int K = 3; + constexpr int M = 7; + // Prepare arguments beforehand. + auto init = [=](float &elt, ArrayRef indices) { + assert(indices.size() == 2); + elt = M * indices[0] + indices[1]; + }; + int64_t shape[] = {K, M}; + int64_t shapeAlloc[] = {K + 1, M + 1}; + OwningMemRef A(shape, shapeAlloc, init); + ASSERT_EQ(A->sizes[0], K); + ASSERT_EQ(A->sizes[1], M); + ASSERT_EQ(A->strides[0], M + 1); + ASSERT_EQ(A->strides[1], 1); + for (int i = 0; i < K; ++i) { + for (int j = 0; j < M; ++j) { + auto value = A[{i, j}]; + EXPECT_EQ(value, i * M + j); + EXPECT_EQ((A[{i, j}]), i * M + j); + } + } + + std::string moduleStr = R"mlir( + func @foo(%arg0 : memref, %arg1 : memref) attributes { llvm.emit_c_interface } { + %x = constant 2 : index + %y = constant 1 : index + %cst42 = constant 42.0 : f32 + store %cst42, %arg0[%y, %x] : memref + store %cst42, %arg1[%x, %y] : memref + return + } + )mlir"; + MLIRContext context; + registerAllDialects(context.getDialectRegistry()); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + + llvm::Error error = jit->invoke("foo", &*A, &*A); + ASSERT_TRUE(!error); + EXPECT_EQ((A[{1, 2}]), 42.); + EXPECT_EQ((A[{2, 1}]), 42.); +} + #endif // _WIN32