Index: streamexecutor/include/streamexecutor/platforms/host/HostPlatformDevice.h =================================================================== --- streamexecutor/include/streamexecutor/platforms/host/HostPlatformDevice.h +++ streamexecutor/include/streamexecutor/platforms/host/HostPlatformDevice.h @@ -139,6 +139,14 @@ return Error::success(); } + /// Gets the value at the given index from a GlobalDeviceMemory instance + /// created by this class. + template + static T getDeviceValue(const streamexecutor::GlobalDeviceMemory &Memory, + size_t Index) { + return static_cast(Memory.getHandle())[Index]; + } + private: static void *offset(const void *Base, size_t Offset) { return const_cast(static_cast(Base) + Offset); Index: streamexecutor/include/streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h =================================================================== --- streamexecutor/include/streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h +++ /dev/null @@ -1,138 +0,0 @@ -//===-- SimpleHostPlatformDevice.h - Host device for testing ----*- C++ -*-===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// The SimpleHostPlatformDevice class is a streamexecutor::PlatformDevice that -/// is really just the host processor and memory. It is useful for testing -/// because no extra device platform is required. -/// -//===----------------------------------------------------------------------===// - -#ifndef STREAMEXECUTOR_UNITTESTS_CORETESTS_SIMPLEHOSTPLATFORMDEVICE_H -#define STREAMEXECUTOR_UNITTESTS_CORETESTS_SIMPLEHOSTPLATFORMDEVICE_H - -#include -#include - -#include "streamexecutor/PlatformDevice.h" - -namespace streamexecutor { -namespace test { - -/// A streamexecutor::PlatformDevice that simply forwards all operations to the -/// host platform. -/// -/// The allocate and copy methods are simple wrappers for std::malloc and -/// std::memcpy. -class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice { -public: - std::string getName() const override { return "SimpleHostPlatformDevice"; } - - streamexecutor::Expected createStream() override { - return nullptr; - } - - streamexecutor::Expected - allocateDeviceMemory(size_t ByteCount) override { - return std::malloc(ByteCount); - } - - streamexecutor::Error freeDeviceMemory(const void *Handle) override { - std::free(const_cast(Handle)); - return streamexecutor::Error::success(); - } - - streamexecutor::Error registerHostMemory(void *Memory, - size_t ByteCount) override { - return streamexecutor::Error::success(); - } - - streamexecutor::Error unregisterHostMemory(const void *Memory) override { - return streamexecutor::Error::success(); - } - - streamexecutor::Error copyD2H(const void *StreamHandle, - const void *DeviceHandleSrc, - size_t SrcByteOffset, void *HostDst, - size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(HostDst) + DstByteOffset, - static_cast(DeviceHandleSrc) + SrcByteOffset, - ByteCount); - return streamexecutor::Error::success(); - } - - streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc, - size_t SrcByteOffset, - const void *DeviceHandleDst, - size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceHandleDst)) + - DstByteOffset, - static_cast(HostSrc) + SrcByteOffset, ByteCount); - return streamexecutor::Error::success(); - } - - streamexecutor::Error - copyD2D(const void *StreamHandle, const void *DeviceHandleSrc, - size_t SrcByteOffset, const void *DeviceHandleDst, - size_t DstByteOffset, size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceHandleDst)) + - DstByteOffset, - static_cast(DeviceHandleSrc) + SrcByteOffset, - ByteCount); - return streamexecutor::Error::success(); - } - - streamexecutor::Error synchronousCopyD2H(const void *DeviceHandleSrc, - size_t SrcByteOffset, void *HostDst, - size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(HostDst) + DstByteOffset, - static_cast(DeviceHandleSrc) + SrcByteOffset, - ByteCount); - return streamexecutor::Error::success(); - } - - streamexecutor::Error synchronousCopyH2D(const void *HostSrc, - size_t SrcByteOffset, - const void *DeviceHandleDst, - size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceHandleDst)) + - DstByteOffset, - static_cast(HostSrc) + SrcByteOffset, ByteCount); - return streamexecutor::Error::success(); - } - - streamexecutor::Error synchronousCopyD2D(const void *DeviceHandleSrc, - size_t SrcByteOffset, - const void *DeviceHandleDst, - size_t DstByteOffset, - size_t ByteCount) override { - std::memcpy(static_cast(const_cast(DeviceHandleDst)) + - DstByteOffset, - static_cast(DeviceHandleSrc) + SrcByteOffset, - ByteCount); - return streamexecutor::Error::success(); - } - - /// Gets the value at the given index from a GlobalDeviceMemory instance - /// created by this class. - template - static T getDeviceValue(const streamexecutor::GlobalDeviceMemory &Memory, - size_t Index) { - return static_cast(Memory.getHandle())[Index]; - } -}; - -} // namespace test -} // namespace streamexecutor - -#endif // STREAMEXECUTOR_UNITTESTS_CORETESTS_SIMPLEHOSTPLATFORMDEVICE_H Index: streamexecutor/unittests/CoreTests/DeviceTest.cpp =================================================================== --- streamexecutor/unittests/CoreTests/DeviceTest.cpp +++ streamexecutor/unittests/CoreTests/DeviceTest.cpp @@ -17,7 +17,7 @@ #include "streamexecutor/Device.h" #include "streamexecutor/PlatformDevice.h" -#include "streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h" +#include "streamexecutor/platforms/host/HostPlatformDevice.h" #include "gtest/gtest.h" @@ -25,8 +25,7 @@ namespace se = ::streamexecutor; -const auto &getDeviceValue = - se::test::SimpleHostPlatformDevice::getDeviceValue; +const auto &getDeviceValue = se::host::HostPlatformDevice::getDeviceValue; /// Test fixture to hold objects used by tests. class DeviceTest : public ::testing::Test { @@ -45,7 +44,7 @@ se::dieIfError(Device.synchronousCopyH2D(HostB7, DeviceB7)); } - se::test::SimpleHostPlatformDevice PDevice; + se::host::HostPlatformDevice PDevice; se::Device Device; // Device memory is backed by host arrays. @@ -74,9 +73,7 @@ using llvm::ArrayRef; using llvm::MutableArrayRef; -TEST_F(DeviceTest, GetName) { - EXPECT_EQ(Device.getName(), "SimpleHostPlatformDevice"); -} +TEST_F(DeviceTest, GetName) { EXPECT_EQ(Device.getName(), "host"); } TEST_F(DeviceTest, AllocateAndFreeDeviceMemory) { se::Expected> MaybeMemory = Index: streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp =================================================================== --- streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp +++ streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp @@ -16,7 +16,7 @@ #include "streamexecutor/DeviceMemory.h" #include "streamexecutor/PackedKernelArgumentArray.h" #include "streamexecutor/PlatformDevice.h" -#include "streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h" +#include "streamexecutor/platforms/host/HostPlatformDevice.h" #include "llvm/ADT/Twine.h" @@ -41,7 +41,7 @@ TypedShared( se::SharedDeviceMemory::makeFromElementCount(ElementCount)) {} - se::test::SimpleHostPlatformDevice PDevice; + se::host::HostPlatformDevice PDevice; se::Device Device; int Value; void *Handle; Index: streamexecutor/unittests/CoreTests/StreamTest.cpp =================================================================== --- streamexecutor/unittests/CoreTests/StreamTest.cpp +++ streamexecutor/unittests/CoreTests/StreamTest.cpp @@ -19,7 +19,7 @@ #include "streamexecutor/KernelSpec.h" #include "streamexecutor/PlatformDevice.h" #include "streamexecutor/Stream.h" -#include "streamexecutor/unittests/CoreTests/SimpleHostPlatformDevice.h" +#include "streamexecutor/platforms/host/HostPlatformDevice.h" #include "gtest/gtest.h" @@ -27,8 +27,7 @@ namespace se = ::streamexecutor; -const auto &getDeviceValue = - se::test::SimpleHostPlatformDevice::getDeviceValue; +const auto &getDeviceValue = se::host::HostPlatformDevice::getDeviceValue; /// Test fixture to hold objects used by tests. class StreamTest : public ::testing::Test { @@ -56,7 +55,7 @@ protected: int DummyPlatformStream; // Mimicking a platform where the platform stream // handle is just a stream number. - se::test::SimpleHostPlatformDevice PDevice; + se::host::HostPlatformDevice PDevice; se::Device Device; se::Stream Stream;