Index: streamexecutor/include/streamexecutor/Device.h =================================================================== --- streamexecutor/include/streamexecutor/Device.h +++ streamexecutor/include/streamexecutor/Device.h @@ -35,12 +35,11 @@ Expected::value, KernelT>::type> createKernel(const MultiKernelLoaderSpec &Spec) { - Expected> MaybeKernelHandle = - PDevice->createKernel(Spec); + Expected MaybeKernelHandle = PDevice->createKernel(Spec); if (!MaybeKernelHandle) { return MaybeKernelHandle.takeError(); } - return KernelT(Spec.getKernelName(), std::move(*MaybeKernelHandle)); + return KernelT(PDevice, *MaybeKernelHandle, Spec.getKernelName()); } /// Creates a stream object for this device. Index: streamexecutor/include/streamexecutor/Kernel.h =================================================================== --- streamexecutor/include/streamexecutor/Kernel.h +++ streamexecutor/include/streamexecutor/Kernel.h @@ -28,19 +28,32 @@ namespace streamexecutor { -class PlatformKernelHandle; +class PlatformDevice; /// The base class for all kernel types. /// /// Stores the name of the kernel in both mangled and demangled forms. class KernelBase { public: - KernelBase(llvm::StringRef Name); + KernelBase(PlatformDevice *D, const void *PlatformKernelHandle, + llvm::StringRef Name); + KernelBase(const KernelBase &Other) = delete; + KernelBase &operator=(const KernelBase &Other) = delete; + + KernelBase(KernelBase &&Other); + KernelBase &operator=(KernelBase &&Other); + + ~KernelBase(); + + const void *getPlatformHandle() const { return PlatformKernelHandle; } const std::string &getName() const { return Name; } const std::string &getDemangledName() const { return DemangledName; } private: + PlatformDevice *PDevice; + const void *PlatformKernelHandle; + std::string Name; std::string DemangledName; }; @@ -51,17 +64,12 @@ /// function. template class Kernel : public KernelBase { public: - Kernel(llvm::StringRef Name, std::unique_ptr PHandle) - : KernelBase(Name), PHandle(std::move(PHandle)) {} + Kernel(PlatformDevice *D, const void *PlatformKernelHandle, + llvm::StringRef Name) + : KernelBase(D, PlatformKernelHandle, Name) {} Kernel(Kernel &&Other) = default; Kernel &operator=(Kernel &&Other) = default; - - /// Gets the underlying platform-specific handle for this kernel. - PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); } - -private: - std::unique_ptr PHandle; }; } // namespace streamexecutor Index: streamexecutor/include/streamexecutor/PlatformInterfaces.h =================================================================== --- streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -31,34 +31,6 @@ namespace streamexecutor { -class PlatformDevice; - -/// Platform-specific kernel handle. -class PlatformKernelHandle { -public: - explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} - - virtual ~PlatformKernelHandle(); - - PlatformDevice *getDevice() { return PDevice; } - -private: - PlatformDevice *PDevice; -}; - -/// Platform-specific stream handle. -class PlatformStreamHandle { -public: - explicit PlatformStreamHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} - - virtual ~PlatformStreamHandle(); - - PlatformDevice *getDevice() { return PDevice; } - -private: - PlatformDevice *PDevice; -}; - /// Raw executor methods that must be implemented by each platform. /// /// This class defines the platform interface that supports executing work on a @@ -73,19 +45,30 @@ virtual std::string getName() const = 0; /// Creates a platform-specific kernel. - virtual Expected> + virtual Expected createKernel(const MultiKernelLoaderSpec &Spec) { return make_error("createKernel not implemented for platform " + getName()); } + virtual Error destroyKernel(const void *Handle) { + return make_error("destroyKernel not implemented for platform " + + getName()); + } + /// Creates a platform-specific stream. - virtual Expected> createStream() { + virtual Expected createStream() { return make_error("createStream not implemented for platform " + getName()); } + virtual Error destroyStream(const void *Handle) { + return make_error("destroyStream not implemented for platform " + + getName()); + } + /// Launches a kernel on the given stream. - virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize, - GridDimensions GridSize, PlatformKernelHandle *K, + virtual Error launch(const void *PlatformStreamHandle, + BlockDimensions BlockSize, GridDimensions GridSize, + const void *PKernelHandle, const PackedKernelArgumentArrayBase &ArgumentArray) { return make_error("launch not implemented for platform " + getName()); } @@ -94,9 +77,9 @@ /// /// HostDst should have been allocated by allocateHostMemory or registered /// with registerHostMemory. - virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle, - size_t SrcByteOffset, void *HostDst, - size_t DstByteOffset, size_t ByteCount) { + virtual Error copyD2H(const void *PlatformStreamHandle, + const void *DeviceSrcHandle, size_t SrcByteOffset, + void *HostDst, size_t DstByteOffset, size_t ByteCount) { return make_error("copyD2H not implemented for platform " + getName()); } @@ -104,22 +87,23 @@ /// /// HostSrc should have been allocated by allocateHostMemory or registered /// with registerHostMemory. - virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc, + virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc, size_t SrcByteOffset, const void *DeviceDstHandle, size_t DstByteOffset, size_t ByteCount) { return make_error("copyH2D not implemented for platform " + getName()); } /// Copies data from one device location to another. - virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle, - size_t SrcByteOffset, const void *DeviceDstHandle, - size_t DstByteOffset, size_t ByteCount) { + virtual Error copyD2D(const void *PlatformStreamHandle, + const void *DeviceSrcHandle, size_t SrcByteOffset, + const void *DeviceDstHandle, size_t DstByteOffset, + size_t ByteCount) { return make_error("copyD2D not implemented for platform " + getName()); } /// Blocks the host until the given stream completes all the work enqueued up /// to the point this function is called. - virtual Error blockHostUntilDone(PlatformStreamHandle *S) { + virtual Error blockHostUntilDone(const void *PlatformStreamHandle) { return make_error("blockHostUntilDone not implemented for platform " + getName()); } Index: streamexecutor/include/streamexecutor/Stream.h =================================================================== --- streamexecutor/include/streamexecutor/Stream.h +++ streamexecutor/include/streamexecutor/Stream.h @@ -59,10 +59,13 @@ /// of a stream once it is in an error state. class Stream { public: - explicit Stream(std::unique_ptr PStream); + Stream(PlatformDevice *D, const void *PlatformStreamHandle); - Stream(Stream &&Other) = default; - Stream &operator=(Stream &&Other) = default; + Stream(const Stream &Other) = delete; + Stream &operator=(const Stream &Other) = delete; + + Stream(Stream &&Other); + Stream &operator=(Stream &&Other); ~Stream(); @@ -88,7 +91,7 @@ // // Returns the result of getStatus() after the Stream work completes. Error blockHostUntilDone() { - setError(PDevice->blockHostUntilDone(ThePlatformStream.get())); + setError(PDevice->blockHostUntilDone(PlatformStreamHandle)); return getStatus(); } @@ -105,7 +108,7 @@ const ParameterTs &... Arguments) { auto ArgumentArray = make_kernel_argument_pack(Arguments...); - setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize, + setError(PDevice->launch(PlatformStreamHandle, BlockSize, GridSize, K.getPlatformHandle(), ArgumentArray)); return *this; } @@ -136,7 +139,7 @@ setError("copying too many elements, " + llvm::Twine(ElementCount) + ", to a host array of element count " + llvm::Twine(Dst.size())); else - setError(PDevice->copyD2H(ThePlatformStream.get(), + setError(PDevice->copyD2H(PlatformStreamHandle, Src.getBaseMemory().getHandle(), Src.getElementOffset() * sizeof(T), Dst.data(), 0, ElementCount * sizeof(T))); @@ -196,10 +199,9 @@ ", to a device array of element count " + llvm::Twine(Dst.getElementCount())); else - setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0, - Dst.getBaseMemory().getHandle(), - Dst.getElementOffset() * sizeof(T), - ElementCount * sizeof(T))); + setError(PDevice->copyH2D( + PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(), + Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); return *this; } @@ -254,7 +256,7 @@ llvm::Twine(Dst.getElementCount())); else setError(PDevice->copyD2D( - ThePlatformStream.get(), Src.getBaseMemory().getHandle(), + PlatformStreamHandle, Src.getBaseMemory().getHandle(), Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(), Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T))); return *this; @@ -342,7 +344,7 @@ PlatformDevice *PDevice; /// The platform-specific stream handle for this instance. - std::unique_ptr ThePlatformStream; + const void *PlatformStreamHandle; /// Mutex that guards the error state flags. std::unique_ptr ErrorMessageMutex; @@ -350,9 +352,6 @@ /// First error message for an operation in this stream or empty if there have /// been no errors. llvm::Optional ErrorMessage; - - Stream(const Stream &) = delete; - void operator=(const Stream &) = delete; }; } // namespace streamexecutor Index: streamexecutor/lib/Device.cpp =================================================================== --- streamexecutor/lib/Device.cpp +++ streamexecutor/lib/Device.cpp @@ -28,14 +28,11 @@ Device::~Device() = default; Expected Device::createStream() { - Expected> MaybePlatformStream = - PDevice->createStream(); + Expected MaybePlatformStream = PDevice->createStream(); if (!MaybePlatformStream) { return MaybePlatformStream.takeError(); } - assert((*MaybePlatformStream)->getDevice() == PDevice && - "an executor created a stream with a different stored executor"); - return Stream(std::move(*MaybePlatformStream)); + return Stream(PDevice, *MaybePlatformStream); } } // namespace streamexecutor Index: streamexecutor/lib/Kernel.cpp =================================================================== --- streamexecutor/lib/Kernel.cpp +++ streamexecutor/lib/Kernel.cpp @@ -12,16 +12,49 @@ /// //===----------------------------------------------------------------------===// -#include "streamexecutor/Kernel.h" +#include + #include "streamexecutor/Device.h" +#include "streamexecutor/Kernel.h" #include "streamexecutor/PlatformInterfaces.h" #include "llvm/DebugInfo/Symbolize/Symbolize.h" namespace streamexecutor { -KernelBase::KernelBase(llvm::StringRef Name) - : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName( - Name, nullptr)) {} +KernelBase::KernelBase(PlatformDevice *D, const void *PlatformKernelHandle, + llvm::StringRef Name) + : PDevice(D), PlatformKernelHandle(PlatformKernelHandle), Name(Name), + DemangledName( + llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr)) { + assert(D != nullptr && + "cannot construct a kernel object with a null platform device"); + assert(PlatformKernelHandle != nullptr && + "cannot construct a kernel object with a null platform kernel handle"); +} + +KernelBase::KernelBase(KernelBase &&Other) + : PDevice(Other.PDevice), PlatformKernelHandle(Other.PlatformKernelHandle), + Name(std::move(Other.Name)), + DemangledName(std::move(Other.DemangledName)) { + Other.PDevice = nullptr; + Other.PlatformKernelHandle = nullptr; +} + +KernelBase &KernelBase::operator=(KernelBase &&Other) { + PDevice = Other.PDevice; + PlatformKernelHandle = Other.PlatformKernelHandle; + Name = std::move(Other.Name); + DemangledName = std::move(Other.DemangledName); + Other.PDevice = nullptr; + Other.PlatformKernelHandle = nullptr; + return *this; +} + +KernelBase::~KernelBase() { + if (PlatformKernelHandle) + // TODO(jhen): Handle the error here. + consumeError(PDevice->destroyKernel(PlatformKernelHandle)); +} } // namespace streamexecutor Index: streamexecutor/lib/PlatformInterfaces.cpp =================================================================== --- streamexecutor/lib/PlatformInterfaces.cpp +++ streamexecutor/lib/PlatformInterfaces.cpp @@ -16,8 +16,6 @@ namespace streamexecutor { -PlatformStreamHandle::~PlatformStreamHandle() = default; - PlatformDevice::~PlatformDevice() = default; } // namespace streamexecutor Index: streamexecutor/lib/Stream.cpp =================================================================== --- streamexecutor/lib/Stream.cpp +++ streamexecutor/lib/Stream.cpp @@ -12,14 +12,43 @@ /// //===----------------------------------------------------------------------===// +#include + #include "streamexecutor/Stream.h" namespace streamexecutor { -Stream::Stream(std::unique_ptr PStream) - : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)), - ErrorMessageMutex(llvm::make_unique()) {} +Stream::Stream(PlatformDevice *D, const void *PlatformStreamHandle) + : PDevice(D), PlatformStreamHandle(PlatformStreamHandle), + ErrorMessageMutex(llvm::make_unique()) { + assert(D != nullptr && + "cannot construct a stream object with a null platform device"); + assert(PlatformStreamHandle != nullptr && + "cannot construct a stream object with a null platform stream handle"); +} + +Stream::Stream(Stream &&Other) + : PDevice(Other.PDevice), PlatformStreamHandle(Other.PlatformStreamHandle), + ErrorMessageMutex(std::move(Other.ErrorMessageMutex)), + ErrorMessage(std::move(Other.ErrorMessage)) { + Other.PDevice = nullptr; + Other.PlatformStreamHandle = nullptr; +} + +Stream &Stream::operator=(Stream &&Other) { + PDevice = Other.PDevice; + PlatformStreamHandle = Other.PlatformStreamHandle; + ErrorMessageMutex = std::move(Other.ErrorMessageMutex); + ErrorMessage = std::move(Other.ErrorMessage); + Other.PDevice = nullptr; + Other.PlatformStreamHandle = nullptr; + return *this; +} -Stream::~Stream() = default; +Stream::~Stream() { + if (PlatformStreamHandle) + // TODO(jhen): Handle error condition here. + consumeError(PDevice->destroyStream(PlatformStreamHandle)); +} } // namespace streamexecutor Index: streamexecutor/lib/unittests/SimpleHostPlatformDevice.h =================================================================== --- streamexecutor/lib/unittests/SimpleHostPlatformDevice.h +++ streamexecutor/lib/unittests/SimpleHostPlatformDevice.h @@ -34,9 +34,7 @@ public: std::string getName() const override { return "SimpleHostPlatformDevice"; } - streamexecutor::Expected< - std::unique_ptr> - createStream() override { + streamexecutor::Expected createStream() override { return nullptr; } @@ -69,7 +67,7 @@ return streamexecutor::Error::success(); } - streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S, + streamexecutor::Error copyD2H(const void *StreamHandle, const void *DeviceHandleSrc, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, @@ -80,8 +78,8 @@ return streamexecutor::Error::success(); } - streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S, - const void *HostSrc, size_t SrcByteOffset, + streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc, + size_t SrcByteOffset, const void *DeviceHandleDst, size_t DstByteOffset, size_t ByteCount) override { @@ -92,7 +90,7 @@ } streamexecutor::Error - copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc, + copyD2D(const void *StreamHandle, const void *DeviceHandleSrc, size_t SrcByteOffset, const void *DeviceHandleDst, size_t DstByteOffset, size_t ByteCount) override { std::memcpy(static_cast(const_cast(DeviceHandleDst)) + Index: streamexecutor/lib/unittests/StreamTest.cpp =================================================================== --- streamexecutor/lib/unittests/StreamTest.cpp +++ streamexecutor/lib/unittests/StreamTest.cpp @@ -34,11 +34,11 @@ class StreamTest : public ::testing::Test { public: StreamTest() - : Device(&PDevice), - Stream(llvm::make_unique(&PDevice)), - HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9}, - HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, - Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, + : DummyPlatformStream(1), Device(&PDevice), + Stream(&PDevice, &DummyPlatformStream), HostA5{0, 1, 2, 3, 4}, + HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, + HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28}, + Host7{29, 30, 31, 32, 33, 34, 35}, DeviceA5(getOrDie(Device.allocateDeviceMemory(5))), DeviceB5(getOrDie(Device.allocateDeviceMemory(5))), DeviceA7(getOrDie(Device.allocateDeviceMemory(7))), @@ -50,6 +50,8 @@ } protected: + int DummyPlatformStream; // Mimicking a platform where the platform stream + // handle is just a stream number. se::test::SimpleHostPlatformDevice PDevice; se::Device Device; se::Stream Stream;