diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -149,7 +149,7 @@ return data[curOffset]; } - StridedMemrefIterator begin() { return {*this}; } + StridedMemrefIterator begin() { return {*this, offset}; } StridedMemrefIterator end() { return {*this, -1}; } // This operator[] is extremely slow and only for sugaring purposes. @@ -181,7 +181,7 @@ return (*this)[*indices.begin()]; } - StridedMemrefIterator begin() { return {*this}; } + StridedMemrefIterator begin() { return {*this, offset}; } StridedMemrefIterator end() { return {*this, -1}; } T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } @@ -202,8 +202,8 @@ return data[offset]; } - StridedMemrefIterator begin() { return {*this}; } - StridedMemrefIterator end() { return {*this, 1}; } + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, offset + 1}; } }; /// Iterate over all elements in a strided memref. @@ -364,7 +364,7 @@ return data[curOffset]; } - DynamicMemRefIterator begin() { return {*this}; } + DynamicMemRefIterator begin() { return {*this, offset}; } DynamicMemRefIterator end() { return {*this, -1}; } // This operator[] is extremely slow and only for sugaring purposes. diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt --- a/mlir/unittests/ExecutionEngine/CMakeLists.txt +++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRExecutionEngineTests DynamicMemRef.cpp + StridedMemRef.cpp Invoke.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp --- a/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp +++ b/mlir/unittests/ExecutionEngine/DynamicMemRef.cpp @@ -96,4 +96,29 @@ llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); EXPECT_THAT(values, ElementsAreArray(data)); -} \ No newline at end of file +} + +TEST(DynamicMemRef, rankOneWithOffset) { + constexpr int offset = 4; + std::array buffer; + + for (size_t i = 0; i < buffer.size(); ++i) { + buffer[i] = i; + } + + StridedMemRefType memRef; + memRef.basePtr = buffer.data(); + memRef.data = buffer.data(); + memRef.offset = offset; + memRef.sizes[0] = 3; + memRef.strides[0] = 1; + + DynamicMemRefType dynamicMemRef(memRef); + + llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); + + for (int64_t i = 0; i < 3; ++i) { + EXPECT_EQ(values[i], buffer[offset + i]); + EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]); + } +} diff --git a/mlir/unittests/ExecutionEngine/StridedMemRef.cpp b/mlir/unittests/ExecutionEngine/StridedMemRef.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/ExecutionEngine/StridedMemRef.cpp @@ -0,0 +1,41 @@ +//===- StridedMemRef.cpp ----------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "llvm/ADT/SmallVector.h" + +#include "gmock/gmock.h" + +using namespace ::mlir; +using namespace ::testing; + +TEST(StridedMemRef, rankOneWithOffset) { + std::array data; + + for (size_t i = 0; i < data.size(); ++i) { + data[i] = i; + } + + StridedMemRefType memRefA; + memRefA.basePtr = data.data(); + memRefA.data = data.data(); + memRefA.offset = 0; + memRefA.sizes[0] = 10; + memRefA.strides[0] = 1; + + StridedMemRefType memRefB = memRefA; + memRefB.offset = 5; + + llvm::SmallVector valuesA(memRefA.begin(), memRefA.end()); + llvm::SmallVector valuesB(memRefB.begin(), memRefB.end()); + + for (int64_t i = 0; i < 10; ++i) { + EXPECT_EQ(valuesA[i], i); + EXPECT_EQ(valuesA[i] + 5, valuesB[i]); + } +} \ No newline at end of file