Index: streamexecutor/include/streamexecutor/Device.h =================================================================== --- streamexecutor/include/streamexecutor/Device.h +++ streamexecutor/include/streamexecutor/Device.h @@ -55,16 +55,17 @@ /// Allocates an array of ElementCount entries of type T in device memory. template Expected> allocateDeviceMemory(size_t ElementCount) { - Expected MaybeBase = + Expected MaybeMemory = PDevice->allocateDeviceMemory(ElementCount * sizeof(T)); - if (!MaybeBase) - return MaybeBase.takeError(); - return GlobalDeviceMemory(*MaybeBase); + if (!MaybeMemory) + return MaybeMemory.takeError(); + return GlobalDeviceMemory::makeFromElementCount(*MaybeMemory, + ElementCount); } /// Frees memory previously allocated with allocateDeviceMemory. template Error freeDeviceMemory(GlobalDeviceMemory Memory) { - return PDevice->freeDeviceMemory(Memory); + return PDevice->freeDeviceMemory(Memory.getHandle()); } /// Allocates an array of ElementCount entries of type T in host memory. @@ -139,7 +140,7 @@ return make_error( "copying too many elements, " + llvm::Twine(ElementCount) + ", to a host array of element count " + llvm::Twine(Dst.size())); - return PDevice->synchronousCopyD2H(Src.getBaseMemory(), + return PDevice->synchronousCopyD2H(Src.getBaseMemory().getHandle(), Src.getElementOffset() * sizeof(T), Dst.data(), 0, ElementCount * sizeof(T)); } @@ -193,9 +194,9 @@ llvm::Twine(ElementCount) + ", to a device array of element count " + llvm::Twine(Dst.getElementCount())); - return PDevice->synchronousCopyH2D(Src.data(), 0, Dst.getBaseMemory(), - Dst.getElementOffset() * sizeof(T), - ElementCount * sizeof(T)); + return PDevice->synchronousCopyH2D( + Src.data(), 0, Dst.getBaseMemory().getHandle(), + Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)); } template @@ -249,8 +250,8 @@ ", to a device array of element count " + llvm::Twine(Dst.getElementCount())); return PDevice->synchronousCopyD2D( - Src.getBaseMemory(), Src.getElementOffset() * sizeof(T), - Dst.getBaseMemory(), Dst.getElementOffset() * sizeof(T), + Src.getBaseMemory().getHandle(), Src.getElementOffset() * sizeof(T), + Dst.getBaseMemory().getHandle(), Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)); } Index: streamexecutor/include/streamexecutor/PlatformInterfaces.h =================================================================== --- streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -94,8 +94,7 @@ /// /// HostDst should have been allocated by allocateHostMemory or registered /// with registerHostMemory. - virtual Error copyD2H(PlatformStreamHandle *S, - const GlobalDeviceMemoryBase &DeviceSrc, + virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { return make_error("copyD2H not implemented for platform " + getName()); @@ -106,15 +105,14 @@ /// HostSrc should have been allocated by allocateHostMemory or registered /// with registerHostMemory. virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc, - size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst, + size_t SrcByteOffset, const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { return make_error("copyH2D not implemented for platform " + getName()); } /// Copies data from one device location to another. - virtual Error copyD2D(PlatformStreamHandle *S, - const GlobalDeviceMemoryBase &DeviceSrc, - size_t SrcByteOffset, GlobalDeviceMemoryBase DeviceDst, + virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle, + size_t SrcByteOffset, const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { return make_error("copyD2D not implemented for platform " + getName()); } @@ -127,14 +125,13 @@ } /// Allocates untyped device memory of a given size in bytes. - virtual Expected - allocateDeviceMemory(size_t ByteCount) { + virtual Expected allocateDeviceMemory(size_t ByteCount) { return make_error("allocateDeviceMemory not implemented for platform " + getName()); } /// Frees device memory previously allocated by allocateDeviceMemory. - virtual Error freeDeviceMemory(GlobalDeviceMemoryBase Memory) { + virtual Error freeDeviceMemory(const void *Handle) { return make_error("freeDeviceMemory not implemented for platform " + getName()); } @@ -172,29 +169,29 @@ /// Blocks the calling host thread until the copy is completed. Can operate on /// any host memory, not just registered host memory or host memory allocated /// by allocateHostMemory. Does not block any ongoing device calls. - virtual Error synchronousCopyD2H(const GlobalDeviceMemoryBase &DeviceSrc, + virtual Error synchronousCopyD2H(const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { return make_error("synchronousCopyD2H not implemented for platform " + getName()); } - /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void + /// Similar to synchronousCopyD2H(const void *, size_t, void /// *, size_t, size_t), but copies memory from host to device rather than /// device to host. virtual Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset, - GlobalDeviceMemoryBase DeviceDst, + const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { return make_error("synchronousCopyH2D not implemented for platform " + getName()); } - /// Similar to synchronousCopyD2H(const GlobalDeviceMemoryBase &, size_t, void + /// Similar to synchronousCopyD2H(const void *, size_t, void /// *, size_t, size_t), but copies memory from one location in device memory /// to another rather than from device to host. - virtual Error synchronousCopyD2D(GlobalDeviceMemoryBase DeviceDst, + virtual Error synchronousCopyD2D(const void *DeviceDstHandle, size_t DstByteOffset, - const GlobalDeviceMemoryBase &DeviceSrc, + const void *DeviceSrcHandle, size_t SrcByteOffset, size_t ByteCount) { return make_error("synchronousCopyD2D not implemented for platform " + getName()); Index: streamexecutor/include/streamexecutor/Stream.h =================================================================== --- streamexecutor/include/streamexecutor/Stream.h +++ streamexecutor/include/streamexecutor/Stream.h @@ -133,7 +133,8 @@ setError("copying too many elements, " + llvm::Twine(ElementCount) + ", to a host array of element count " + llvm::Twine(Dst.size())); else - setError(PDevice->copyD2H(ThePlatformStream.get(), Src.getBaseMemory(), + setError(PDevice->copyD2H(ThePlatformStream.get(), + Src.getBaseMemory().getHandle(), Src.getElementOffset() * sizeof(T), Dst.data(), 0, ElementCount * sizeof(T))); return *this; @@ -190,9 +191,10 @@ ", to a device array of element count " + llvm::Twine(Dst.getElementCount())); else - setError(PDevice->copyH2D( - ThePlatformStream.get(), Src.data(), 0, Dst.getBaseMemory(), - Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); + setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0, + Dst.getBaseMemory().getHandle(), + Dst.getElementOffset() * sizeof(T), + ElementCount * sizeof(T))); return *this; } @@ -247,8 +249,8 @@ llvm::Twine(Dst.getElementCount())); else setError(PDevice->copyD2D( - ThePlatformStream.get(), Src.getBaseMemory(), - Src.getElementOffset() * sizeof(T), Dst.getBaseMemory(), + ThePlatformStream.get(), Src.getBaseMemory().getHandle(), + Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(), Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); return *this; } Index: streamexecutor/lib/unittests/DeviceTest.cpp =================================================================== --- streamexecutor/lib/unittests/DeviceTest.cpp +++ streamexecutor/lib/unittests/DeviceTest.cpp @@ -15,6 +15,7 @@ #include #include +#include "SimpleHostPlatformDevice.h" #include "streamexecutor/Device.h" #include "streamexecutor/PlatformInterfaces.h" @@ -24,79 +25,6 @@ namespace se = ::streamexecutor; -class MockPlatformDevice : public se::PlatformDevice { -public: - ~MockPlatformDevice() override {} - - std::string getName() const override { return "MockPlatformDevice"; } - - se::Expected> - createStream() override { - return se::make_error("not implemented"); - } - - se::Expected - allocateDeviceMemory(size_t ByteCount) override { - return se::GlobalDeviceMemoryBase(std::malloc(ByteCount)); - } - - se::Error freeDeviceMemory(se::GlobalDeviceMemoryBase Memory) override { - std::free(const_cast(Memory.getHandle())); - return se::Error::success(); - } - - se::Expected allocateHostMemory(size_t ByteCount) override { - return std::malloc(ByteCount); - } - - se::Error freeHostMemory(void *Memory) override { - std::free(Memory); - return se::Error::success(); - } - - se::Error registerHostMemory(void *, size_t) override { - return se::Error::success(); - } - - se::Error unregisterHostMemory(void *) override { - return se::Error::success(); - } - - se::Error synchronousCopyD2H(const se::GlobalDeviceMemoryBase &DeviceSrc, - size_t SrcByteOffset, void *HostDst, - size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(HostDst) + DstByteOffset, - static_cast(DeviceSrc.getHandle()) + - SrcByteOffset, - ByteCount); - return se::Error::success(); - } - - se::Error synchronousCopyH2D(const void *HostSrc, size_t SrcByteOffset, - se::GlobalDeviceMemoryBase DeviceDst, - size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceDst.getHandle())) + - DstByteOffset, - static_cast(HostSrc) + SrcByteOffset, ByteCount); - return se::Error::success(); - } - - se::Error synchronousCopyD2D(se::GlobalDeviceMemoryBase DeviceDst, - size_t DstByteOffset, - const se::GlobalDeviceMemoryBase &DeviceSrc, - size_t SrcByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceDst.getHandle())) + - DstByteOffset, - static_cast(DeviceSrc.getHandle()) + - SrcByteOffset, - ByteCount); - return se::Error::success(); - } -}; - /// Test fixture to hold objects used by tests. class DeviceTest : public ::testing::Test { public: @@ -124,7 +52,7 @@ int Host5[5]; int Host7[7]; - MockPlatformDevice PDevice; + SimpleHostPlatformDevice PDevice; se::Device Device; }; Index: streamexecutor/lib/unittests/SimpleHostPlatformDevice.h =================================================================== --- /dev/null +++ streamexecutor/lib/unittests/SimpleHostPlatformDevice.h @@ -0,0 +1,135 @@ +//===-- SimpleHostPlatformDevice.h - Host device for testing ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// The SimpleHostPlatformDevice class is a streamexecutor::PlatformDevice that +/// is really just the host processor and memory. It is useful for testing +/// because no extra device platform is required. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H +#define STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H + +#include +#include + +#include "streamexecutor/PlatformInterfaces.h" + +/// A streamexecutor::PlatformDevice that simply forwards all operations to the +/// host platform. +/// +/// The allocate and copy methods are simple wrappers for std::malloc and +/// std::memcpy. +class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice { + std::string getName() const override { return "SimpleHostPlatformDevice"; } + + streamexecutor::Expected< + std::unique_ptr> + createStream() override { + return nullptr; + } + + streamexecutor::Expected + allocateDeviceMemory(size_t ByteCount) override { + return std::malloc(ByteCount); + } + + streamexecutor::Error freeDeviceMemory(const void *Handle) override { + std::free(const_cast(Handle)); + return streamexecutor::Error::success(); + } + + streamexecutor::Expected + allocateHostMemory(size_t ByteCount) override { + return std::malloc(ByteCount); + } + + streamexecutor::Error freeHostMemory(void *Memory) override { + std::free(const_cast(Memory)); + return streamexecutor::Error::success(); + } + + streamexecutor::Error registerHostMemory(void *Memory, + size_t ByteCount) override { + return streamexecutor::Error::success(); + } + + streamexecutor::Error unregisterHostMemory(void *Memory) override { + return streamexecutor::Error::success(); + } + + streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S, + const void *DeviceHandleSrc, + size_t SrcByteOffset, void *HostDst, + size_t DstByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(HostDst) + DstByteOffset, + static_cast(DeviceHandleSrc) + SrcByteOffset, + ByteCount); + return streamexecutor::Error::success(); + } + + streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S, + const void *HostSrc, size_t SrcByteOffset, + const void *DeviceHandleDst, + size_t DstByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(const_cast(DeviceHandleDst)) + + DstByteOffset, + static_cast(HostSrc) + SrcByteOffset, ByteCount); + return streamexecutor::Error::success(); + } + + streamexecutor::Error + copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc, + size_t SrcByteOffset, const void *DeviceHandleDst, + size_t DstByteOffset, size_t ByteCount) override { + std::memcpy(static_cast(const_cast(DeviceHandleDst)) + + DstByteOffset, + static_cast(DeviceHandleSrc) + SrcByteOffset, + ByteCount); + return streamexecutor::Error::success(); + } + + streamexecutor::Error synchronousCopyD2H(const void *DeviceHandleSrc, + size_t SrcByteOffset, void *HostDst, + size_t DstByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(HostDst) + DstByteOffset, + static_cast(DeviceHandleSrc) + SrcByteOffset, + ByteCount); + return streamexecutor::Error::success(); + } + + streamexecutor::Error synchronousCopyH2D(const void *HostSrc, + size_t SrcByteOffset, + const void *DeviceHandleDst, + size_t DstByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(const_cast(DeviceHandleDst)) + + DstByteOffset, + static_cast(HostSrc) + SrcByteOffset, ByteCount); + return streamexecutor::Error::success(); + } + + streamexecutor::Error synchronousCopyD2D(const void *DeviceHandleSrc, + size_t SrcByteOffset, + const void *DeviceHandleDst, + size_t DstByteOffset, + size_t ByteCount) override { + std::memcpy(static_cast(const_cast(DeviceHandleDst)) + + DstByteOffset, + static_cast(DeviceHandleSrc) + SrcByteOffset, + ByteCount); + return streamexecutor::Error::success(); + } +}; + +#endif // STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H Index: streamexecutor/lib/unittests/StreamTest.cpp =================================================================== --- streamexecutor/lib/unittests/StreamTest.cpp +++ streamexecutor/lib/unittests/StreamTest.cpp @@ -14,6 +14,7 @@ #include +#include "SimpleHostPlatformDevice.h" #include "streamexecutor/Device.h" #include "streamexecutor/Kernel.h" #include "streamexecutor/KernelSpec.h" @@ -26,52 +27,6 @@ namespace se = ::streamexecutor; -/// Mock PlatformDevice that performs asynchronous memcpy operations by -/// ignoring the stream argument and calling std::memcpy on device memory -/// handles. -class MockPlatformDevice : public se::PlatformDevice { -public: - ~MockPlatformDevice() override {} - - std::string getName() const override { return "MockPlatformDevice"; } - - se::Expected> - createStream() override { - return nullptr; - } - - se::Error copyD2H(se::PlatformStreamHandle *S, - const se::GlobalDeviceMemoryBase &DeviceSrc, - size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(HostDst, static_cast(DeviceSrc.getHandle()) + - SrcByteOffset, - ByteCount); - return se::Error::success(); - } - - se::Error copyH2D(se::PlatformStreamHandle *S, const void *HostSrc, - size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst, - size_t DstByteOffset, size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceDst.getHandle())) + - DstByteOffset, - HostSrc, ByteCount); - return se::Error::success(); - } - - se::Error copyD2D(se::PlatformStreamHandle *S, - const se::GlobalDeviceMemoryBase &DeviceSrc, - size_t SrcByteOffset, se::GlobalDeviceMemoryBase DeviceDst, - size_t DstByteOffset, size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceDst.getHandle())) + - DstByteOffset, - static_cast(DeviceSrc.getHandle()) + - SrcByteOffset, - ByteCount); - return se::Error::success(); - } -}; - /// Test fixture to hold objects used by tests. class StreamTest : public ::testing::Test { public: @@ -100,7 +55,7 @@ int Host5[5]; int Host7[7]; - MockPlatformDevice PDevice; + SimpleHostPlatformDevice PDevice; se::Stream Stream; };