Index: streamexecutor/examples/CUDASaxpy.cpp =================================================================== --- streamexecutor/examples/CUDASaxpy.cpp +++ streamexecutor/examples/CUDASaxpy.cpp @@ -115,6 +115,11 @@ cg::SaxpyKernel Kernel = getOrDie(Device->createKernel(cg::SaxpyLoaderSpec)); + se::RegisteredHostMemory RegisteredX = + getOrDie(Device->registerHostMemory(HostX)); + se::RegisteredHostMemory RegisteredY = + getOrDie(Device->registerHostMemory(HostY)); + // Allocate memory on the device. se::GlobalDeviceMemory X = getOrDie(Device->allocateDeviceMemory(ArraySize)); @@ -123,10 +128,10 @@ // Run operations on a stream. se::Stream Stream = getOrDie(Device->createStream()); - Stream.thenCopyH2D(HostX, X) - .thenCopyH2D(HostY, Y) + Stream.thenCopyH2D(RegisteredX, X) + .thenCopyH2D(RegisteredY, Y) .thenLaunch(ArraySize, 1, Kernel, A, X, Y) - .thenCopyD2H(X, HostX); + .thenCopyD2H(X, RegisteredX); // Wait for the stream to complete. se::dieIfError(Stream.blockHostUntilDone()); Index: streamexecutor/include/streamexecutor/Device.h =================================================================== --- streamexecutor/include/streamexecutor/Device.h +++ streamexecutor/include/streamexecutor/Device.h @@ -17,6 +17,7 @@ #include +#include "streamexecutor/HostMemory.h" #include "streamexecutor/KernelSpec.h" #include "streamexecutor/PlatformDevice.h" #include "streamexecutor/Utils/Error.h" @@ -58,36 +59,19 @@ return GlobalDeviceMemory(this, *MaybeMemory, ElementCount); } - /// Allocates an array of ElementCount entries of type T in host memory. - /// - /// Host memory allocated by this function can be used for asynchronous memory - /// copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D. - template Expected allocateHostMemory(size_t ElementCount) { - Expected MaybeMemory = - PDevice->allocateHostMemory(ElementCount * sizeof(T)); - if (!MaybeMemory) - return MaybeMemory.takeError(); - return static_cast(*MaybeMemory); - } - - /// Frees memory previously allocated with allocateHostMemory. - template Error freeHostMemory(T *Memory) { - return PDevice->freeHostMemory(Memory); - } - /// Registers a previously allocated host array of type T for asynchronous /// memory operations. /// /// Host memory registered by this function can be used for asynchronous /// memory copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D. template - Error registerHostMemory(T *Memory, size_t ElementCount) { - return PDevice->registerHostMemory(Memory, ElementCount * sizeof(T)); - } - - /// Unregisters host memory previously registered by registerHostMemory. - template Error unregisterHostMemory(T *Memory) { - return PDevice->unregisterHostMemory(Memory); + Expected> + registerHostMemory(llvm::MutableArrayRef Memory) { + if (Error E = PDevice->registerHostMemory(Memory.data(), + Memory.size() * sizeof(T))) { + return std::move(E); + } + return RegisteredHostMemory(this, Memory.data(), Memory.size()); } /// \anchor DeviceHostSyncCopyGroup @@ -98,9 +82,8 @@ /// device calls. /// /// There are no restrictions on the host memory that is used as a source or - /// destination in these copy methods, so there is no need to allocate that - /// host memory using allocateHostMemory or register it with - /// registerHostMemory. + /// destination in these copy methods, so there is no need to register that + /// host memory with registerHostMemory. /// /// Each of these methods has a single template parameter, T, that specifies /// the type of data being copied. The ElementCount arguments specify the @@ -303,6 +286,13 @@ return PDevice->freeDeviceMemory(Memory.getHandle()); } + // Only a RegisteredHostMemoryBase may unregister host memory. + friend RegisteredHostMemoryBase; + + Error unregisterHostMemory(const RegisteredHostMemoryBase &Memory) { + return PDevice->unregisterHostMemory(Memory.getUntypedPointer()); + } + PlatformDevice *PDevice; }; Index: streamexecutor/include/streamexecutor/HostMemory.h =================================================================== --- /dev/null +++ streamexecutor/include/streamexecutor/HostMemory.h @@ -0,0 +1,204 @@ +//===-- HostMemory.h - Types for registered host memory ---------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// +/// This file defines types that represent registered host memory buffers. Host +/// memory must be registered to participate in asynchronous copies to or from +/// device memory. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_HOSTMEMORY_H +#define STREAMEXECUTOR_HOSTMEMORY_H + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" + +namespace streamexecutor { + +class Device; +template class RegisteredHostMemory; + +/// An immutable slice of registered host memory. +/// +/// Holds a reference to an underlying registered host memory buffer. Must not +/// be used after the underlying buffer is freed or unregistered. +template class RegisteredHostMemorySlice { +public: + const ElemT *getPointer() const { return ArrayRef.data(); } + size_t getElementCount() const { return ArrayRef.size(); } + + /// Chops off the first N elements of the slice. + RegisteredHostMemorySlice slice(size_t N) const { + return RegisteredHostMemorySlice(ArrayRef.slice(N)); + } + + /// Chops off the first N elements of the slice and keeps the next M elements. + RegisteredHostMemorySlice slice(size_t N, size_t M) const { + return RegisteredHostMemorySlice(ArrayRef.slice(N, M)); + } + + /// Chops off the last N elements of the slice. + RegisteredHostMemorySlice drop_back(size_t N) const { + return RegisteredHostMemorySlice(ArrayRef.drop_back(N)); + } + +private: + friend RegisteredHostMemory; + + RegisteredHostMemorySlice(llvm::ArrayRef ArrayRef) + : ArrayRef(ArrayRef) {} + + llvm::ArrayRef ArrayRef; +}; + +/// A mutable slice of registered host memory. +/// +/// Holds a reference to an underlying registered host memory buffer. Must not +/// be used after the underlying buffer is freed or unregistered. +template class MutableRegisteredHostMemorySlice { +public: + ElemT *getPointer() const { return MutableArrayRef.data(); } + size_t getElementCount() const { return MutableArrayRef.size(); } + + /// Chops off the first N elements of the slice. + MutableRegisteredHostMemorySlice slice(size_t N) const { + return MutableRegisteredHostMemorySlice(MutableArrayRef.slice(N)); + } + + /// Chops off the first N elements of the slice and keeps the next M elements. + MutableRegisteredHostMemorySlice slice(size_t N, size_t M) const { + return MutableRegisteredHostMemorySlice(MutableArrayRef.slice(N, M)); + } + + /// Chops off the last N elements of the slice. + MutableRegisteredHostMemorySlice drop_back(size_t N) const { + return MutableRegisteredHostMemorySlice(MutableArrayRef.drop_back(N)); + } + +private: + friend RegisteredHostMemory; + + MutableRegisteredHostMemorySlice(llvm::MutableArrayRef MutableArrayRef) + : MutableArrayRef(MutableArrayRef) {} + + llvm::MutableArrayRef MutableArrayRef; +}; + +/// Base class for registered host memory that knows how to unregister itself +/// upon destruction. +/// +/// This class does not keep track of the data types of the underlying memory. +/// See RegisteredHostMemory for the subclass that does keep track of types. +/// +/// Can be created by registering previously allocated host memory and will +/// unregister that memory upon destruction. +class RegisteredHostMemoryBase { +public: + RegisteredHostMemoryBase(const RegisteredHostMemoryBase &) = delete; + RegisteredHostMemoryBase & + operator=(const RegisteredHostMemoryBase &) = delete; + + RegisteredHostMemoryBase(RegisteredHostMemoryBase &&Other) + : TheDevice(Other.TheDevice), Pointer(Other.Pointer), + ByteCount(Other.ByteCount) { + Other.TheDevice = nullptr; + Other.Pointer = nullptr; + }; + + RegisteredHostMemoryBase &operator=(RegisteredHostMemoryBase &&Other) { + TheDevice = Other.TheDevice; + Pointer = Other.Pointer; + ByteCount = Other.ByteCount; + Other.TheDevice = nullptr; + Other.Pointer = nullptr; + return *this; + } + + const void *getUntypedPointer() const { return Pointer; } + size_t getByteCount() const { return ByteCount; } + +protected: + RegisteredHostMemoryBase(Device *TheDevice, void *Pointer, size_t ByteCount) + : TheDevice(TheDevice), Pointer(Pointer), ByteCount(ByteCount) { + assert(TheDevice != nullptr && "cannot construct a " + "RegisteredHostMemoryBase with a null " + "platform device"); + } + + ~RegisteredHostMemoryBase(); + + Device *TheDevice; + void *Pointer; + size_t ByteCount; +}; + +/// A subclass of RegisteredHostMemoryBase that keeps track of the types of +/// elements stored in the host buffer. +/// +/// ElemT is the type of element stored in the host buffer. +template +class RegisteredHostMemory : RegisteredHostMemoryBase { +public: + RegisteredHostMemory(Device *TheDevice, ElemT *Pointer, size_t ElementCount) + : RegisteredHostMemoryBase(TheDevice, Pointer, + ElementCount * sizeof(ElemT)) {} + + RegisteredHostMemory(RegisteredHostMemory &&) = default; + RegisteredHostMemory &operator=(RegisteredHostMemory &&) = default; + + ElemT *getPointer() { return static_cast(Pointer); } + size_t getElementCount() const { return ByteCount / sizeof(ElemT); } + + /// Creates an immutable slice for the entire contents of this memory. + RegisteredHostMemorySlice asSlice() const { + return RegisteredHostMemorySlice(llvm::ArrayRef( + static_cast(Pointer), getElementCount())); + } + + /// Creates a mutable slice for the entire contents of this memory. + MutableRegisteredHostMemorySlice asSlice() { + return MutableRegisteredHostMemorySlice(llvm::MutableArrayRef( + static_cast(Pointer), getElementCount())); + } +}; + +/// \name Type trait unifying RegisteredHostMemorySlice and +/// MutableRegisteredHostMemorySlice. +/// +/// In many cases, it doesn't matter if a slice is mutable or not, so this type +/// trait makes it easy to write templated functions that take either kind of +/// slice. +/// +/// is_registered_host_memory_slice is std::true_type if T is either +/// RegisteredHostMemorySlice or MutableRegisteredHostMemorySlice. +/// Otherwise, it is std::false_type. +/// +/// @{ + +template +struct is_registered_host_memory_slice : public std::false_type {}; + +template +struct is_registered_host_memory_slice, ElemT> + : public std::true_type {}; + +template +struct is_registered_host_memory_slice, + ElemT> : public std::true_type {}; + +/// @} + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_HOSTMEMORY_H Index: streamexecutor/include/streamexecutor/PlatformDevice.h =================================================================== --- streamexecutor/include/streamexecutor/PlatformDevice.h +++ streamexecutor/include/streamexecutor/PlatformDevice.h @@ -68,8 +68,7 @@ /// Copies data from the device to the host. /// - /// HostDst should have been allocated by allocateHostMemory or registered - /// with registerHostMemory. + /// HostDst should have been registered with registerHostMemory. virtual Error copyD2H(const void *PlatformStreamHandle, const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { @@ -78,8 +77,7 @@ /// Copies data from the host to the device. /// - /// HostSrc should have been allocated by allocateHostMemory or registered - /// with registerHostMemory. + /// HostSrc should have been registered with registerHostMemory. virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc, size_t SrcByteOffset, const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { @@ -113,21 +111,6 @@ getName()); } - /// Allocates untyped host memory of a given size in bytes. - /// - /// Host memory allocated via this method is suitable for use with copyH2D and - /// copyD2H. - virtual Expected allocateHostMemory(size_t ByteCount) { - return make_error("allocateHostMemory not implemented for platform " + - getName()); - } - - /// Frees host memory allocated by allocateHostMemory. - virtual Error freeHostMemory(void *Memory) { - return make_error("freeHostMemory not implemented for platform " + - getName()); - } - /// Registers previously allocated host memory so it can be used with copyH2D /// and copyD2H. virtual Error registerHostMemory(void *Memory, size_t ByteCount) { @@ -136,7 +119,7 @@ } /// Unregisters host memory previously registered with registerHostMemory. - virtual Error unregisterHostMemory(void *Memory) { + virtual Error unregisterHostMemory(const void *Memory) { return make_error("unregisterHostMemory not implemented for platform " + getName()); } @@ -144,8 +127,8 @@ /// Copies the given number of bytes from device memory to host memory. /// /// Blocks the calling host thread until the copy is completed. Can operate on - /// any host memory, not just registered host memory or host memory allocated - /// by allocateHostMemory. Does not block any ongoing device calls. + /// any host memory, not just registered host memory. Does not block any + /// ongoing device calls. virtual Error synchronousCopyD2H(const void *DeviceSrcHandle, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, size_t ByteCount) { Index: streamexecutor/include/streamexecutor/Stream.h =================================================================== --- streamexecutor/include/streamexecutor/Stream.h +++ streamexecutor/include/streamexecutor/Stream.h @@ -33,8 +33,10 @@ #include #include #include +#include #include "streamexecutor/DeviceMemory.h" +#include "streamexecutor/HostMemory.h" #include "streamexecutor/Kernel.h" #include "streamexecutor/LaunchDimensions.h" #include "streamexecutor/PackedKernelArgumentArray.h" @@ -118,11 +120,6 @@ /// These methods enqueue a device memory copy operation on the stream and /// return without waiting for the operation to complete. /// - /// Any host memory used as a source or destination for one of these - /// operations must be allocated with Device::allocateHostMemory or registered - /// with Device::registerHostMemory. Otherwise, the enqueuing operation may - /// block until the copy operation is fully complete. - /// /// The arguments and bounds checking for these methods match the API of the /// \ref DeviceHostSyncCopyGroup /// "host-synchronous device memory copying functions" of Device. @@ -130,86 +127,108 @@ template Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, - llvm::MutableArrayRef Dst, size_t ElementCount) { + MutableRegisteredHostMemorySlice Dst, + size_t ElementCount) { if (ElementCount > Src.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", from a device array of element count " + llvm::Twine(Src.getElementCount())); - else if (ElementCount > Dst.size()) + else if (ElementCount > Dst.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + - ", to a host array of element count " + llvm::Twine(Dst.size())); + ", to a host array of element count " + + llvm::Twine(Dst.getElementCount())); else setError(PDevice->copyD2H(PlatformStreamHandle, Src.getBaseMemory().getHandle(), - Src.getElementOffset() * sizeof(T), Dst.data(), - 0, ElementCount * sizeof(T))); + Src.getElementOffset() * sizeof(T), + Dst.getPointer(), 0, ElementCount * sizeof(T))); return *this; } template Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, - llvm::MutableArrayRef Dst) { - if (Src.getElementCount() != Dst.size()) + RegisteredHostMemory &Dst, size_t ElementCount) { + return thenCopyD2H(Src, Dst.asSlice(), ElementCount); + } + + template + Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, + MutableRegisteredHostMemorySlice Dst) { + if (Src.getElementCount() != Dst.getElementCount()) setError("array size mismatch for D2H, device source has element count " + llvm::Twine(Src.getElementCount()) + " but host destination has element count " + - llvm::Twine(Dst.size())); + llvm::Twine(Dst.getElementCount())); else thenCopyD2H(Src, Dst, Src.getElementCount()); return *this; } template - Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, T *Dst, - size_t ElementCount) { - thenCopyD2H(Src, llvm::MutableArrayRef(Dst, ElementCount), ElementCount); - return *this; + Stream &thenCopyD2H(GlobalDeviceMemorySlice Src, + RegisteredHostMemory &Dst) { + return thenCopyD2H(Src, Dst.asSlice()); } template Stream &thenCopyD2H(const GlobalDeviceMemory &Src, - llvm::MutableArrayRef Dst, size_t ElementCount) { - thenCopyD2H(Src.asSlice(), Dst, ElementCount); - return *this; + MutableRegisteredHostMemorySlice Dst, + size_t ElementCount) { + return thenCopyD2H(Src.asSlice(), Dst, ElementCount); } template Stream &thenCopyD2H(const GlobalDeviceMemory &Src, - llvm::MutableArrayRef Dst) { - thenCopyD2H(Src.asSlice(), Dst); - return *this; + RegisteredHostMemory &Dst, size_t ElementCount) { + return thenCopyD2H(Src.asSlice(), Dst.asSlice(), ElementCount); } template - Stream &thenCopyD2H(const GlobalDeviceMemory &Src, T *Dst, - size_t ElementCount) { - thenCopyD2H(Src.asSlice(), Dst, ElementCount); - return *this; + Stream &thenCopyD2H(const GlobalDeviceMemory &Src, + MutableRegisteredHostMemorySlice Dst) { + return thenCopyD2H(Src.asSlice(), Dst); } template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemorySlice Dst, - size_t ElementCount) { - if (ElementCount > Src.size()) + Stream &thenCopyD2H(const GlobalDeviceMemory &Src, + RegisteredHostMemory &Dst) { + return thenCopyD2H(Src.asSlice(), Dst.asSlice()); + } + + template + typename std::enable_if::value, + Stream>::type & + thenCopyH2D(SrcSliceT Src, GlobalDeviceMemorySlice Dst, + size_t ElementCount) { + if (ElementCount > Src.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", from a host array of element count " + - llvm::Twine(Src.size())); + llvm::Twine(Src.getElementCount())); else if (ElementCount > Dst.getElementCount()) setError("copying too many elements, " + llvm::Twine(ElementCount) + ", to a device array of element count " + llvm::Twine(Dst.getElementCount())); else - setError(PDevice->copyH2D( - PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(), - Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); + setError(PDevice->copyH2D(PlatformStreamHandle, Src.getPointer(), 0, + Dst.getBaseMemory().getHandle(), + Dst.getElementOffset() * sizeof(T), + ElementCount * sizeof(T))); return *this; } template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemorySlice Dst) { - if (Src.size() != Dst.getElementCount()) + Stream &thenCopyH2D(const RegisteredHostMemory &Src, + GlobalDeviceMemorySlice Dst, size_t ElementCount) { + return thenCopyH2D(Src.asSlice(), Dst, ElementCount); + } + + template + typename std::enable_if::value, + Stream>::type & + thenCopyH2D(SrcSliceT Src, GlobalDeviceMemorySlice Dst) { + if (Src.getElementCount() != Dst.getElementCount()) setError("array size mismatch for H2D, host source has element count " + - llvm::Twine(Src.size()) + + llvm::Twine(Src.getElementCount()) + " but device destination has element count " + llvm::Twine(Dst.getElementCount())); else @@ -218,29 +237,35 @@ } template - Stream &thenCopyH2D(T *Src, GlobalDeviceMemorySlice Dst, - size_t ElementCount) { - thenCopyH2D(llvm::ArrayRef(Src, ElementCount), Dst, ElementCount); - return *this; + Stream &thenCopyH2D(const RegisteredHostMemory &Src, + GlobalDeviceMemorySlice Dst) { + return thenCopyH2D(Src.asSlice(), Dst); } - template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemory &Dst, - size_t ElementCount) { - thenCopyH2D(Src, Dst.asSlice(), ElementCount); - return *this; + template + typename std::enable_if::value, + Stream>::type & + thenCopyH2D(SrcSliceT Src, GlobalDeviceMemory &Dst, size_t ElementCount) { + return thenCopyH2D(Src, Dst.asSlice(), ElementCount); } template - Stream &thenCopyH2D(llvm::ArrayRef Src, GlobalDeviceMemory &Dst) { - thenCopyH2D(Src, Dst.asSlice()); - return *this; + Stream &thenCopyH2D(const RegisteredHostMemory &Src, + GlobalDeviceMemory &Dst, size_t ElementCount) { + return thenCopyH2D(Src.asSlice(), Dst, ElementCount); + } + + template + typename std::enable_if::value, + Stream>::type & + thenCopyH2D(SrcSliceT Src, GlobalDeviceMemory &Dst) { + return thenCopyH2D(Src, Dst.asSlice()); } template - Stream &thenCopyH2D(T *Src, GlobalDeviceMemory &Dst, size_t ElementCount) { - thenCopyH2D(Src, Dst.asSlice(), ElementCount); - return *this; + Stream &thenCopyH2D(const RegisteredHostMemory &Src, + GlobalDeviceMemory &Dst) { + return thenCopyH2D(Src.asSlice(), Dst); } template Index: streamexecutor/lib/CMakeLists.txt =================================================================== --- streamexecutor/lib/CMakeLists.txt +++ streamexecutor/lib/CMakeLists.txt @@ -8,6 +8,7 @@ $ Device.cpp DeviceMemory.cpp + HostMemory.cpp Kernel.cpp KernelSpec.cpp PackedKernelArgumentArray.cpp Index: streamexecutor/lib/HostMemory.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/HostMemory.cpp @@ -0,0 +1,27 @@ +//===-- HostMemory.cpp - HostMemory implementation ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation of HostMemory class internals. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/HostMemory.h" +#include "streamexecutor/Device.h" + +namespace streamexecutor { + +RegisteredHostMemoryBase::~RegisteredHostMemoryBase() { + // TODO(jhen): How to handle errors here? + if (Pointer) { + consumeError(TheDevice->unregisterHostMemory(*this)); + } +} + +} // namespace streamexecutor Index: streamexecutor/lib/unittests/DeviceTest.cpp =================================================================== --- streamexecutor/lib/unittests/DeviceTest.cpp +++ streamexecutor/lib/unittests/DeviceTest.cpp @@ -84,16 +84,11 @@ EXPECT_TRUE(static_cast(MaybeMemory)); } -TEST_F(DeviceTest, AllocateAndFreeHostMemory) { - se::Expected MaybeMemory = Device.allocateHostMemory(10); - EXPECT_TRUE(static_cast(MaybeMemory)); - EXPECT_NO_ERROR(Device.freeHostMemory(*MaybeMemory)); -} - TEST_F(DeviceTest, RegisterAndUnregisterHostMemory) { std::vector Data(10); - EXPECT_NO_ERROR(Device.registerHostMemory(Data.data(), 10)); - EXPECT_NO_ERROR(Device.unregisterHostMemory(Data.data())); + se::Expected> MaybeMemory = + Device.registerHostMemory(Data); + EXPECT_TRUE(static_cast(MaybeMemory)); } // D2H tests Index: streamexecutor/lib/unittests/SimpleHostPlatformDevice.h =================================================================== --- streamexecutor/lib/unittests/SimpleHostPlatformDevice.h +++ streamexecutor/lib/unittests/SimpleHostPlatformDevice.h @@ -48,22 +48,12 @@ return streamexecutor::Error::success(); } - streamexecutor::Expected - allocateHostMemory(size_t ByteCount) override { - return std::malloc(ByteCount); - } - - streamexecutor::Error freeHostMemory(void *Memory) override { - std::free(const_cast(Memory)); - return streamexecutor::Error::success(); - } - streamexecutor::Error registerHostMemory(void *Memory, size_t ByteCount) override { return streamexecutor::Error::success(); } - streamexecutor::Error unregisterHostMemory(void *Memory) override { + streamexecutor::Error unregisterHostMemory(const void *Memory) override { return streamexecutor::Error::success(); } Index: streamexecutor/lib/unittests/StreamTest.cpp =================================================================== --- streamexecutor/lib/unittests/StreamTest.cpp +++ streamexecutor/lib/unittests/StreamTest.cpp @@ -39,6 +39,10 @@ HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, + RegisteredHost5(getOrDie( + Device.registerHostMemory(llvm::MutableArrayRef(Host5)))), + RegisteredHost7(getOrDie( + Device.registerHostMemory(llvm::MutableArrayRef(Host7)))), DeviceA5(getOrDie(Device.allocateDeviceMemory(5))), DeviceB5(getOrDie(Device.allocateDeviceMemory(5))), DeviceA7(getOrDie(Device.allocateDeviceMemory(7))), @@ -66,6 +70,9 @@ int Host5[5]; int Host7[7]; + se::RegisteredHostMemory RegisteredHost5; + se::RegisteredHostMemory RegisteredHost7; + // Device memory. se::GlobalDeviceMemory DeviceA5; se::GlobalDeviceMemory DeviceB5; @@ -78,166 +85,119 @@ // D2H tests -TEST_F(StreamTest, CopyD2HToMutableArrayRefByCount) { - Stream.thenCopyD2H(DeviceA5, MutableArrayRef(Host5), 5); +TEST_F(StreamTest, CopyD2HToRegisteredRefByCount) { + Stream.thenCopyD2H(DeviceA5, RegisteredHost5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceB5, MutableArrayRef(Host5), 2); + Stream.thenCopyD2H(DeviceB5, RegisteredHost5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(HostB5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA7, MutableArrayRef(Host5), 7); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyD2HToMutableArrayRef) { - Stream.thenCopyD2H(DeviceA5, MutableArrayRef(Host5)); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); - } - - Stream.thenCopyD2H(DeviceA5, MutableArrayRef(Host7)); + Stream.thenCopyD2H(DeviceA7, RegisteredHost5, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HToPointer) { - Stream.thenCopyD2H(DeviceA5, Host5, 5); +TEST_F(StreamTest, CopyD2HToRegistered) { + Stream.thenCopyD2H(DeviceA5, RegisteredHost5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA5, Host7, 7); + Stream.thenCopyD2H(DeviceA5, RegisteredHost7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HSliceToMutableArrayRefByCount) { +TEST_F(StreamTest, CopyD2HSliceToRegiseredSliceByCount) { Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1), - MutableArrayRef(Host5 + 1, 4), 4); + RegisteredHost5.asSlice().slice(1, 4), 4); EXPECT_TRUE(Stream.isOK()); for (int I = 1; I < 5; ++I) { EXPECT_EQ(HostA5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceB5.asSlice().drop_back(1), - MutableArrayRef(Host5), 2); + Stream.thenCopyD2H(DeviceB5.asSlice().drop_back(1), RegisteredHost5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(HostB5[I], Host5[I]); } - Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef(Host7), 7); + Stream.thenCopyD2H(DeviceA5.asSlice(), RegisteredHost7, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyD2HSliceToMutableArrayRef) { - Stream.thenCopyD2H(DeviceA7.asSlice().slice(1, 5), - MutableArrayRef(Host5)); +TEST_F(StreamTest, CopyD2HSliceToRegistered) { + Stream.thenCopyD2H(DeviceA7.asSlice().slice(1, 5), RegisteredHost5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(HostA7[I + 1], Host5[I]); } - Stream.thenCopyD2H(DeviceA5.asSlice(), MutableArrayRef(Host7)); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyD2HSliceToPointer) { - Stream.thenCopyD2H(DeviceA5.asSlice().drop_front(1), Host5 + 1, 4); - EXPECT_TRUE(Stream.isOK()); - for (int I = 1; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); - } - - Stream.thenCopyD2H(DeviceA5.asSlice(), Host7, 7); + Stream.thenCopyD2H(DeviceA5.asSlice(), RegisteredHost7); EXPECT_FALSE(Stream.isOK()); } // H2D tests -TEST_F(StreamTest, CopyH2DToArrayRefByCount) { - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5, 5); +TEST_F(StreamTest, CopyH2DFromRegisterdByCount) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host5), DeviceB5, 2); + Stream.thenCopyH2D(RegisteredHost5, DeviceB5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5, 7); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyH2DToArrayRef) { - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); - } - - Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5, 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DToPointer) { - Stream.thenCopyH2D(Host5, DeviceA5, 5); +TEST_F(StreamTest, CopyH2DFromRegistered) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(Host7, DeviceA5, 7); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DSliceToArrayRefByCount) { - Stream.thenCopyH2D(ArrayRef(Host5 + 1, 4), +TEST_F(StreamTest, CopyH2DFromRegisteredSliceToSlice) { + Stream.thenCopyH2D(RegisteredHost5.asSlice().slice(1, 4), DeviceA5.asSlice().drop_front(1), 4); EXPECT_TRUE(Stream.isOK()); for (int I = 1; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host5), DeviceB5.asSlice().drop_back(1), 2); + Stream.thenCopyH2D(RegisteredHost5, DeviceB5.asSlice().drop_back(1), 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5.asSlice(), 7); + Stream.thenCopyH2D(RegisteredHost5, DeviceA5.asSlice(), 7); EXPECT_FALSE(Stream.isOK()); } -TEST_F(StreamTest, CopyH2DSliceToArrayRef) { - - Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5.asSlice()); +TEST_F(StreamTest, CopyH2DRegisteredToSlice) { + Stream.thenCopyH2D(RegisteredHost5, DeviceA5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } - Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5.asSlice()); - EXPECT_FALSE(Stream.isOK()); -} - -TEST_F(StreamTest, CopyH2DSliceToPointer) { - Stream.thenCopyH2D(Host5, DeviceA5.asSlice(), 5); - EXPECT_TRUE(Stream.isOK()); - for (int I = 0; I < 5; ++I) { - EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); - } - - Stream.thenCopyH2D(Host7, DeviceA5.asSlice(), 7); + Stream.thenCopyH2D(RegisteredHost7, DeviceA5.asSlice()); EXPECT_FALSE(Stream.isOK()); } @@ -289,7 +249,6 @@ } TEST_F(StreamTest, CopySliceD2D) { - Stream.thenCopyD2D(DeviceA7.asSlice().drop_back(2), DeviceB5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -318,7 +277,6 @@ } TEST_F(StreamTest, CopyD2DSlice) { - Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice().drop_back(2)); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -330,7 +288,6 @@ } TEST_F(StreamTest, CopySliceD2DSliceByCount) { - Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { @@ -348,7 +305,6 @@ } TEST_F(StreamTest, CopySliceD2DSlice) { - Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) {