Index: streamexecutor/examples/CUDASaxpy.cpp =================================================================== --- streamexecutor/examples/CUDASaxpy.cpp +++ streamexecutor/examples/CUDASaxpy.cpp @@ -115,6 +115,11 @@ cg::SaxpyKernel Kernel = getOrDie(Device->createKernel(cg::SaxpyLoaderSpec)); + se::RegisteredHostMemory RegisteredX = + getOrDie(Device->registerHostMemory(HostX)); + se::RegisteredHostMemory RegisteredY = + getOrDie(Device->registerHostMemory(HostY)); + // Allocate memory on the device. se::GlobalDeviceMemory X = getOrDie(Device->allocateDeviceMemory(ArraySize)); @@ -123,10 +128,10 @@ // Run operations on a stream. se::Stream Stream = getOrDie(Device->createStream()); - Stream.thenCopyH2D(HostX, X) - .thenCopyH2D(HostY, Y) + Stream.thenCopyH2D(RegisteredX, X) + .thenCopyH2D(RegisteredY, Y) .thenLaunch(ArraySize, 1, Kernel, A, X, Y) - .thenCopyD2H(X, HostX); + .thenCopyD2H(X, RegisteredX); // Wait for the stream to complete. se::dieIfError(Stream.blockHostUntilDone()); Index: streamexecutor/include/streamexecutor/Device.h =================================================================== --- streamexecutor/include/streamexecutor/Device.h +++ streamexecutor/include/streamexecutor/Device.h @@ -17,6 +17,7 @@ #include +#include "streamexecutor/HostMemory.h" #include "streamexecutor/KernelSpec.h" #include "streamexecutor/PlatformDevice.h" #include "streamexecutor/Utils/Error.h" @@ -58,36 +59,19 @@ return GlobalDeviceMemory(this, *MaybeMemory, ElementCount); } - /// Allocates an array of ElementCount entries of type T in host memory. - /// - /// Host memory allocated by this function can be used for asynchronous memory - /// copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D. - template Expected allocateHostMemory(size_t ElementCount) { - Expected MaybeMemory = - PDevice->allocateHostMemory(ElementCount * sizeof(T)); - if (!MaybeMemory) - return MaybeMemory.takeError(); - return static_cast(*MaybeMemory); - } - - /// Frees memory previously allocated with allocateHostMemory. - template Error freeHostMemory(T *Memory) { - return PDevice->freeHostMemory(Memory); - } - /// Registers a previously allocated host array of type T for asynchronous /// memory operations. /// /// Host memory registered by this function can be used for asynchronous /// memory copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D. template - Error registerHostMemory(T *Memory, size_t ElementCount) { - return PDevice->registerHostMemory(Memory, ElementCount * sizeof(T)); - } - - /// Unregisters host memory previously registered by registerHostMemory. - template Error unregisterHostMemory(T *Memory) { - return PDevice->unregisterHostMemory(Memory); + Expected> + registerHostMemory(llvm::MutableArrayRef Memory) { + if (Error E = PDevice->registerHostMemory(Memory.data(), + Memory.size() * sizeof(T))) { + return std::move(E); + } + return RegisteredHostMemory(this, Memory.data(), Memory.size()); } /// \anchor DeviceHostSyncCopyGroup @@ -98,9 +82,8 @@ /// device calls. /// /// There are no restrictions on the host memory that is used as a source or - /// destination in these copy methods, so there is no need to allocate that - /// host memory using allocateHostMemory or register it with - /// registerHostMemory. + /// destination in these copy methods, so there is no need to register that + /// host memory with registerHostMemory. /// /// Each of these methods has a single template parameter, T, that specifies /// the type of data being copied. The ElementCount arguments specify the @@ -303,6 +286,12 @@ return PDevice->freeDeviceMemory(Memory.getHandle()); } + // Only destroyRegisteredHostMemoryInternals may unregister host memory. + friend void internal::destroyRegisteredHostMemoryInternals(Device *, void *); + Error unregisterHostMemory(const void *Pointer) { + return PDevice->unregisterHostMemory(Pointer); + } + PlatformDevice *PDevice; }; Index: streamexecutor/include/streamexecutor/DeviceMemory.h =================================================================== --- streamexecutor/include/streamexecutor/DeviceMemory.h +++ streamexecutor/include/streamexecutor/DeviceMemory.h @@ -46,6 +46,8 @@ /// memory, and an element count for the size of the slice. template class GlobalDeviceMemorySlice { public: + using ElementTy = ElemT; + /// Intentionally implicit so GlobalDeviceMemory can be passed to functions /// expecting GlobalDeviceMemorySlice arguments. GlobalDeviceMemorySlice(const GlobalDeviceMemory &Memory) @@ -171,6 +173,8 @@ template class GlobalDeviceMemory : public GlobalDeviceMemoryBase { public: + using ElementTy = ElemT; + GlobalDeviceMemory(GlobalDeviceMemory &&Other) = default; GlobalDeviceMemory &operator=(GlobalDeviceMemory &&Other) = default; Index: streamexecutor/include/streamexecutor/HostMemory.h =================================================================== --- /dev/null +++ streamexecutor/include/streamexecutor/HostMemory.h @@ -0,0 +1,146 @@ +//===-- HostMemory.h - Types for registered host memory ---------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// +/// This file defines types that represent registered host memory buffers. Host +/// memory must be registered to participate in asynchronous copies to or from +/// device memory. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_HOSTMEMORY_H +#define STREAMEXECUTOR_HOSTMEMORY_H + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" + +namespace streamexecutor { + +class Device; +template class RegisteredHostMemory; + +/// A slice of registered host memory. +/// +/// The memory is registered in the sense of +/// streamexecutor::Device::registerHostMemory. +/// +/// Holds a reference to an underlying registered host memory buffer. Must not +/// be used after the underlying buffer is freed or unregistered. +template class RegisteredHostMemorySlice { +public: + using ElementTy = ElemT; + + RegisteredHostMemorySlice(const RegisteredHostMemory &Registered) + : MutableArrayRef(const_cast(Registered.getPointer()), + Registered.getElementCount()) {} + + ElemT *getPointer() { return MutableArrayRef.data(); } + const ElemT *getPointer() const { return MutableArrayRef.data(); } + size_t getElementCount() const { return MutableArrayRef.size(); } + + /// Chops off the first N elements of the slice. + RegisteredHostMemorySlice slice(size_t N) const { + return RegisteredHostMemorySlice(MutableArrayRef.slice(N)); + } + + /// Chops off the first N elements of the slice and keeps the next M elements. + RegisteredHostMemorySlice slice(size_t N, size_t M) const { + return RegisteredHostMemorySlice(MutableArrayRef.slice(N, M)); + } + + /// Chops off the last N elements of the slice. + RegisteredHostMemorySlice drop_back(size_t N) const { + return RegisteredHostMemorySlice(MutableArrayRef.drop_back(N)); + } + +private: + RegisteredHostMemorySlice(llvm::MutableArrayRef MutableArrayRef) + : MutableArrayRef(MutableArrayRef) {} + + llvm::MutableArrayRef MutableArrayRef; +}; + +namespace internal { + +/// Helper function to unregister host memory. +/// +/// This is a thin wrapper around streamexecutor::Device::unregisterHostMemory. +/// It is defined so this operation can be performed from the destructor of the +/// template class RegisteredHostMemory without including Device.h in this +/// header and creating a header inclusion cycle. +void destroyRegisteredHostMemoryInternals(Device *TheDevice, void *Pointer); + +} // namespace internal + +/// Registered host memory that knows how to unregister itself upon destruction. +/// +/// The memory is registered in the sense of +/// streamexecutor::Device::registerHostMemory. +/// +/// ElemT is the type of element stored in the host buffer. +template class RegisteredHostMemory { +public: + using ElementTy = ElemT; + + RegisteredHostMemory(Device *TheDevice, ElemT *Pointer, size_t ElementCount) + : TheDevice(TheDevice), Pointer(Pointer), ElementCount(ElementCount) { + assert(TheDevice != nullptr && "cannot construct a " + "RegisteredHostMemoryBase with a null " + "platform device"); + } + + RegisteredHostMemory(const RegisteredHostMemory &) = delete; + RegisteredHostMemory &operator=(const RegisteredHostMemory &) = delete; + + RegisteredHostMemory(RegisteredHostMemory &&Other) + : TheDevice(Other.TheDevice), Pointer(Other.Pointer), + ElementCount(Other.ElementCount) { + Other.TheDevice = nullptr; + Other.Pointer = nullptr; + } + + RegisteredHostMemory &operator=(RegisteredHostMemory &&Other) { + TheDevice = Other.TheDevice; + Pointer = Other.Pointer; + ElementCount = Other.ElementCount; + Other.TheDevice = nullptr; + Other.Pointer = nullptr; + } + + ~RegisteredHostMemory() { + internal::destroyRegisteredHostMemoryInternals(TheDevice, Pointer); + } + + ElemT *getPointer() { return static_cast(Pointer); } + const ElemT *getPointer() const { return static_cast(Pointer); } + size_t getElementCount() const { return ElementCount; } + + /// Creates an immutable slice for the entire contents of this memory. + RegisteredHostMemorySlice asSlice() const { + return RegisteredHostMemorySlice(*this); + } + + /// Creates a mutable slice for the entire contents of this memory. + RegisteredHostMemorySlice asSlice() { + return RegisteredHostMemorySlice(*this); + } + +private: + Device *TheDevice; + void *Pointer; + size_t ElementCount; +}; + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_HOSTMEMORY_H Index: streamexecutor/include/streamexecutor/PlatformDevice.h =================================================================== --- streamexecutor/include/streamexecutor/PlatformDevice.h +++ streamexecutor/include/streamexecutor/PlatformDevice.h @@ -68,8 +68,7 @@ /// Copies data from the device to the host. /// - /// HostDst should have been allocated by allocateHostMemory or registered - /// with registerHostMemory. + /// HostDst should have been registered with registerHostMemory. virtual Error copyD2H(const void *PlatformStreamHandle, const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { @@ -78,8 +77,7 @@ /// Copies data from the host to the device. /// - /// HostSrc should have been allocated by allocateHostMemory or registered - /// with registerHostMemory. + /// HostSrc should have been registered with registerHostMemory. virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc, size_t SrcByteOffset, const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { @@ -113,21 +111,6 @@ getName()); } - /// Allocates untyped host memory of a given size in bytes. - /// - /// Host memory allocated via this method is suitable for use with copyH2D and - /// copyD2H. - virtual Expected allocateHostMemory(size_t ByteCount) { - return make_error("allocateHostMemory not implemented for platform " + - getName()); - } - - /// Frees host memory allocated by allocateHostMemory. - virtual Error freeHostMemory(void *Memory) { - return make_error("freeHostMemory not implemented for platform " + - getName()); - } - /// Registers previously allocated host memory so it can be used with copyH2D /// and copyD2H. virtual Error registerHostMemory(void *Memory, size_t ByteCount) { @@ -136,7 +119,7 @@ } /// Unregisters host memory previously registered with registerHostMemory. - virtual Error unregisterHostMemory(void *Memory) { + virtual Error unregisterHostMemory(const void *Memory) { return make_error("unregisterHostMemory not implemented for platform " + getName()); } @@ -144,8 +127,8 @@ /// Copies the given number of bytes from device memory to host memory. /// /// Blocks the calling host thread until the copy is completed. Can operate on - /// any host memory, not just registered host memory or host memory allocated - /// by allocateHostMemory. Does not block any ongoing device calls. + /// any host memory, not just registered host memory. Does not block any + /// ongoing device calls. virtual Error synchronousCopyD2H(const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { Index: streamexecutor/include/streamexecutor/Stream.h =================================================================== --- streamexecutor/include/streamexecutor/Stream.h +++ streamexecutor/include/streamexecutor/Stream.h @@ -33,8 +33,10 @@ #include #include #include +#include #include "streamexecutor/DeviceMemory.h" +#include "streamexecutor/HostMemory.h" #include "streamexecutor/Kernel.h" #include "streamexecutor/LaunchDimensions.h" #include "streamexecutor/PackedKernelArgumentArray.h" @@ -118,98 +120,103 @@ /// These methods enqueue a device memory copy operation on the stream and /// return without waiting for the operation to complete. /// - /// Any host memory used as a source or destination for one of these - /// operations must be allocated with Device::allocateHostMemory or registered - /// with Device::registerHostMemory. Otherwise, the enqueuing operation may - /// block until the copy operation is fully complete. - /// /// The arguments and bounds checking for these methods match the API of the /// \ref DeviceHostSyncCopyGroup /// "host-synchronous device memory copying functions" of Device. + /// + /// The template types SrcTy and DstTy must match the following constraints: + /// * Must define typename ElementTy (the type of element stored in the + /// memory); + /// * ElementTy for the source argument must be the same as ElementTy for + /// the destination argument; + /// * Must be convertible to the correct slice type: + /// * GlobalDeviceMemorySlice for device memory arguments, + /// * RegisteredHostMemorySlice for host memory arguments. ///@{ + // D2H + template Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, - llvm::MutableArrayRef Dst, size_t ElementCount) { + RegisteredHostMemorySlice Dst, size_t ElementCount) { if (ElementCount > Src.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", from a device array of element count " + llvm::Twine(Src.getElementCount())); - else if (ElementCount > Dst.size()) + else if (ElementCount > Dst.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", to a host array of element count " + llvm::Twine(Dst.size())); + ", to a host array of element count " + + llvm::Twine(Dst.getElementCount())); else setError(PDevice->copyD2H(PlatformStreamHandle, Src.getBaseMemory().getHandle(), - Src.getElementOffset() * sizeof(T), Dst.data(), - 0, ElementCount * sizeof(T))); + Src.getElementOffset() * sizeof(T), + Dst.getPointer(), 0, ElementCount * sizeof(T))); return *this; } template Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, - llvm::MutableArrayRef Dst) { - if (Src.getElementCount() != Dst.size()) + RegisteredHostMemorySlice Dst) { + if (Src.getElementCount() != Dst.getElementCount()) setError("array size mismatch for D2H, device source has element count " + llvm::Twine(Src.getElementCount()) + " but host destination has element count " + - llvm::Twine(Dst.size())); + llvm::Twine(Dst.getElementCount())); else thenCopyD2H(Src, Dst, Src.getElementCount()); return *this; } - template - Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, T *Dst, - size_t ElementCount) { - thenCopyD2H(Src, llvm::MutableArrayRef(Dst, ElementCount), ElementCount); - return *this; + template + Stream &thenCopyD2H(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) { + using SrcElemTy = typename std::remove_reference::type::ElementTy; + using DstElemTy = typename std::remove_reference::type::ElementTy; + static_assert(std::is_same::value, + "src/dst element type mismatch for thenCopyD2H"); + GlobalDeviceMemorySlice SrcSlice(Src); + RegisteredHostMemorySlice DstSlice(Dst); + return thenCopyD2H(SrcSlice, DstSlice, ElementCount); } - template - Stream &thenCopyD2H(const GlobalDeviceMemory &Src, - llvm::MutableArrayRef Dst, size_t ElementCount) { - thenCopyD2H(Src.asSlice(), Dst, ElementCount); - return *this; - } - - template - Stream &thenCopyD2H(const GlobalDeviceMemory &Src, - llvm::MutableArrayRef Dst) { - thenCopyD2H(Src.asSlice(), Dst); - return *this; + template + Stream &thenCopyD2H(SrcTy &&Src, DstTy &&Dst) { + using SrcElemTy = typename std::remove_reference::type::ElementTy; + using DstElemTy = typename std::remove_reference::type::ElementTy; + static_assert(std::is_same::value, + "src/dst element type mismatch for thenCopyD2H"); + GlobalDeviceMemorySlice SrcSlice(Src); + RegisteredHostMemorySlice DstSlice(Dst); + return thenCopyD2H(SrcSlice, DstSlice); } - template - Stream &thenCopyD2H(const GlobalDeviceMemory &Src, T *Dst, - size_t ElementCount) { - thenCopyD2H(Src.asSlice(), Dst, ElementCount); - return *this; - } + // H2D template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemorySlice Dst, - size_t ElementCount) { - if (ElementCount > Src.size()) + Stream &thenCopyH2D(RegisteredHostMemorySlice Src, + GlobalDeviceMemorySlice Dst, size_t ElementCount) { + if (ElementCount > Src.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", from a host array of element count " + - llvm::Twine(Src.size())); + llvm::Twine(Src.getElementCount())); else if (ElementCount > Dst.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", to a device array of element count " + llvm::Twine(Dst.getElementCount())); else - setError(PDevice->copyH2D( - PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(), - Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); + setError(PDevice->copyH2D(PlatformStreamHandle, Src.getPointer(), 0, + Dst.getBaseMemory().getHandle(), + Dst.getElementOffset() * sizeof(T), + ElementCount * sizeof(T))); return *this; } template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemorySlice Dst) { - if (Src.size() != Dst.getElementCount()) + Stream &thenCopyH2D(RegisteredHostMemorySlice Src, + GlobalDeviceMemorySlice Dst) { + if (Src.getElementCount() != Dst.getElementCount()) setError("array size mismatch for H2D, host source has element count " + - llvm::Twine(Src.size()) + + llvm::Twine(Src.getElementCount()) + " but device destination has element count " + llvm::Twine(Dst.getElementCount())); else @@ -217,31 +224,29 @@ return *this; } - template - Stream &thenCopyH2D(T *Src, GlobalDeviceMemorySlice Dst, - size_t ElementCount) { - thenCopyH2D(llvm::ArrayRef(Src, ElementCount), Dst, ElementCount); - return *this; - } - - template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemory &Dst, - size_t ElementCount) { - thenCopyH2D(Src, Dst.asSlice(), ElementCount); - return *this; + template + Stream &thenCopyH2D(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) { + using SrcElemTy = typename std::remove_reference::type::ElementTy; + using DstElemTy = typename std::remove_reference::type::ElementTy; + static_assert(std::is_same::value, + "src/dst element type mismatch for thenCopyH2D"); + RegisteredHostMemorySlice SrcSlice(Src); + GlobalDeviceMemorySlice DstSlice(Dst); + return thenCopyH2D(SrcSlice, DstSlice, ElementCount); } - template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemory &Dst) { - thenCopyH2D(Src, Dst.asSlice()); - return *this; + template + Stream &thenCopyH2D(SrcTy &&Src, DstTy &&Dst) { + using SrcElemTy = typename std::remove_reference::type::ElementTy; + using DstElemTy = typename std::remove_reference::type::ElementTy; + static_assert(std::is_same::value, + "src/dst element type mismatch for thenCopyH2D"); + RegisteredHostMemorySlice SrcSlice(Src); + GlobalDeviceMemorySlice DstSlice(Dst); + return thenCopyH2D(SrcSlice, DstSlice); } - template - Stream &thenCopyH2D(T *Src, GlobalDeviceMemory &Dst, size_t ElementCount) { - thenCopyH2D(Src, Dst.asSlice(), ElementCount); - return *this; - } + // D2D template Stream &thenCopyD2D(GlobalDeviceMemorySlice Src, @@ -275,46 +280,26 @@ return *this; } - template - Stream &thenCopyD2D(const GlobalDeviceMemory &Src, - GlobalDeviceMemorySlice Dst, size_t ElementCount) { - thenCopyD2D(Src.asSlice(), Dst, ElementCount); - return *this; - } - - template - Stream &thenCopyD2D(const GlobalDeviceMemory &Src, - GlobalDeviceMemorySlice Dst) { - thenCopyD2D(Src.asSlice(), Dst); - return *this; - } - - template - Stream &thenCopyD2D(GlobalDeviceMemorySlice Src, - GlobalDeviceMemory &Dst, size_t ElementCount) { - thenCopyD2D(Src, Dst.asSlice(), ElementCount); - return *this; + template + Stream &thenCopyD2D(SrcTy &&Src, DstTy &&Dst, size_t ElementCount) { + using SrcElemTy = typename std::remove_reference::type::ElementTy; + using DstElemTy = typename std::remove_reference::type::ElementTy; + static_assert(std::is_same::value, + "src/dst element type mismatch for thenCopyD2D"); + GlobalDeviceMemorySlice SrcSlice(Src); + GlobalDeviceMemorySlice DstSlice(Dst); + return thenCopyD2D(SrcSlice, DstSlice, ElementCount); } - template - Stream &thenCopyD2D(GlobalDeviceMemorySlice Src, - GlobalDeviceMemory &Dst) { - thenCopyD2D(Src, Dst.asSlice()); - return *this; - } - - template - Stream &thenCopyD2D(const GlobalDeviceMemory &Src, - GlobalDeviceMemory &Dst, size_t ElementCount) { - thenCopyD2D(Src.asSlice(), Dst.asSlice(), ElementCount); - return *this; - } - - template - Stream &thenCopyD2D(const GlobalDeviceMemory &Src, - GlobalDeviceMemory &Dst) { - thenCopyD2D(Src.asSlice(), Dst.asSlice()); - return *this; + template + Stream &thenCopyD2D(SrcTy &&Src, DstTy &&Dst) { + using SrcElemTy = typename std::remove_reference::type::ElementTy; + using DstElemTy = typename std::remove_reference::type::ElementTy; + static_assert(std::is_same::value, + "src/dst element type mismatch for thenCopyD2D"); + GlobalDeviceMemorySlice SrcSlice(Src); + GlobalDeviceMemorySlice DstSlice(Dst); + return thenCopyD2D(SrcSlice, DstSlice); } ///@} End device memory copying functions Index: streamexecutor/lib/CMakeLists.txt =================================================================== --- streamexecutor/lib/CMakeLists.txt +++ streamexecutor/lib/CMakeLists.txt @@ -8,6 +8,7 @@ $ Device.cpp DeviceMemory.cpp + HostMemory.cpp Kernel.cpp KernelSpec.cpp PackedKernelArgumentArray.cpp Index: streamexecutor/lib/HostMemory.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/HostMemory.cpp @@ -0,0 +1,29 @@ +//===-- HostMemory.cpp - HostMemory implementation ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation of HostMemory internals. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/HostMemory.h" +#include "streamexecutor/Device.h" + +namespace streamexecutor { +namespace internal { + +void destroyRegisteredHostMemoryInternals(Device *TheDevice, void *Pointer) { + // TODO(jhen): How to handle errors here? + if (Pointer) { + consumeError(TheDevice->unregisterHostMemory(Pointer)); + } +} + +} // namespace internal +} // namespace streamexecutor Index: streamexecutor/lib/unittests/DeviceTest.cpp =================================================================== --- streamexecutor/lib/unittests/DeviceTest.cpp +++ streamexecutor/lib/unittests/DeviceTest.cpp @@ -84,16 +84,11 @@ EXPECT_TRUE(static_cast(MaybeMemory)); } -TEST_F(DeviceTest, AllocateAndFreeHostMemory) { - se::Expected MaybeMemory = Device.allocateHostMemory(10); - EXPECT_TRUE(static_cast(MaybeMemory)); - EXPECT_NO_ERROR(Device.freeHostMemory(*MaybeMemory)); -} - TEST_F(DeviceTest, RegisterAndUnregisterHostMemory) { std::vector Data(10); - EXPECT_NO_ERROR(Device.registerHostMemory(Data.data(), 10)); - EXPECT_NO_ERROR(Device.unregisterHostMemory(Data.data())); + se::Expected> MaybeMemory = + Device.registerHostMemory(Data); + EXPECT_TRUE(static_cast(MaybeMemory)); } // D2H tests Index: streamexecutor/lib/unittests/SimpleHostPlatformDevice.h =================================================================== --- streamexecutor/lib/unittests/SimpleHostPlatformDevice.h +++ streamexecutor/lib/unittests/SimpleHostPlatformDevice.h @@ -48,22 +48,12 @@ return streamexecutor::Error::success(); } - streamexecutor::Expected - allocateHostMemory(size_t ByteCount) override { - return std::malloc(ByteCount); - } - - streamexecutor::Error freeHostMemory(void *Memory) override { - std::free(const_cast(Memory)); - return streamexecutor::Error::success(); - } - streamexecutor::Error registerHostMemory(void *Memory, size_t ByteCount) override { return streamexecutor::Error::success(); } - streamexecutor::Error unregisterHostMemory(void *Memory) override { + streamexecutor::Error unregisterHostMemory(const void *Memory) override { return streamexecutor::Error::success(); } Index: streamexecutor/lib/unittests/StreamTest.cpp =================================================================== --- streamexecutor/lib/unittests/StreamTest.cpp +++ streamexecutor/lib/unittests/StreamTest.cpp @@ -39,6 +39,10 @@ HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, + RegisteredHost5(getOrDie( + Device.registerHostMemory(llvm::MutableArrayRef(Host5)))), + RegisteredHost7(getOrDie( + Device.registerHostMemory(llvm::MutableArrayRef(Host7)))), DeviceA5(getOrDie(Device.allocateDeviceMemory(5))), DeviceB5(getOrDie(Device.allocateDeviceMemory(5))), DeviceA7(getOrDie(Device.allocateDeviceMemory(7))), @@ -66,6 +70,9 @@ int Host5[5]; int Host7[7]; + se::RegisteredHostMemory RegisteredHost5; + se::RegisteredHostMemory RegisteredHost7; + // Device memory. se::GlobalDeviceMemory DeviceA5; se::GlobalDeviceMemory DeviceB5; @@ -78,166 +85,119 @@ // D2H tests -TEST_F(StreamTest, CopyD2HToMutableArrayRefByCount) { - Stream.thenCopyD2H(DeviceA5, MutableArrayRef(Host5), 5); +TEST_F(StreamTest, CopyD2HToRegisteredRefByCount) { + Stream.thenCopyD2H(DeviceA5, RegisteredHost5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceB5, MutableArrayRef(Host5), 2); + Stream.thenCopyD2H(DeviceB5, RegisteredHost5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(HostB5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA7, MutableArrayRef(Host5), 7); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyD2HToMutableArrayRef) { - Stream.thenCopyD2H(DeviceA5, MutableArrayRef(Host5)); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); - } - - Stream.thenCopyD2H(DeviceA5, MutableArrayRef(Host7)); + Stream.thenCopyD2H(DeviceA7, RegisteredHost5, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HToPointer) { - Stream.thenCopyD2H(DeviceA5, Host5, 5); +TEST_F(StreamTest, CopyD2HToRegistered) { + Stream.thenCopyD2H(DeviceA5, RegisteredHost5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA5, Host7, 7); + Stream.thenCopyD2H(DeviceA5, RegisteredHost7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HSliceToMutableArrayRefByCount) { +TEST_F(StreamTest, CopyD2HSliceToRegiseredSliceByCount) { Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1), - MutableArrayRef(Host5 + 1, 4), 4); + RegisteredHost5.asSlice().slice(1, 4), 4); EXPECT_TRUE(Stream.isOK()); for (int I = 1; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceB5.asSlice().drop_back(1), - MutableArrayRef(Host5), 2); + Stream.thenCopyD2H(DeviceB5.asSlice().drop_back(1), RegisteredHost5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(HostB5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef(Host7), 7); + Stream.thenCopyD2H(DeviceA5.asSlice(), RegisteredHost7, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HSliceToMutableArrayRef) { - Stream.thenCopyD2H(DeviceA7.asSlice().slice(1, 5), - MutableArrayRef(Host5)); +TEST_F(StreamTest, CopyD2HSliceToRegistered) { + Stream.thenCopyD2H(DeviceA7.asSlice().slice(1, 5), RegisteredHost5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA7[I + 1], Host5[I]); } - Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef(Host7)); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyD2HSliceToPointer) { - Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1), Host5 + 1, 4); - EXPECT_TRUE(Stream.isOK()); - for (int I = 1; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); - } - - Stream.thenCopyD2H(DeviceA5.asSlice(), Host7, 7); + Stream.thenCopyD2H(DeviceA5.asSlice(), RegisteredHost7); EXPECT_FALSE(Stream.isOK()); } // H2D tests -TEST_F(StreamTest, CopyH2DToArrayRefByCount) { - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5, 5); +TEST_F(StreamTest, CopyH2DFromRegisterdByCount) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host5), DeviceB5, 2); + Stream.thenCopyH2D(RegisteredHost5, DeviceB5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5, 7); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyH2DToArrayRef) { - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); - } - - Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DToPointer) { - Stream.thenCopyH2D(Host5, DeviceA5, 5); +TEST_F(StreamTest, CopyH2DFromRegistered) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(Host7, DeviceA5, 7); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DSliceToArrayRefByCount) { - Stream.thenCopyH2D(ArrayRef(Host5 + 1, 4), +TEST_F(StreamTest, CopyH2DFromRegisteredSliceToSlice) { + Stream.thenCopyH2D(RegisteredHost5.asSlice().slice(1, 4), DeviceA5.asSlice().drop_front(1), 4); EXPECT_TRUE(Stream.isOK()); for (int I = 1; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host5), DeviceB5.asSlice().drop_back(1), 2); + Stream.thenCopyH2D(RegisteredHost5, DeviceB5.asSlice().drop_back(1), 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5.asSlice(), 7); + Stream.thenCopyH2D(RegisteredHost5, DeviceA5.asSlice(), 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DSliceToArrayRef) { - - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5.asSlice()); +TEST_F(StreamTest, CopyH2DRegisteredToSlice) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5.asSlice()); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyH2DSliceToPointer) { - Stream.thenCopyH2D(Host5, DeviceA5.asSlice(), 5); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); - } - - Stream.thenCopyH2D(Host7, DeviceA5.asSlice(), 7); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5.asSlice()); EXPECT_FALSE(Stream.isOK()); } @@ -289,7 +249,6 @@ } TEST_F(StreamTest, CopySliceD2D) { - Stream.thenCopyD2D(DeviceA7.asSlice().drop_back(2), DeviceB5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -318,7 +277,6 @@ } TEST_F(StreamTest, CopyD2DSlice) { - Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice().drop_back(2)); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -330,7 +288,6 @@ } TEST_F(StreamTest, CopySliceD2DSliceByCount) { - Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -348,7 +305,6 @@ } TEST_F(StreamTest, CopySliceD2DSlice) { - Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) {