Index: streamexecutor/include/streamexecutor/DeviceMemory.h =================================================================== --- streamexecutor/include/streamexecutor/DeviceMemory.h +++ streamexecutor/include/streamexecutor/DeviceMemory.h @@ -18,9 +18,9 @@ /// and a byte count to tell how much memory is pointed to by that void*. /// /// GlobalDeviceMemory is a subclass of GlobalDeviceMemoryBase which keeps -/// track of the type of element to be stored in the device array. It is similar -/// to a pair of a T* pointer and an element count to tell how many elements of -/// type T fit in the memory pointed to by that T*. +/// track of the type of element to be stored in the device memory. It is +/// similar to a pair of a T* pointer and an element count to tell how many +/// elements of type T fit in the memory pointed to by that T*. /// /// SharedDeviceMemoryBase is just the size in bytes of a shared memory buffer. /// @@ -38,6 +38,7 @@ #ifndef STREAMEXECUTOR_DEVICEMEMORY_H #define STREAMEXECUTOR_DEVICEMEMORY_H +#include #include namespace streamexecutor { @@ -91,6 +92,71 @@ size_t ByteCount; // Size in bytes of this allocation. }; +template class GlobalDeviceMemory; + +/// Reference to a slice of device memory. +/// +/// Contains a base memory handle, an element count offset into that base +/// memory, and an element count for the size of the slice. +template class GlobalDeviceMemorySlice { +public: + /// Intentionally implicit so GlobalDeviceMemory can be passed to functions + /// expecting GlobalDeviceMemorySlice arguments. + GlobalDeviceMemorySlice(const GlobalDeviceMemory &Memory) + : BaseMemory(Memory), ElementOffset(0), + ElementCount(Memory.getElementCount()) {} + + GlobalDeviceMemorySlice(const GlobalDeviceMemory &BaseMemory, + size_t ElementOffset, size_t ElementCount) + : BaseMemory(BaseMemory), ElementOffset(ElementOffset), + ElementCount(ElementCount) { + assert(ElementOffset + ElementCount <= BaseMemory.getElementCount() && + "slicing past the end of a GlobalDeviceMemory buffer"); + } + + /// Gets the GlobalDeviceMemory backing this slice. + GlobalDeviceMemory getBaseMemory() const { return BaseMemory; } + + /// Gets the offset of this slice from the base memory. + /// + /// The offset is measured in elements, not bytes. + size_t getElementOffset() const { return ElementOffset; } + + /// Gets the number of elements in this slice. + size_t getElementCount() const { return ElementCount; } + + /// Creates a slice of the memory with the first DropCount elements removed. + GlobalDeviceMemorySlice drop_front(size_t DropCount) const { + assert(DropCount <= ElementCount && + "dropping more than the size of a slice"); + return GlobalDeviceMemorySlice(BaseMemory, ElementOffset + DropCount, + ElementCount - DropCount); + } + + /// Creates a slice of the memory with the last DropCount elements removed. + GlobalDeviceMemorySlice drop_back(size_t DropCount) const { + assert(DropCount <= ElementCount && + "dropping more than the size of a slice"); + return GlobalDeviceMemorySlice(BaseMemory, ElementOffset, + ElementCount - DropCount); + } + + /// Creates a slice of the memory that chops off the first DropCount elements + /// and keeps the next TakeCount elements. + GlobalDeviceMemorySlice slice(size_t DropCount, + size_t TakeCount) const { + assert(DropCount + TakeCount <= ElementCount && + "sub-slice operation overruns slice bounds"); + return GlobalDeviceMemorySlice(BaseMemory, ElementOffset + DropCount, + TakeCount); + } + +private: + GlobalDeviceMemory BaseMemory; + size_t ElementOffset; + size_t ElementCount; +}; + /// Typed wrapper around the "void *"-like GlobalDeviceMemoryBase class. /// /// For example, GlobalDeviceMemory is a simple wrapper around @@ -125,6 +191,11 @@ /// allocation. size_t getElementCount() const { return getByteCount() / sizeof(ElemT); } + /// Converts this memory object into a slice. + GlobalDeviceMemorySlice asSlice() { + return GlobalDeviceMemorySlice(*this); + } + private: /// Constructs a GlobalDeviceMemory instance from an opaque handle and an /// element count. Index: streamexecutor/include/streamexecutor/Executor.h =================================================================== --- streamexecutor/include/streamexecutor/Executor.h +++ streamexecutor/include/streamexecutor/Executor.h @@ -16,12 +16,12 @@ #define STREAMEXECUTOR_EXECUTOR_H #include "streamexecutor/KernelSpec.h" +#include "streamexecutor/PlatformInterfaces.h" #include "streamexecutor/Utils/Error.h" namespace streamexecutor { class KernelInterface; -class PlatformExecutor; class Stream; class Executor { @@ -38,6 +38,144 @@ Expected> createStream(); + /// Allocates an array of ElementCount entries of type T in device memory. + template + Expected> allocateDeviceMemory(size_t ElementCount) { + return PExecutor->allocateDeviceMemory(ElementCount * sizeof(T)); + } + + /// Frees memory previously allocated with allocateDeviceMemory. + template Error freeDeviceMemory(GlobalDeviceMemory *Memory) { + return PExecutor->freeDeviceMemory(Memory); + } + + /// Allocates an array of ElementCount entries of type T in host memory. + /// + /// Host memory allocated by this function can be used for asynchronous memory + /// copies on streams. See Stream::thenMemcpyD2H and Stream::thenMemcpyH2D. + template Expected allocateHostMemory(size_t ElementCount) { + return PExecutor->allocateHostMemory(ElementCount * sizeof(T)); + } + + /// Frees memory previously allocated with allocateHostMemory. + template Error freeHostMemory(T *Memory) { + return PExecutor->freeHostMemory(Memory); + } + + /// Registers a previously allocated host array of type T for asynchronous + /// memory operations. + /// + /// Host memory registered by this function can be used for asynchronous + /// memory copies on streams. See Stream::thenMemcpyD2H and + /// Stream::thenMemcpyH2D. + template + Error registerHostMemory(T *Memory, size_t ElementCount) { + return PExecutor->registerHostMemory(Memory, ElementCount * sizeof(T)); + } + + /// Unregisters host memory previously registered by registerHostMemory. + template Error unregisterHostMemory(T *Memory) { + return PExecutor->unregisterHostMemory(Memory); + } + + /// Host-synchronously copies a slice of an array of elements of type T from + /// host to device memory. + /// + /// Returns an error if ElementCount is too large for the source slice or the + /// destination. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousCopyD2H(GlobalDeviceMemorySlice Src, + llvm::MutableArrayRef Dst, size_t ElementCount) { + if (ElementCount > Src.getElementCount()) + return make_error("copying too many elements, " + + llvm::Twine(ElementCount) + + ", from a device array sloce of element count " + + llvm::Twine(Src.getElementCount())); + if (Dst.size() < ElementCount) + return make_error( + "copying too many elements, " + llvm::Twine(ElementCount) + + ", to a host array of element count " + llvm::Twine(Dst.size())); + return PExecutor->synchronousCopyD2H( + Src.getBaseMemory(), Src.getElementOffset() * sizeof(T), Dst.data(), 0, + ElementCount * sizeof(T)); + } + + /// Host-synchronously copies a slice of an array of elements of type T from + /// host to device memory. + /// + /// Returns an error if the Src and Dst sizes do not match. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousCopyD2H(GlobalDeviceMemorySlice Src, + llvm::MutableArrayRef Dst) { + return synchronousCopyD2H(Src, Dst, Src.getElementCount()); + } + + /// Host-synchronously copies a slice of an array of elements of type T from + /// host to device memory. + /// + /// Returns an error if ElementCount is too large for the source slice. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousCopyD2H(GlobalDeviceMemorySlice Src, T *Dst, + size_t ElementCount) { + return synchronousCopyD2H(Src, llvm::MutableArrayRef(Dst, ElementCount), + ElementCount); + } + + /// Host-synchronously copies an array of elements of type T from host to + /// device memory. + /// + /// Returns an error if ElementCount is too large for the source slice or the + /// destination. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousCopyD2H(GlobalDeviceMemory Src, + llvm::MutableArrayRef Dst, size_t ElementCount) { + return synchronousCopyD2H(Src.asSlice(), Dst, ElementCount); + } + + /// Host-synchronously copies an array of elements of type T from host to + /// device memory. + /// + /// Returns an error if the Src and Dst sizes do not match. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousCopyD2H(GlobalDeviceMemory Src, + llvm::MutableArrayRef Dst) { + return synchronousCopyD2H(Src.asSlice(), Dst); + } + + /// Host-synchronously copies an array of elements of type T from host to + /// device memory. + /// + /// Returns an error if ElementCount is too large for the source slice. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousCopyD2H(GlobalDeviceMemory Src, T *Dst, + size_t ElementCount) { + return synchronousCopyD2H(Src.asSlice(), Dst, ElementCount); + } + private: PlatformExecutor *PExecutor; }; Index: streamexecutor/include/streamexecutor/PlatformInterfaces.h =================================================================== --- streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -76,22 +76,32 @@ } /// Copies data from the device to the host. - virtual Error memcpyD2H(PlatformStreamHandle *S, + /// + /// HostDst should have been allocated by allocateHostMemory or registered + /// with registerHostMemory. + virtual Error memcpyD2H(PlatformStreamHandle *S, void *HostDst, const GlobalDeviceMemoryBase &DeviceSrc, - void *HostDst, size_t ByteCount) { + size_t ByteCount, size_t SrcByteOffset = 0) { return make_error("memcpyD2H not implemented for platform " + getName()); } /// Copies data from the host to the device. - virtual Error memcpyH2D(PlatformStreamHandle *S, const void *HostSrc, - GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) { + /// + /// HostSrc should have been allocated by allocateHostMemory or registered + /// with registerHostMemory. + virtual Error memcpyH2D(PlatformStreamHandle *S, + GlobalDeviceMemoryBase *DeviceDst, + const void *HostSrc, size_t ByteCount, + size_t DstByteOffset = 0) { return make_error("memcpyH2D not implemented for platform " + getName()); } /// Copies data from one device location to another. virtual Error memcpyD2D(PlatformStreamHandle *S, + GlobalDeviceMemoryBase *DeviceDst, const GlobalDeviceMemoryBase &DeviceSrc, - GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) { + size_t ByteCount, size_t SrcByteOffset = 0, + size_t DstByteOffset = 0) { return make_error("memcpyD2D not implemented for platform " + getName()); } @@ -101,6 +111,81 @@ return make_error("blockHostUntilDone not implemented for platform " + getName()); } + + /// Allocates untyped device memory of a given size in bytes. + virtual Expected + allocateDeviceMemory(size_t ByteCount) { + return make_error("allocateDeviceMemory not implemented for platform " + + getName()); + } + + /// Frees device memory previously allocated by allocateDeviceMemory. + virtual Error freeDeviceMemory(GlobalDeviceMemoryBase *Memory) { + return make_error("freeDeviceMemory not implemented for platform " + + getName()); + } + + /// Allocates untyped host memory of a given size in bytes. + /// + /// Host memory allocated via this method is suitable for use with memcpyH2D + /// and memcpyD2H. + virtual Expected allocateHostMemory(size_t ByteCount) { + return make_error("allocateHostMemory not implemented for platform " + + getName()); + } + + /// Frees host memory allocated by allocateHostMemory. + virtual Error freeHostMemory(void *Memory) { + return make_error("freeHostMemory not implemented for platform " + + getName()); + } + + /// Registers previously allocated host memory so it can be used with + /// memcpyH2D and memcpyD2H. + virtual Error registerHostMemory(void *Memory, size_t ByteCount) { + return make_error("registerHostMemory not implemented for platform " + + getName()); + } + + /// Unregisters host memory previously registered with registerHostMemory. + virtual Error unregisterHostMemory(void *Memory) { + return make_error("unregisterHostMemory not implemented for platform " + + getName()); + } + + /// Copies the given number of bytes from device memory to host memory. + /// + /// Blocks the calling host thread until the copy is completed. Can operate on + /// any host memory, not just registered host memory or host memory allocated + /// by allocateHostMemory. + virtual Error synchronousCopyD2H(const GlobalDeviceMemoryBase &DeviceSrc, + size_t SrcByteOffset, void *HostDst, + size_t DstByteOffset, size_t ByteCount) { + return make_error("synchronousCopyD2H not implemented for platform " + + getName()); + } + + /// Copies the given number of bytes from host memory to device memory. + /// + /// Blocks the calling host thread until the copy is completed. Can operate on + /// any host memory, not just registered host memory or host memory allocated + /// by allocateHostMemory. + virtual Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset, + GlobalDeviceMemoryBase DeviceDst, + size_t DstByteOffset, size_t ByteCount) { + return make_error("synchronousCopyH2D not implemented for platform " + + getName()); + } + + /// Copies the given number of bytes from one location to another in device + /// memory. + virtual Error synchronousCopyD2D(GlobalDeviceMemoryBase DeviceDst, + size_t DstByteOffset, + const GlobalDeviceMemoryBase &DeviceSrc, + size_t SrcByteOffset, size_t ByteCount) { + return make_error("synchronousCopyD2D not implemented for platform " + + getName()); + } }; } // namespace streamexecutor Index: streamexecutor/include/streamexecutor/Stream.h =================================================================== --- streamexecutor/include/streamexecutor/Stream.h +++ streamexecutor/include/streamexecutor/Stream.h @@ -106,27 +106,38 @@ /// Executor::allocateHostMemory or otherwise allocated and then /// registered with Executor::registerHostMemory. template - Stream &thenMemcpyD2H(const GlobalDeviceMemory &DeviceSrc, - llvm::MutableArrayRef HostDst, size_t ElementCount) { - if (ElementCount > DeviceSrc.getElementCount()) + Stream &thenMemcpyD2H(llvm::MutableArrayRef HostDst, + const GlobalDeviceMemory &DeviceSrc, + size_t ElementCount, size_t SrcElementOffset = 0) { + if (ElementCount + SrcElementOffset > DeviceSrc.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", from device memory array of size " + + ", at element offset " + llvm::Twine(SrcElementOffset) + + " from device memory array of element count " + llvm::Twine(DeviceSrc.getElementCount())); else if (ElementCount > HostDst.size()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", to host array of size " + llvm::Twine(HostDst.size())); + ", to host array of element count " + + llvm::Twine(HostDst.size())); else - setError(PExecutor->memcpyD2H(ThePlatformStream.get(), DeviceSrc, - HostDst.data(), ElementCount * sizeof(T))); + setError(PExecutor->memcpyD2H(ThePlatformStream.get(), HostDst.data(), + DeviceSrc, ElementCount * sizeof(T), + SrcElementOffset * sizeof(T))); return *this; } /// Same as thenMemcpyD2H above, but copies the entire source to the /// destination. template - Stream &thenMemcpyD2H(const GlobalDeviceMemory &DeviceSrc, - llvm::MutableArrayRef HostDst) { - return thenMemcpyD2H(DeviceSrc, HostDst, DeviceSrc.getElementCount()); + Stream &thenMemcpyD2H(llvm::MutableArrayRef HostDst, + const GlobalDeviceMemory &DeviceSrc) { + if (DeviceSrc.getElementCount() != HostDst.size()) { + setError("thenMemcpyD2H device source element count, " + + llvm::Twine(DeviceSrc.getElementCount()) + + " ,does not match host destination element count, " + + llvm::Twine(HostDst.size())); + return *this; + } + return thenMemcpyD2H(HostDst, DeviceSrc, DeviceSrc.getElementCount()); } /// Entrain onto the stream a memcpy of a given number of elements from a host @@ -136,54 +147,78 @@ /// Executor::allocateHostMemory or otherwise allocated and then /// registered with Executor::registerHostMemory. template - Stream &thenMemcpyH2D(llvm::ArrayRef HostSrc, - GlobalDeviceMemory *DeviceDst, size_t ElementCount) { + Stream &thenMemcpyH2D(GlobalDeviceMemory *DeviceDst, + llvm::ArrayRef HostSrc, size_t ElementCount, + size_t DstElementOffset = 0) { if (ElementCount > HostSrc.size()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", from host array of size " + llvm::Twine(HostSrc.size())); - else if (ElementCount > DeviceDst->getElementCount()) + ", from host array of element count " + + llvm::Twine(HostSrc.size())); + else if (DstElementOffset + ElementCount > DeviceDst->getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", to device memory array of size " + - llvm::Twine(DeviceDst->getElementCount())); + ", to device memory array of element count " + + llvm::Twine(DeviceDst->getElementCount()) + + " at element offset " + llvm::Twine(DstElementOffset)); else - setError(PExecutor->memcpyH2D(ThePlatformStream.get(), HostSrc.data(), - DeviceDst, ElementCount * sizeof(T))); + setError(PExecutor->memcpyH2D(ThePlatformStream.get(), DeviceDst, + HostSrc.data(), ElementCount * sizeof(T), + DstElementOffset * sizeof(T))); return *this; } /// Same as thenMemcpyH2D above, but copies the entire source to the /// destination. template - Stream &thenMemcpyH2D(llvm::ArrayRef HostSrc, - GlobalDeviceMemory *DeviceDst) { - return thenMemcpyH2D(HostSrc, DeviceDst, HostSrc.size()); + Stream &thenMemcpyH2D(GlobalDeviceMemory *DeviceDst, + llvm::ArrayRef HostSrc) { + if (HostSrc.size() != DeviceDst->getElementCount()) { + setError("thenMemcpyH2D host source element count, " + + llvm::Twine(HostSrc.size()) + + " ,does not match device destination element count, " + + llvm::Twine(DeviceDst->getElementCount())); + return *this; + } + return thenMemcpyH2D(DeviceDst, HostSrc, HostSrc.size()); } /// Entrain onto the stream a memcpy of a given number of elements from a /// device source to a device destination. template - Stream &thenMemcpyD2D(const GlobalDeviceMemory &DeviceSrc, - GlobalDeviceMemory *DeviceDst, size_t ElementCount) { - if (ElementCount > DeviceSrc.getElementCount()) + Stream &thenMemcpyD2D(GlobalDeviceMemory *DeviceDst, + const GlobalDeviceMemory &DeviceSrc, + size_t ElementCount, size_t SrcElementOffset = 0, + size_t DstElementOffset = 0) { + if (SrcElementOffset + ElementCount > DeviceSrc.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", from device memory array of size " + + ", at element offset " + llvm::Twine(SrcElementOffset) + + " from device memory array of element count " + llvm::Twine(DeviceSrc.getElementCount())); - else if (ElementCount > DeviceDst->getElementCount()) + else if (DstElementOffset + ElementCount > DeviceDst->getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", to device memory array of size " + - llvm::Twine(DeviceDst->getElementCount())); + ", to device memory array of element count " + + llvm::Twine(DeviceDst->getElementCount()) + + " at element offset " + llvm::Twine(DstElementOffset)); else - setError(PExecutor->memcpyD2D(ThePlatformStream.get(), DeviceSrc, - DeviceDst, ElementCount * sizeof(T))); + setError(PExecutor->memcpyD2D(ThePlatformStream.get(), DeviceDst, + DeviceSrc, ElementCount * sizeof(T), + SrcElementOffset * sizeof(T), + DstElementOffset * sizeof(T))); return *this; } /// Same as thenMemcpyD2D above, but copies the entire source to the /// destination. template - Stream &thenMemcpyD2D(const GlobalDeviceMemory &DeviceSrc, - GlobalDeviceMemory *DeviceDst) { - return thenMemcpyD2D(DeviceSrc, DeviceDst, DeviceSrc.getElementCount()); + Stream &thenMemcpyD2D(GlobalDeviceMemory *DeviceDst, + const GlobalDeviceMemory &DeviceSrc) { + if (DeviceSrc.getElementCount() != DeviceDst->getElementCount()) { + setError("thenMemcpyH2D device source element count, " + + llvm::Twine(DeviceSrc.getElementCount()) + + " ,does not match device destination element count, " + + llvm::Twine(DeviceDst->getElementCount())); + return *this; + } + return thenMemcpyD2D(DeviceDst, DeviceSrc, DeviceSrc.getElementCount()); } /// Blocks the host code, waiting for the operations entrained on the stream Index: streamexecutor/include/streamexecutor/Utils/Error.h =================================================================== --- streamexecutor/include/streamexecutor/Utils/Error.h +++ streamexecutor/include/streamexecutor/Utils/Error.h @@ -169,10 +169,10 @@ using llvm::consumeError; using llvm::Error; using llvm::Expected; -using llvm::StringRef; +using llvm::Twine; // Makes an Error object from an error message. -Error make_error(StringRef Message); +Error make_error(Twine Message); // Consumes the input error and returns its error message. // Index: streamexecutor/lib/Utils/Error.cpp =================================================================== --- streamexecutor/lib/Utils/Error.cpp +++ streamexecutor/lib/Utils/Error.cpp @@ -44,8 +44,8 @@ namespace streamexecutor { -Error make_error(StringRef Message) { - return llvm::make_error(Message); +Error make_error(Twine Message) { + return llvm::make_error(Message.str()); } std::string consumeAndGetMessage(Error &&E) { Index: streamexecutor/lib/unittests/CMakeLists.txt =================================================================== --- streamexecutor/lib/unittests/CMakeLists.txt +++ streamexecutor/lib/unittests/CMakeLists.txt @@ -1,4 +1,14 @@ add_executable( + executor_test + ExecutorTest.cpp) +target_link_libraries( + executor_test + streamexecutor + ${GTEST_BOTH_LIBRARIES} + ${CMAKE_THREAD_LIBS_INIT}) +add_test(ExecutorTest executor_test) + +add_executable( kernel_test KernelTest.cpp) target_link_libraries( Index: streamexecutor/lib/unittests/ExecutorTest.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/unittests/ExecutorTest.cpp @@ -0,0 +1,177 @@ +//===-- ExecutorTest.cpp - Tests for Executor -----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the unit tests for Executor code. +/// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "streamexecutor/Executor.h" +#include "streamexecutor/PlatformInterfaces.h" + +#include "gtest/gtest.h" + +namespace { + +namespace se = ::streamexecutor; + +class MockPlatformExecutor : public se::PlatformExecutor { +public: + ~MockPlatformExecutor() override {} + + std::string getName() const override { return "MockPlatformExecutor"; } + + se::Expected> + createStream() override { + return se::make_error("not implemented"); + } + + se::Expected + allocateDeviceMemory(size_t ByteCount) override { + return se::GlobalDeviceMemoryBase(std::malloc(ByteCount)); + } + + se::Error freeDeviceMemory(se::GlobalDeviceMemoryBase *Memory) override { + std::free(const_cast(Memory->getHandle())); + return se::Error::success(); + } + + se::Expected allocateHostMemory(size_t ByteCount) override { + return std::malloc(ByteCount); + } + + se::Error freeHostMemory(void *Memory) override { + std::free(Memory); + return se::Error::success(); + } + + se::Error synchronousCopyD2H(const se::GlobalDeviceMemoryBase &DeviceSrc, + size_t SrcByteOffset, void *HostDst, + size_t DstByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(HostDst) + DstByteOffset, + static_cast(DeviceSrc.getHandle()) + + SrcByteOffset, + ByteCount); + return se::Error::success(); + } + + se::Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset, + se::GlobalDeviceMemoryBase DeviceDst, + size_t DstByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(const_cast(DeviceDst.getHandle())) + + DstByteOffset, + static_cast(HostSrc) + SrcByteOffset, ByteCount); + return se::Error::success(); + } + + se::Error synchronousCopyD2D(se::GlobalDeviceMemoryBase DeviceDst, + size_t DstByteOffset, + const se::GlobalDeviceMemoryBase &DeviceSrc, + size_t SrcByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(const_cast(DeviceDst.getHandle())) + + DstByteOffset, + static_cast(DeviceSrc.getHandle()) + + SrcByteOffset, + ByteCount); + return se::Error::success(); + } +}; + +/// Test fixture to hold objects used by tests. +class ExecutorTest : public ::testing::Test { +public: + ExecutorTest() + : HostA{0, 1, 2, 3, 4}, + DeviceA(se::GlobalDeviceMemory::makeFromElementCount(HostA, 5)), + HostB{5, 6, 7, 8, 9}, + DeviceB(se::GlobalDeviceMemory::makeFromElementCount(HostB, 5)), + Host{10, 11, 12, 13, 14}, Executor(&PExecutor) {} + + // Device memory is backed by host arrays. + int HostA[5]; + se::GlobalDeviceMemory DeviceA; + int HostB[5]; + se::GlobalDeviceMemory DeviceB; + + // Host memory to be used as actual host memory. + int Host[5]; + + MockPlatformExecutor PExecutor; + se::Executor Executor; +}; + +TEST_F(ExecutorTest, SyncCopyD2HFullArrayEqualSize) { + EXPECT_FALSE(static_cast( + Executor.synchronousCopyD2H(DeviceA, llvm::MutableArrayRef(Host)))); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(HostA[i], Host[i]); + } +} + +TEST_F(ExecutorTest, SyncCopyD2HSliceEqualSize) { + EXPECT_FALSE(static_cast(Executor.synchronousCopyD2H( + DeviceA.asSlice().slice(1, 2), llvm::MutableArrayRef(Host + 1, 2)))); + for (int i = 1; i < 3; ++i) { + EXPECT_EQ(HostA[i], Host[i]); + } +} + +TEST_F(ExecutorTest, SyncCopyD2HArraySpecifiedSize) { + EXPECT_FALSE(static_cast(Executor.synchronousCopyD2H( + DeviceA, llvm::MutableArrayRef(Host), 2))); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(HostA[i], Host[i]); + } +} + +TEST_F(ExecutorTest, SyncCopyD2HSliceSpecifiedSize) { + EXPECT_FALSE(static_cast(Executor.synchronousCopyD2H( + DeviceA.asSlice().slice(0, 2), llvm::MutableArrayRef(Host), 2))); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(HostA[i], Host[i]); + } +} + +TEST_F(ExecutorTest, SyncCopyD2HArrayToPointerSpecifiedSize) { + EXPECT_FALSE( + static_cast(Executor.synchronousCopyD2H(DeviceA, Host, 2))); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(HostA[i], Host[i]); + } +} + +TEST_F(ExecutorTest, SyncCopyD2HSliceToPointerSpecifiedSize) { + EXPECT_FALSE(static_cast( + Executor.synchronousCopyD2H(DeviceA.asSlice().slice(0, 2), Host, 2))); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(HostA[i], Host[i]); + } +} + +TEST_F(ExecutorTest, SyncCopyD2HArrayTooManyElements) { + se::Error E = Executor.synchronousCopyD2H( + DeviceA, llvm::MutableArrayRef(Host), 20); + EXPECT_TRUE(static_cast(E)); + se::consumeError(std::move(E)); +} + +TEST_F(ExecutorTest, SyncCopyD2HSliceTooManyElements) { + se::Error E = Executor.synchronousCopyD2H( + DeviceA.asSlice(), llvm::MutableArrayRef(Host), 20); + EXPECT_TRUE(static_cast(E)); + se::consumeError(std::move(E)); +} + +} // namespace Index: streamexecutor/lib/unittests/StreamTest.cpp =================================================================== --- streamexecutor/lib/unittests/StreamTest.cpp +++ streamexecutor/lib/unittests/StreamTest.cpp @@ -40,26 +40,36 @@ return nullptr; } - se::Error memcpyD2H(se::PlatformStreamHandle *, + se::Error memcpyD2H(se::PlatformStreamHandle *, void *HostDst, const se::GlobalDeviceMemoryBase &DeviceSrc, - void *HostDst, size_t ByteCount) override { - std::memcpy(HostDst, DeviceSrc.getHandle(), ByteCount); + size_t ByteCount, size_t SrcByteOffset) override { + std::memcpy(HostDst, static_cast(DeviceSrc.getHandle()) + + SrcByteOffset, + ByteCount); return se::Error::success(); } - se::Error memcpyH2D(se::PlatformStreamHandle *, const void *HostSrc, + se::Error memcpyH2D(se::PlatformStreamHandle *, se::GlobalDeviceMemoryBase *DeviceDst, - size_t ByteCount) override { - std::memcpy(const_cast(DeviceDst->getHandle()), HostSrc, ByteCount); + const void *HostSrc, size_t ByteCount, + size_t DstByteOffset) override { + std::memcpy( + static_cast(const_cast(DeviceDst->getHandle())) + + DstByteOffset, + HostSrc, ByteCount); return se::Error::success(); } se::Error memcpyD2D(se::PlatformStreamHandle *, - const se::GlobalDeviceMemoryBase &DeviceSrc, se::GlobalDeviceMemoryBase *DeviceDst, - size_t ByteCount) override { - std::memcpy(const_cast(DeviceDst->getHandle()), - DeviceSrc.getHandle(), ByteCount); + const se::GlobalDeviceMemoryBase &DeviceSrc, + size_t ByteCount, size_t SrcByteOffset, + size_t DstByteOffset) override { + std::memcpy( + static_cast(const_cast(DeviceDst->getHandle())) + + SrcByteOffset, + static_cast(DeviceSrc.getHandle()) + DstByteOffset, + ByteCount); return se::Error::success(); } }; @@ -87,28 +97,28 @@ }; TEST_F(StreamTest, MemcpyCorrectSize) { - Stream.thenMemcpyH2D(llvm::ArrayRef(Host), &DeviceA); + Stream.thenMemcpyH2D(&DeviceA, llvm::ArrayRef(Host)); EXPECT_TRUE(Stream.isOK()); - Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef(Host)); + Stream.thenMemcpyD2H(llvm::MutableArrayRef(Host), DeviceA); EXPECT_TRUE(Stream.isOK()); - Stream.thenMemcpyD2D(DeviceA, &DeviceB); + Stream.thenMemcpyD2D(&DeviceB, DeviceA); EXPECT_TRUE(Stream.isOK()); } TEST_F(StreamTest, MemcpyH2DTooManyElements) { - Stream.thenMemcpyH2D(llvm::ArrayRef(Host), &DeviceA, 20); + Stream.thenMemcpyH2D(&DeviceA, llvm::ArrayRef(Host), 20); EXPECT_FALSE(Stream.isOK()); } TEST_F(StreamTest, MemcpyD2HTooManyElements) { - Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef(Host), 20); + Stream.thenMemcpyD2H(llvm::MutableArrayRef(Host), DeviceA, 20); EXPECT_FALSE(Stream.isOK()); } TEST_F(StreamTest, MemcpyD2DTooManyElements) { - Stream.thenMemcpyD2D(DeviceA, &DeviceB, 20); + Stream.thenMemcpyD2D(&DeviceB, DeviceA, 20); EXPECT_FALSE(Stream.isOK()); }