Index: streamexecutor/include/streamexecutor/Executor.h =================================================================== --- streamexecutor/include/streamexecutor/Executor.h +++ streamexecutor/include/streamexecutor/Executor.h @@ -16,12 +16,12 @@ #define STREAMEXECUTOR_EXECUTOR_H #include "streamexecutor/KernelSpec.h" +#include "streamexecutor/PlatformInterfaces.h" #include "streamexecutor/Utils/Error.h" namespace streamexecutor { class KernelInterface; -class PlatformExecutor; class Stream; class Executor { @@ -38,6 +38,137 @@ Expected> createStream(); + /// Allocates an array of ElementCount entries of type T in device memory. + template + Expected> allocateDeviceMemory(size_t ElementCount) { + return PExecutor->allocateDeviceMemory(ElementCount * sizeof(T)); + } + + /// Frees memory previously allocated with allocateDeviceMemory. + template Error freeDeviceMemory(GlobalDeviceMemory *Memory) { + return PExecutor->freeDeviceMemory(Memory); + } + + /// Allocates an array of ElementCount entries of type T in host memory. + /// + /// Host memory allocated by this function can be used for asynchronous memory + /// copies on streams. See Stream::thenMemcpyD2H and Stream::thenMemcpyH2D. + template Expected allocateHostMemory(size_t ElementCount) { + return PExecutor->allocateHostMemory(ElementCount * sizeof(T)); + } + + /// Frees memory previously allocated with allocateHostMemory. + template Error freeHostMemory(T *Memory) { + return PExecutor->freeHostMemory(Memory); + } + + /// Registers a previously allocated host array of type T for asynchronous + /// memory operations. + /// + /// Host memory registered by this function can be used for asynchronous + /// memory copies on streams. See Stream::thenMemcpyD2H and + /// Stream::thenMemcpyH2D. + template Error registerHostMemory(T *Memory) { + return PExecutor->registerHostMemory(Memory); + } + + /// Unregisters host memory previously registered by registerHostMemory. + template Error unregisterHostMemory(T *Memory) { + return PExecutor->unregisterHostMemory(Memory); + } + + /// Host-synchronously copies a slice of an array of elements of type T from + /// device memory to host memory. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousMemcpyD2H(const GlobalDeviceMemory &DeviceSrc, + llvm::MutableArrayRef HostDst, + size_t ElementCount) { + if (ElementCount > DeviceSrc.getElementCount()) + return make_error("copying too many elements, " + + llvm::Twine(ElementCount) + + ", from device memory array of size " + + llvm::Twine(DeviceSrc.getElementCount())); + else if (ElementCount > HostDst.size()) + return make_error("copying too many elements, " + + llvm::Twine(ElementCount) + ", to host array of size " + + llvm::Twine(HostDst.size())); + return PExecutor->synchronousMemcpyD2H(DeviceSrc, HostDst.data(), + ElementCount * sizeof(T)); + } + + /// Just like synchronousMemcpyD2H above, but copies the entire source array + /// to the destination. + template + Error synchronousMemcpyD2H(const GlobalDeviceMemory &DeviceSrc, + llvm::MutableArrayRef HostDst) { + return synchronousMemcpyD2H(DeviceSrc, HostDst, + DeviceSrc.getElementCount()); + } + + /// Host-synchronously copies a slice of an array of elements of type T from + /// host memory to device memory. + /// + /// The calling host thread is blocked until the copy completes. Can be used + /// with any host memory, the host memory does not have to be allocated with + /// allocateHostMemory or registered with registerHostMemory. + template + Error synchronousMemcpyH2D(llvm::ArrayRef HostSrc, + GlobalDeviceMemory *DeviceDst, + size_t ElementCount) { + if (ElementCount > HostSrc.size()) + return make_error( + "copying too many elements, " + llvm::Twine(ElementCount) + + ", from host array of size " + llvm::Twine(HostSrc.size())); + else if (ElementCount > DeviceDst->getElementCount()) + return make_error("copying too many elements, " + + llvm::Twine(ElementCount) + + ", to device memory array of size " + + llvm::Twine(DeviceDst->getElementCount())); + return PExecutor->synchronousMemcpyH2D(HostSrc.data(), DeviceDst, + ElementCount * sizeof(T)); + } + + /// Just like synchronousMemcpyH2D above, but copies the entire source array + /// to the destination. + template + Error synchronousMemcpyH2D(llvm::ArrayRef HostSrc, + GlobalDeviceMemory *DeviceDst) { + return synchronousMemcpyH2D(HostSrc, DeviceDst, HostSrc.size()); + } + + /// Host-synchronously copies a slice of an array of elements of type T from + /// one place in device memroy to another. + template + Error synchronousMemcpyD2D(const GlobalDeviceMemory &DeviceSrc, + GlobalDeviceMemory *DeviceDst, + size_t ElementCount) { + if (ElementCount > DeviceSrc.getElementCount()) + return make_error("copying too many elements, " + + llvm::Twine(ElementCount) + + ", from device memory array of size " + + llvm::Twine(DeviceSrc.getElementCount())); + else if (ElementCount > DeviceDst->getElementCount()) + return make_error("copying too many elements, " + + llvm::Twine(ElementCount) + + ", to device memory array of size " + + llvm::Twine(DeviceDst->getElementCount())); + return PExecutor->synchronousMemcpyD2D(DeviceSrc, DeviceDst, + ElementCount * sizeof(T)); + } + + /// Just like synchronousMemcpyD2D above, but copies the entire source array + /// to the destination. + template + Error synchronousMemcpyD2D(const GlobalDeviceMemory &DeviceSrc, + GlobalDeviceMemory *DeviceDst) { + return synchronousMemcpyD2D(DeviceSrc, DeviceDst, + DeviceSrc.getElementCount()); + } + private: PlatformExecutor *PExecutor; }; Index: streamexecutor/include/streamexecutor/PlatformInterfaces.h =================================================================== --- streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -76,6 +76,9 @@ } /// Copies data from the device to the host. + /// + /// HostDst should have been allocated by allocateHostMemory or registered + /// with registerHostMemory. virtual Error memcpyD2H(PlatformStreamHandle *S, const GlobalDeviceMemoryBase &DeviceSrc, void *HostDst, size_t ByteCount) { @@ -83,6 +86,9 @@ } /// Copies data from the host to the device. + /// + /// HostSrc should have been allocated by allocateHostMemory or registered + /// with registerHostMemory. virtual Error memcpyH2D(PlatformStreamHandle *S, const void *HostSrc, GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) { return make_error("memcpyH2D not implemented for platform " + getName()); @@ -101,6 +107,79 @@ return make_error("blockHostUntilDone not implemented for platform " + getName()); } + + /// Allocates untyped device memory of a given size in bytes. + virtual Expected + allocateDeviceMemory(size_t ByteCount) { + return make_error("allocateDeviceMemory not implemented for platform " + + getName()); + } + + /// Frees device memory previously allocated by allocateDeviceMemory. + virtual Error freeDeviceMemory(GlobalDeviceMemoryBase *Memory) { + return make_error("freeDeviceMemory not implemented for platform " + + getName()); + } + + /// Allocates untyped host memory of a given size in bytes. + /// + /// Host memory allocated via this method is suitable for use with memcpyH2D + /// and memcpyD2H. + virtual Expected allocateHostMemory(size_t ByteCount) { + return make_error("allocateHostMemory not implemented for platform " + + getName()); + } + + /// Frees host memory allocated by allocateHostMemory. + virtual Error freeHostMemory(void *Memory) { + return make_error("freeHostMemory not implemented for platform " + + getName()); + } + + /// Registers previously allocated host memory so it can be used with + /// memcpyH2D and memcpyD2H. + virtual Error registerHostMemory(void *Memory) { + return make_error("unregisterHostMemory not implemented for platform " + + getName()); + } + + /// Unregisters host memory previously registered with registerHostMemory. + virtual Error unregisterHostMemory(void *Memory) { + return make_error("unregisterHostMemory not implemented for platform " + + getName()); + } + + /// Copies the given number of bytes from device memory to host memory. + /// + /// Blocks the calling host thread until the copy is completed. Can operate on + /// any host memory, not just registered host memory or host memory allocated + /// by allocateHostMemory. + virtual Error synchronousMemcpyD2H(const GlobalDeviceMemoryBase &DeviceSrc, + void *HostDst, size_t ByteCount) { + return make_error("synchronousMemcpyD2H not implemented for platform " + + getName()); + } + + /// Copies the given number of bytes from host memory to device memory. + /// + /// Blocks the calling host thread until the copy is completed. Can operate on + /// any host memory, not just registered host memory or host memory allocated + /// by allocateHostMemory. + virtual Error synchronousMemcpyH2D(const void *HostSrc, + GlobalDeviceMemoryBase *DeviceDst, + size_t ByteCount) { + return make_error("synchronousMemcpyH2D not implemented for platform " + + getName()); + } + + /// Copies the given number of bytes from one location to another in device + /// memory. + virtual Error synchronousMemcpyD2D(const GlobalDeviceMemoryBase &DeviceSrc, + GlobalDeviceMemoryBase *DeviceDst, + size_t ByteCount) { + return make_error("synchronousMemcpyD2D not implemented for platform " + + getName()); + } }; } // namespace streamexecutor Index: streamexecutor/include/streamexecutor/Utils/Error.h =================================================================== --- streamexecutor/include/streamexecutor/Utils/Error.h +++ streamexecutor/include/streamexecutor/Utils/Error.h @@ -169,10 +169,10 @@ using llvm::consumeError; using llvm::Error; using llvm::Expected; -using llvm::StringRef; +using llvm::Twine; // Makes an Error object from an error message. -Error make_error(StringRef Message); +Error make_error(Twine Message); // Consumes the input error and returns its error message. // Index: streamexecutor/lib/Utils/Error.cpp =================================================================== --- streamexecutor/lib/Utils/Error.cpp +++ streamexecutor/lib/Utils/Error.cpp @@ -44,8 +44,8 @@ namespace streamexecutor { -Error make_error(StringRef Message) { - return llvm::make_error(Message); +Error make_error(Twine Message) { + return llvm::make_error(Message.str()); } std::string consumeAndGetMessage(Error &&E) { Index: streamexecutor/lib/unittests/CMakeLists.txt =================================================================== --- streamexecutor/lib/unittests/CMakeLists.txt +++ streamexecutor/lib/unittests/CMakeLists.txt @@ -1,4 +1,14 @@ add_executable( + executor_test + ExecutorTest.cpp) +target_link_libraries( + executor_test + streamexecutor + ${GTEST_BOTH_LIBRARIES} + ${CMAKE_THREAD_LIBS_INIT}) +add_test(ExecutorTest executor_test) + +add_executable( kernel_test KernelTest.cpp) target_link_libraries( Index: streamexecutor/lib/unittests/ExecutorTest.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/unittests/ExecutorTest.cpp @@ -0,0 +1,129 @@ +//===-- ExecutorTest.cpp - Tests for Executor -----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the unit tests for Executor code. +/// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "streamexecutor/Executor.h" +#include "streamexecutor/PlatformInterfaces.h" + +#include "gtest/gtest.h" + +namespace { + +namespace se = ::streamexecutor; + +class MockPlatformExecutor : public se::PlatformExecutor { +public: + ~MockPlatformExecutor() override {} + + std::string getName() const override { return "MockPlatformExecutor"; } + + se::Expected> + createStream() override { + return se::make_error("not implemented"); + } + + se::Expected + allocateDeviceMemory(size_t ByteCount) override { + return se::GlobalDeviceMemoryBase(std::malloc(ByteCount)); + } + + se::Error freeDeviceMemory(se::GlobalDeviceMemoryBase *Memory) override { + std::free(const_cast(Memory->getHandle())); + return se::Error::success(); + } + + se::Expected allocateHostMemory(size_t ByteCount) override { + return std::malloc(ByteCount); + } + + se::Error freeHostMemory(void *Memory) override { + std::free(Memory); + return se::Error::success(); + } + + se::Error synchronousMemcpyD2H(const se::GlobalDeviceMemoryBase &DeviceSrc, + void *HostDst, size_t ByteCount) override { + std::memcpy(HostDst, DeviceSrc.getHandle(), ByteCount); + return se::Error::success(); + } + + se::Error synchronousMemcpyH2D(const void *HostSrc, + se::GlobalDeviceMemoryBase *DeviceDst, + size_t ByteCount) override { + std::memcpy(const_cast(DeviceDst->getHandle()), HostSrc, ByteCount); + return se::Error::success(); + } + + se::Error synchronousMemcpyD2D(const se::GlobalDeviceMemoryBase &DeviceSrc, + se::GlobalDeviceMemoryBase *DeviceDst, + size_t ByteCount) override { + std::memcpy(const_cast(DeviceDst->getHandle()), + DeviceSrc.getHandle(), ByteCount); + return se::Error::success(); + } +}; + +/// Test fixture to hold objects used by tests. +class ExecutorTest : public ::testing::Test { +public: + ExecutorTest() + : DeviceA(se::GlobalDeviceMemory::makeFromElementCount(HostA, 10)), + DeviceB(se::GlobalDeviceMemory::makeFromElementCount(HostB, 10)), + Executor(&PExecutor) {} + + // Device memory is backed by host arrays. + int HostA[10]; + se::GlobalDeviceMemory DeviceA; + int HostB[10]; + se::GlobalDeviceMemory DeviceB; + + // Host memory to be used as actual host memory. + int Host[10]; + + MockPlatformExecutor PExecutor; + se::Executor Executor; +}; + +TEST_F(ExecutorTest, MemcpyCorrectSize) { + EXPECT_FALSE(static_cast( + Executor.synchronousMemcpyH2D(llvm::ArrayRef(Host), &DeviceA))); + EXPECT_FALSE(static_cast(Executor.synchronousMemcpyD2H( + DeviceA, llvm::MutableArrayRef(Host)))); + EXPECT_FALSE( + static_cast(Executor.synchronousMemcpyD2D(DeviceA, &DeviceB))); +} + +TEST_F(ExecutorTest, MemcpyH2DTooManyElements) { + se::Error E = + Executor.synchronousMemcpyH2D(llvm::ArrayRef(Host), &DeviceA, 20); + EXPECT_TRUE(static_cast(E)); + se::consumeError(std::move(E)); +} + +TEST_F(ExecutorTest, MemcpyD2HTooManyElements) { + se::Error E = Executor.synchronousMemcpyD2H( + DeviceA, llvm::MutableArrayRef(Host), 20); + EXPECT_TRUE(static_cast(E)); + se::consumeError(std::move(E)); +} + +TEST_F(ExecutorTest, MemcpyD2DTooManyElements) { + se::Error E = Executor.synchronousMemcpyD2D(DeviceA, &DeviceB, 20); + EXPECT_TRUE(static_cast(E)); + se::consumeError(std::move(E)); +} + +} // namespace