diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -31,6 +31,9 @@ #define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS #endif // _WIN32 +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + #include #include #include @@ -80,6 +83,39 @@ T vector[Dim]; char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; }; + +/// Update the `indices` array representing the subscript for a memref to point +/// to the next element. The `sizes` is the current shape of the memref, and +/// `strides` the physical size for each dimension. `offset` is the physical +/// offset in the buffer for the current element. Returns the physical offset in +/// the buffer for the next element, or -1 if the end of the memref has been +/// reached. +inline int64_t offsetForNextElement(llvm::MutableArrayRef indices, + llvm::ArrayRef sizes, + llvm::ArrayRef strides, + int64_t offset) { + int dim = indices.size() - 1; + // Start from the end of the indices and try to find the outer most dimension + // where the index can be incremented. For each dimension where the index is + // at the end, roll back the offset to the beginning for this dimension (the + // next outer dimension will increment by the stride size). + while (dim >= 0 && indices[dim] == (sizes[dim] - 1)) { + offset -= indices[dim] * strides[dim]; + indices[dim] = 0; + --dim; + } + if (dim < 0) { + // All indices are at the end for every dimensions: end of the memref. + offset = -1; + } else { + // Increment the index for the current dimension, and move the offset + // forward by the current stride. + ++indices[dim]; + offset += strides[dim]; + } + return offset; +} + } // end namespace detail } // end namespace mlir @@ -212,18 +248,8 @@ int64_t offset = 0) : offset(offset), descriptor(descriptor) {} StridedMemrefIterator &operator++() { - int dim = Rank - 1; - while (dim >= 0 && indices[dim] == (descriptor.sizes[dim] - 1)) { - offset -= indices[dim] * descriptor.strides[dim]; - indices[dim] = 0; - --dim; - } - if (dim < 0) { - offset = -1; - return *this; - } - ++indices[dim]; - offset += descriptor.strides[dim]; + offset = mlir::detail::offsetForNextElement(indices, descriptor.sizes, + descriptor.strides, offset); return *this; } @@ -290,11 +316,102 @@ //===----------------------------------------------------------------------===// // Codegen-compatible structure for UnrankedMemRef type. //===----------------------------------------------------------------------===// + +template +class UnrankedMemrefIterator; + // Unranked MemRef template struct UnrankedMemRefType { int64_t rank; + /// This is a pointer to an instantiation of StridedMemref. void *descriptor; + + // Return a pointer to the `basePtr` member of the Strided descriptor. + T *basePtr() { return *reinterpret_cast(descriptor); } + + // Return a pointer to the `data` member of the Strided descriptor. + T *data() { + return *reinterpret_cast(reinterpret_cast(descriptor) + + sizeof(T *)); + } + + // Return the `offset` member of the Strided descriptor. + int64_t offset() { + return *reinterpret_cast(reinterpret_cast(descriptor) + + sizeof(T *) + sizeof(T *)); + } + + // Return the `sizes` member of the Strided descriptor. + llvm::ArrayRef sizes() { + return llvm::makeArrayRef(reinterpret_cast( + reinterpret_cast(descriptor) + + sizeof(T *) + sizeof(T *) + sizeof(int64_t)), + rank); + } + + // Return the `strides` member of the Strided descriptor. + llvm::ArrayRef strides() { + return llvm::makeArrayRef( + reinterpret_cast(reinterpret_cast(descriptor) + + sizeof(T *) + sizeof(T *) + + sizeof(int64_t) + rank * sizeof(int64_t)), + rank); + } + + template ().begin())> + T &operator[](Range indices) { + assert(static_cast(indices.size()) == rank); + int64_t offset = 0; + // This is a very ugly hard-coded way to access the data and strides + // inside a StridedMemrefType!! + llvm::ArrayRef strides = this->strides(); + + for (int dim = rank - 1; dim >= 0; --dim) + offset += *(indices.begin() + dim) * strides[dim]; + return data()[offset]; + } + + UnrankedMemrefIterator begin() { return {*this}; } + UnrankedMemrefIterator end() { return {*this, -1}; } +}; + +/// Iterate over all elements in an unranked strided memref. +template +class UnrankedMemrefIterator { +public: + UnrankedMemrefIterator(UnrankedMemRefType &descriptor, int64_t offset = 0) + : offset(offset), descriptor(descriptor) { + indices.resize(descriptor.rank); + } + UnrankedMemrefIterator &operator++() { + offset = mlir::detail::offsetForNextElement(indices, descriptor.sizes(), + descriptor.strides(), offset); + return *this; + } + + T &operator*() { return descriptor.data()[offset]; } + T *operator->() { return descriptor.data()[offset]; } + + bool operator==(const UnrankedMemrefIterator &other) const { + return &other.descriptor == &descriptor && other.offset == offset; + } + + bool operator!=(const UnrankedMemrefIterator &other) const { + return !(*this == other); + } + + llvm::ArrayRef getIndices() { return indices; } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + /// Array of indices in the multi-dimensional memref. + llvm::SmallVector indices = {}; + /// Descriptor for the unranked memref. + UnrankedMemRefType &descriptor; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h --- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -37,6 +37,22 @@ namespace detail { +/// Given a shape with sizes greater than 0 along all dimensions, populate the +/// provided `strides` array with the distance, in number of elements, between a +/// slice in a dimension and the next slice in the same dimension. +/// e.g. shape[3, 4, 5] -> strides[20, 5, 1] +inline void makeStrides(ArrayRef shape, + MutableArrayRef strides) { + assert(shape.size() == strides.size() && + "expect shapes and strides size to match"); + int64_t running = 1; + for (int64_t idx = shape.size() - 1; idx >= 0; --idx) { + assert(shape[idx] && "size must be non-negative for all shape dimensions"); + strides[idx] = running; + running *= shape[idx]; + } +} + /// Given a shape with sizes greater than 0 along all dimensions, returns the /// distance, in number of elements, between a slice in a dimension and the next /// slice in the same dimension. @@ -44,14 +60,9 @@ template inline std::array makeStrides(ArrayRef shape) { assert(shape.size() == N && "expect shape specification to match rank"); - std::array res; - int64_t running = 1; - for (int64_t idx = N - 1; idx >= 0; --idx) { - assert(shape[idx] && "size must be non-negative for all shape dimensions"); - res[idx] = running; - running *= shape[idx]; - } - return res; + std::array strides; + makeStrides(shape, strides); + return strides; } /// Build a `StridedMemRefDescriptor` that matches the MLIR ABI. @@ -93,6 +104,50 @@ return descriptor; } +// Mallocs an UnrankedMemRefType* that contains a ranked +// StridedMemRefDescriptor* and matches the MLIR ABI. This is an +// implementation detail that is kept in sync with MLIR codegen conventions. +template +::UnrankedMemRefType +makeUnrankedDescriptor(T *data, T *alignedData, ArrayRef shape, + AllocFunType allocFun = &::malloc) { + ::UnrankedMemRefType res{}; + res.rank = shape.size(); + if (res.rank == 0) { + res.descriptor = allocFun(sizeof(StridedMemRefType)); + *static_cast *>(res.descriptor) = + makeStridedMemRefDescriptor<0>(data, alignedData, shape, shape); + } else { + // Allocate and build a StridedMemRefType descriptor for Rank >= 1. + // The allocated size is the size of the descriptor for rank 1 plus two + // int64_t for each additional rank for one shape and one stride. The + // reinterpret_cast computations are computing the offset to the individual + // fields based on the rank. + res.descriptor = allocFun(sizeof(StridedMemRefType) + + (res.rank - 1) * 2 * sizeof(int64_t)); + *reinterpret_cast(res.descriptor) = data; + *reinterpret_cast(reinterpret_cast(res.descriptor) + + sizeof(T *)) = alignedData; + *reinterpret_cast(reinterpret_cast(res.descriptor) + + 2 * sizeof(T *)) = 0; + std::copy(shape.begin(), shape.end(), + const_cast(res.sizes().data())); + llvm::SmallVector strides; + strides.resize(res.rank); + makeStrides(shape, strides); + std::copy(strides.begin(), strides.end(), + const_cast(res.strides().data())); + } + return res; +} + +// Frees an UnrankedMemRefType* +template +void freeUnrankedDescriptor(::UnrankedMemRefType *desc) { + free(desc->descriptor); + free(desc); +} + /// Align `nElements` of type T with an optional `alignment`. /// This replaces a portable `posix_memalign`. /// `alignment` must be a power of 2 and greater than the size of T. By default @@ -209,6 +264,65 @@ DescriptorType descriptor; }; +/// Owning Unranked MemRef type that abstracts over the runtime type for memref. +template +class OwningUnrankedMemRef { +public: + using DescriptorType = ::UnrankedMemRefType; + using FreeFunType = std::function; + + // Allocate a new UnrankedMemref with a given `shape` and initializer + // of type ElementWiseVisitor. Can optionally take specific `alloc` and `free` + // functions. + OwningUnrankedMemRef( + ArrayRef shape, ElementWiseVisitor init = {}, + llvm::Optional alignment = llvm::Optional(), + AllocFunType alloc = &::malloc, + FreeFunType freeFun = + [](DescriptorType descriptor) { + auto *strided_descriptor = + reinterpret_cast *>( + descriptor.descriptor); + ::free(strided_descriptor->data); + }) + : freeFunc(freeFun) { + int64_t nElements = 1; + for (int64_t s : shape) + nElements *= s; + T *data, *alignedData; + std::tie(data, alignedData) = + detail::allocAligned(nElements, alloc, alignment); + descriptor = + detail::makeUnrankedDescriptor(data, alignedData, shape, alloc); + if (init) { + for (UnrankedMemrefIterator it = descriptor.begin(), + end = descriptor.end(); + it != end; ++it) + init(*it, it.getIndices()); + } else { + memset(descriptor.data(), 0, + nElements * sizeof(T) + + alignment.getValueOr(detail::nextPowerOf2(sizeof(T)))); + } + } + + /// Take ownership of an existing UnrankMemRef descriptor with a custom + /// deleter; + OwningUnrankedMemRef(FreeFunType freeFunc, DescriptorType descriptor) + : freeFunc(freeFunc), descriptor(descriptor) {} + ~OwningUnrankedMemRef() { freeFunc(descriptor); } + T &operator[](std::initializer_list indices) { + return descriptor[std::move(indices)]; + } + + DescriptorType &operator*() { return descriptor; } + DescriptorType *operator->() { return &descriptor; } + +private: + FreeFunType freeFunc; + DescriptorType descriptor; +}; + } // namespace mlir #endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -79,7 +79,6 @@ EXCLUDE_FROM_LIBMLIR ) -set_property(TARGET mlir_c_runner_utils PROPERTY CXX_STANDARD 11) target_compile_definitions(mlir_c_runner_utils PRIVATE mlir_c_runner_utils_EXPORTS) add_mlir_library(mlir_runner_utils diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -264,4 +264,175 @@ ASSERT_EQ(elt, coefficient * count++); } +TEST(NativeMemRefJit, UnrankedMemref_Rank0) { + OwningUnrankedMemRef A({}); + A[{}] = 42.; + ASSERT_EQ(*A->data(), 42); + A[{}] = 0; + std::string moduleStr = R"mlir( + func @unranked_zero(%arg0 : memref<*xf32>) attributes { llvm.emit_c_interface } { + %cst42 = constant 42.0 : f32 + %casted = memref_cast %arg0 : memref<*xf32> to memref + store %cst42, %casted[] : memref + return + } + )mlir"; + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + + llvm::Error error = jit->invoke("unranked_zero", &*A); + ASSERT_TRUE(!error); + EXPECT_EQ((A[{}]), 42.); + for (float &elt : *A) + EXPECT_EQ(&elt, &(A[{}])); +} + +TEST(NativeMemRefJit, UnrankedMemref_Rank1) { + int64_t shape[] = {9}; + OwningUnrankedMemRef A(shape); + int count = 1; + for (float &elt : *A) { + EXPECT_EQ(&elt, &(A[{count - 1}])); + elt = count++; + } + + std::string moduleStr = R"mlir( + func @unranked_one(%arg0 : memref<*xf32>) attributes { llvm.emit_c_interface } { + %cst42 = constant 42.0 : f32 + %cst5 = constant 5 : index + %casted = memref_cast %arg0 : memref<*xf32> to memref + store %cst42, %casted[%cst5] : memref + return + } + )mlir"; + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + + llvm::Error error = jit->invoke("unranked_one", &*A); + ASSERT_TRUE(!error); + count = 1; + for (float elt : *A) { + if (count == 6) + EXPECT_EQ(elt, 42.); + else + EXPECT_EQ(elt, count); + count++; + } +} + +TEST(NativeMemRefJit, UnrankedMemref_Rank2) { + constexpr int K = 3; + constexpr int M = 7; + // Prepare arguments beforehand. + auto init = [=](float &elt, ArrayRef indices) { + assert(indices.size() == 2); + elt = M * indices[0] + indices[1]; + }; + int64_t shape[] = {K, M}; + OwningUnrankedMemRef A(shape, init); + + for (int i = 0; i < K; ++i) + for (int j = 0; j < M; ++j) + EXPECT_EQ((A[{i, j}]), i * M + j); + + std::string moduleStr = R"mlir( + func @unranked_memref(%arg0 : memref<*xf32>, %arg1 : memref<*xf32>) attributes { llvm.emit_c_interface } { + %x = constant 2 : index + %y = constant 1 : index + %cst42 = constant 42.0 : f32 + %arg0_cast = memref_cast %arg0 : memref<*xf32> to memref + %arg1_cast = memref_cast %arg1 : memref<*xf32> to memref + store %cst42, %arg0_cast[%y, %x] : memref + store %cst42, %arg1_cast[%x, %y] : memref + return + } + )mlir"; + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + llvm::Error error = jit->invoke("unranked_memref", &*A, &*A); + ASSERT_TRUE(!error); + EXPECT_EQ((A[{1, 2}]), 42.); + EXPECT_EQ((A[{2, 1}]), 42.); +} + +// A helper function operating on unranked memref that will be called from the +// JIT. +static void unranked_memref_multiply(::UnrankedMemRefType *memref, + int32_t coefficient) { + for (float &elt : *memref) + elt *= coefficient; +} + +TEST(NativeMemRefJit, UnrankedMemrefCallback) { + constexpr int K = 3; + constexpr int M = 7; + // Prepare arguments beforehand. + int64_t shape[] = {K, M}; + OwningUnrankedMemRef A(shape); + int count = 0; + for (float &elt : *A) + elt = count++; + + for (int i = 0; i < K; ++i) + for (int j = 0; j < M; ++j) + EXPECT_EQ((A[{i, j}]), i * M + j); + + std::string moduleStr = R"mlir( + func private @callback(%arg0: memref<*xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } + func @unrankedcaller_for_callback(%arg0: memref<*xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { + call @callback(%arg0, %coefficient) : (memref<*xf32>, i32) -> () + return + } + )mlir"; + DialectRegistry registry; + registerAllDialects(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + auto jitOrError = ExecutionEngine::create(*module); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + // Define any extra symbols so they're available at runtime. + jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { + llvm::orc::SymbolMap symbolMap; + symbolMap[interner("_mlir_ciface_callback")] = + llvm::JITEvaluatedSymbol::fromPointer(unranked_memref_multiply); + return symbolMap; + }); + + int32_t coefficient = 7; + llvm::Error error = + jit->invoke("unrankedcaller_for_callback", &*A, coefficient); + ASSERT_TRUE(!error); + count = 0; + for (float elt : *A) + ASSERT_EQ(elt, coefficient * count++); +} + #endif // _WIN32