Index: streamexecutor/examples/CUDASaxpy.cpp =================================================================== --- streamexecutor/examples/CUDASaxpy.cpp +++ streamexecutor/examples/CUDASaxpy.cpp @@ -108,25 +108,25 @@ if (Platform->getDeviceCount() == 0) { return EXIT_FAILURE; } - se::Device *Device = getOrDie(Platform->getDevice(0)); + se::Device Device = getOrDie(Platform->getDevice(0)); // Load the kernel onto the device. cg::SaxpyKernel Kernel = - getOrDie(Device->createKernel(cg::SaxpyLoaderSpec)); + getOrDie(Device.createKernel(cg::SaxpyLoaderSpec)); se::RegisteredHostMemory RegisteredX = - getOrDie(Device->registerHostMemory(HostX)); + getOrDie(Device.registerHostMemory(HostX)); se::RegisteredHostMemory RegisteredY = - getOrDie(Device->registerHostMemory(HostY)); + getOrDie(Device.registerHostMemory(HostY)); // Allocate memory on the device. se::GlobalDeviceMemory X = - getOrDie(Device->allocateDeviceMemory(ArraySize)); + getOrDie(Device.allocateDeviceMemory(ArraySize)); se::GlobalDeviceMemory Y = - getOrDie(Device->allocateDeviceMemory(ArraySize)); + getOrDie(Device.allocateDeviceMemory(ArraySize)); // Run operations on a stream. - se::Stream Stream = getOrDie(Device->createStream()); + se::Stream Stream = getOrDie(Device.createStream()); Stream.thenCopyH2D(RegisteredX, X) .thenCopyH2D(RegisteredY, Y) .thenLaunch(ArraySize, 1, Kernel, A, X, Y) Index: streamexecutor/examples/HostSaxpy.cpp =================================================================== --- streamexecutor/examples/HostSaxpy.cpp +++ streamexecutor/examples/HostSaxpy.cpp @@ -62,25 +62,25 @@ if (Platform->getDeviceCount() == 0) { return EXIT_FAILURE; } - se::Device *Device = getOrDie(Platform->getDevice(0)); + se::Device Device = getOrDie(Platform->getDevice(0)); // Load the kernel onto the device. cg::SaxpyKernel Kernel = - getOrDie(Device->createKernel(cg::SaxpyLoaderSpec)); + getOrDie(Device.createKernel(cg::SaxpyLoaderSpec)); se::RegisteredHostMemory RegisteredX = - getOrDie(Device->registerHostMemory(HostX)); + getOrDie(Device.registerHostMemory(HostX)); se::RegisteredHostMemory RegisteredY = - getOrDie(Device->registerHostMemory(HostY)); + getOrDie(Device.registerHostMemory(HostY)); // Allocate memory on the device. se::GlobalDeviceMemory X = - getOrDie(Device->allocateDeviceMemory(ArraySize)); + getOrDie(Device.allocateDeviceMemory(ArraySize)); se::GlobalDeviceMemory Y = - getOrDie(Device->allocateDeviceMemory(ArraySize)); + getOrDie(Device.allocateDeviceMemory(ArraySize)); // Run operations on a stream. - se::Stream Stream = getOrDie(Device->createStream()); + se::Stream Stream = getOrDie(Device.createStream()); Stream.thenCopyH2D(RegisteredX, X) .thenCopyH2D(RegisteredY, Y) .thenLaunch(1, 1, Kernel, A, X, Y, ArraySize) Index: streamexecutor/include/streamexecutor/Platform.h =================================================================== --- streamexecutor/include/streamexecutor/Platform.h +++ streamexecutor/include/streamexecutor/Platform.h @@ -31,10 +31,8 @@ /// Gets the number of devices available for this platform. virtual size_t getDeviceCount() const = 0; - /// Gets a pointer to a Device with the given index for this platform. - /// - /// Ownership of the Device instance is NOT transferred to the caller. - virtual Expected getDevice(size_t DeviceIndex) = 0; + /// Gets a Device with the given index for this platform. + virtual Expected getDevice(size_t DeviceIndex) = 0; }; } // namespace streamexecutor Index: streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h =================================================================== --- streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h +++ streamexecutor/include/streamexecutor/platforms/host/HostPlatform.h @@ -30,24 +30,21 @@ public: size_t getDeviceCount() const override { return 1; } - Expected getDevice(size_t DeviceIndex) override { + Expected getDevice(size_t DeviceIndex) override { if (DeviceIndex != 0) { return make_error( "Requested device index " + llvm::Twine(DeviceIndex) + " from host platform which only supports device index 0"); } llvm::sys::ScopedLock Lock(Mutex); - if (!TheDevice) { + if (!ThePlatformDevice) ThePlatformDevice = llvm::make_unique(); - TheDevice = llvm::make_unique(ThePlatformDevice.get()); - } - return TheDevice.get(); + return Device(ThePlatformDevice.get()); } private: llvm::sys::Mutex Mutex; std::unique_ptr ThePlatformDevice; - std::unique_ptr TheDevice; }; } // namespace host