Index: streamexecutor/include/streamexecutor/Executor.h =================================================================== --- streamexecutor/include/streamexecutor/Executor.h +++ streamexecutor/include/streamexecutor/Executor.h @@ -41,7 +41,11 @@ /// Allocates an array of ElementCount entries of type T in device memory. template Expected> allocateDeviceMemory(size_t ElementCount) { - return PExecutor->allocateDeviceMemory(ElementCount * sizeof(T)); + Expected MaybeBase = + PExecutor->allocateDeviceMemory(ElementCount * sizeof(T)); + if (!MaybeBase) + return MaybeBase.takeError(); + return GlobalDeviceMemory(*MaybeBase); } /// Frees memory previously allocated with allocateDeviceMemory. @@ -54,7 +58,11 @@ /// Host memory allocated by this function can be used for asynchronous memory /// copies on streams. See Stream::thenCopyD2H and Stream::thenCopyH2D. template Expected allocateHostMemory(size_t ElementCount) { - return PExecutor->allocateHostMemory(ElementCount * sizeof(T)); + Expected MaybeMemory = + PExecutor->allocateHostMemory(ElementCount * sizeof(T)); + if (!MaybeMemory) + return MaybeMemory.takeError(); + return static_cast(*MaybeMemory); } /// Frees memory previously allocated with allocateHostMemory. Index: streamexecutor/lib/unittests/ExecutorTest.cpp =================================================================== --- streamexecutor/lib/unittests/ExecutorTest.cpp +++ streamexecutor/lib/unittests/ExecutorTest.cpp @@ -54,6 +54,14 @@ return se::Error::success(); } + se::Error registerHostMemory(void *, size_t) override { + return se::Error::success(); + } + + se::Error unregisterHostMemory(void *) override { + return se::Error::success(); + } + se::Error synchronousCopyD2H(const se::GlobalDeviceMemoryBase &DeviceSrc, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, @@ -131,6 +139,25 @@ using llvm::ArrayRef; using llvm::MutableArrayRef; +TEST_F(ExecutorTest, AllocateAndFreeDeviceMemory) { + se::Expected> MaybeMemory = + Executor.allocateDeviceMemory(10); + EXPECT_TRUE(static_cast(MaybeMemory)); + EXPECT_NO_ERROR(Executor.freeDeviceMemory(*MaybeMemory)); +} + +TEST_F(ExecutorTest, AllocateAndFreeHostMemory) { + se::Expected MaybeMemory = Executor.allocateHostMemory(10); + EXPECT_TRUE(static_cast(MaybeMemory)); + EXPECT_NO_ERROR(Executor.freeHostMemory(*MaybeMemory)); +} + +TEST_F(ExecutorTest, RegisterAndUnregisterHostMemory) { + std::vector Data(10); + EXPECT_NO_ERROR(Executor.registerHostMemory(Data.data(), 10)); + EXPECT_NO_ERROR(Executor.unregisterHostMemory(Data.data())); +} + // D2H tests TEST_F(ExecutorTest, SyncCopyD2HToMutableArrayRefByCount) {