Index: streamexecutor/CMakeLists.txt =================================================================== --- streamexecutor/CMakeLists.txt +++ streamexecutor/CMakeLists.txt @@ -24,6 +24,18 @@ include_directories(${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) + # Get the LLVM cxxflags by using llvm-config. + # + # This is necessary to get -fno-rtti if LLVM is compiled that way. + execute_process( + COMMAND + "${LLVM_BINARY_DIR}/bin/llvm-config" + --cxxflags + OUTPUT_VARIABLE + LLVM_CXXFLAGS + OUTPUT_STRIP_TRAILING_WHITESPACE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS}") + # Find the libraries that correspond to the LLVM components # that we wish to use llvm_map_components_to_libnames(llvm_libs support symbolize) Index: streamexecutor/include/streamexecutor/Interfaces.h =================================================================== --- streamexecutor/include/streamexecutor/Interfaces.h +++ /dev/null @@ -1,29 +0,0 @@ -//===-- Interfaces.h - Interfaces to platform-specific impls ----*- C++ -*-===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// Interfaces to platform-specific StreamExecutor type implementations. -/// -//===----------------------------------------------------------------------===// - -#ifndef STREAMEXECUTOR_INTERFACES_H -#define STREAMEXECUTOR_INTERFACES_H - -namespace streamexecutor { - -/// Methods supported by device kernel function objects on all platforms. -class KernelInterface { - // TODO(jhen): Add methods. -}; - -// TODO(jhen): Add other interfaces such as Stream. - -} // namespace streamexecutor - -#endif // STREAMEXECUTOR_INTERFACES_H Index: streamexecutor/include/streamexecutor/LaunchDimensions.h =================================================================== --- /dev/null +++ streamexecutor/include/streamexecutor/LaunchDimensions.h @@ -0,0 +1,47 @@ +//===-- LaunchDimensions.h - Kernel block and grid sizes --------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Structures to hold sizes for blocks and grids which are used as parameters +/// for kernel launches. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_LAUNCHDIMENSIONS_H +#define STREAMEXECUTOR_LAUNCHDIMENSIONS_H + +namespace streamexecutor { + +/// The dimensions of a device block of execution. +/// +/// A block is made up of an array of X by Y by Z threads. +struct BlockDimensions { + BlockDimensions(unsigned X = 1, unsigned Y = 1, unsigned Z = 1) + : X(X), Y(Y), Z(Z) {} + + unsigned X; + unsigned Y; + unsigned Z; +}; + +/// The dimensions of a device grid of execution. +/// +/// A grid is made up of an array of X by Y by Z blocks. +struct GridDimensions { + GridDimensions(unsigned X = 1, unsigned Y = 1, unsigned Z = 1) + : X(X), Y(Y), Z(Z) {} + + unsigned X; + unsigned Y; + unsigned Z; +}; + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_LAUNCHDIMENSIONS_H Index: streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h =================================================================== --- streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h +++ streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h @@ -47,6 +47,12 @@ /// efficiently, although it is probably more information than is needed for any /// specific platform. /// +/// The PackedKernelArgumentArrayBase class has no template parameters, so it +/// does not benefit from compile-time type checking. However, since it has no +/// template parameters, it can be passed as an argument to virtual functions, +/// and this allows it to be passed to functions that use virtual function +/// overloading to handle platform-specific kernel launching. +/// //===----------------------------------------------------------------------===// #ifndef STREAMEXECUTOR_PACKEDKERNELARGUMENTARRAY_H @@ -64,39 +70,81 @@ SHARED_DEVICE_MEMORY /// Shared device memory argument. }; -/// An array of packed kernel arguments. -template class PackedKernelArgumentArray { +/// An array of packed kernel arguments without compile-time type information. +/// +/// This un-templated base class is useful because packed kernel arguments must +/// at some point be passed to a virtual function that performs +/// platform-specific kernel launches. Such a virtual function cannot be +/// templated to handle all specializations of the +/// PackedKernelArgumentArray<...> class template, so, instead, references to +/// PackedKernelArgumentArray<...> are passed as references to this base class. +class PackedKernelArgumentArrayBase { public: - /// Constructs an instance by packing the specified arguments. - PackedKernelArgumentArray(const ParameterTs &... Arguments) - : SharedCount(0u) { - PackArguments(0, Arguments...); - } + virtual ~PackedKernelArgumentArrayBase(); /// Gets the number of packed arguments. - size_t getArgumentCount() const { return sizeof...(ParameterTs); } + size_t getArgumentCount() const { return ArgumentCount; } /// Gets the address of the argument at the given index. - const void *getAddress(size_t Index) const { return Addresses[Index]; } + const void *getAddress(size_t Index) const { return AddressesData[Index]; } /// Gets the size of the argument at the given index. - size_t getSize(size_t Index) const { return Sizes[Index]; } + size_t getSize(size_t Index) const { return SizesData[Index]; } /// Gets the type of the argument at the given index. - KernelArgumentType getType(size_t Index) const { return Types[Index]; } + KernelArgumentType getType(size_t Index) const { return TypesData[Index]; } /// Gets a pointer to the address array. - const void *const *getAddresses() const { return Addresses.data(); } + const void *const *getAddresses() const { return AddressesData; } /// Gets a pointer to the sizes array. - const size_t *getSizes() const { return Sizes.data(); } + const size_t *getSizes() const { return SizesData; } /// Gets a pointer to the types array. - const KernelArgumentType *getTypes() const { return Types.data(); } + const KernelArgumentType *getTypes() const { return TypesData; } /// Gets the number of shared device memory arguments. size_t getSharedCount() const { return SharedCount; } +protected: + PackedKernelArgumentArrayBase(size_t ArgumentCount) + : ArgumentCount(ArgumentCount), SharedCount(0u) {} + + size_t ArgumentCount; + size_t SharedCount; + const void *const *AddressesData; + size_t *SizesData; + KernelArgumentType *TypesData; +}; + +/// An array of packed kernel arguments with compile-time type information. +/// +/// This is used by the platform-independent StreamExecutor code to pack +/// arguments in a compile-time type-safe way. In order to actually launch a +/// kernel on a specific platform, however, a reference to this class will have +/// to be passed to a virtual, platform-specific kernel launch function. Such a +/// reference will be passed as a reference to the base class rather than a +/// reference to this subclass itself because a virtual function cannot be +/// templated in such a way to maintain the template parameter types of the +/// subclass. +template +class PackedKernelArgumentArray : public PackedKernelArgumentArrayBase { +public: + /// Constructs an instance by packing the specified arguments. + /// + /// Rather than using this constructor directly, consider using the + /// make_kernel_argument_pack function instead, to get the compiler to infer + /// the parameter types for you. + PackedKernelArgumentArray(const ParameterTs &... Arguments) + : PackedKernelArgumentArrayBase(sizeof...(ParameterTs)) { + AddressesData = Addresses.data(); + SizesData = Sizes.data(); + TypesData = Types.data(); + PackArguments(0, Arguments...); + } + + ~PackedKernelArgumentArray() override = default; + private: // Base case for PackArguments when there are no arguments to pack. void PackArguments(size_t) {} @@ -215,7 +263,6 @@ std::array Addresses; std::array Sizes; std::array Types; - size_t SharedCount; }; // Utility template function to call the PackedKernelArgumentArray constructor Index: streamexecutor/include/streamexecutor/PlatformInterfaces.h =================================================================== --- /dev/null +++ streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -0,0 +1,113 @@ +//===-- PlatformInterfaces.h - Interfaces to platform impls -----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Interfaces to platform-specific implementations. +/// +/// The general pattern is that the functions in these interfaces take raw +/// handle types as parameters. This means that these types and functions are +/// not intended for public use. Instead, corresponding methods in public types +/// like Stream, StreamExecutor, and Kernel use C++ templates to create +/// type-safe public interfaces. Those public functions do the type-unsafe work +/// of extracting raw handles from their arguments and forwarding those handles +/// to the methods defined in this file in the proper format. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_PLATFORMINTERFACES_H +#define STREAMEXECUTOR_PLATFORMINTERFACES_H + +#include "streamexecutor/DeviceMemory.h" +#include "streamexecutor/Kernel.h" +#include "streamexecutor/LaunchDimensions.h" +#include "streamexecutor/PackedKernelArgumentArray.h" +#include "streamexecutor/Utils/Error.h" + +namespace streamexecutor { + +class PlatformStreamExecutor; + +/// Methods supported by device kernel function objects on all platforms. +class KernelInterface { + // TODO(jhen): Add methods. +}; + +/// Platform-specific stream handle. +class PlatformStreamHandle { +public: + explicit PlatformStreamHandle( + PlatformStreamExecutor *ParentExecutor = nullptr) + : ParentExecutor(ParentExecutor) {} + + virtual ~PlatformStreamHandle(); + + PlatformStreamExecutor *getParentExecutor() { return ParentExecutor; } + + void setParentExecutor(PlatformStreamExecutor *PExecutor) { + ParentExecutor = PExecutor; + } + +private: + PlatformStreamExecutor *ParentExecutor; +}; + +/// Raw executor methods that must be implemented by each platform. +/// +/// This class defines the platform interface that supports executing work on a +/// device. +/// +/// The public StreamExecutor and Stream classes have the type-safe versions of +/// the functions in this interface. +class PlatformStreamExecutor { +public: + virtual ~PlatformStreamExecutor(); + + virtual std::string getName() const = 0; + + /// Creates a platform-specific stream. + virtual Expected> createStream() = 0; + + /// Launches a kernel on the given stream. + virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize, + GridDimensions GridSize, const KernelBase &Kernel, + const PackedKernelArgumentArrayBase &ArgumentArray) { + return make_error("launch not implemented for platform " + getName()); + } + + /// Copies data from the device to the host. + virtual Error memcpyD2H(PlatformStreamHandle *S, + const GlobalDeviceMemoryBase &DeviceSrc, + void *HostDst, size_t ByteCount) { + return make_error("memcpyD2H not implemented for platform " + getName()); + } + + /// Copies data from the host to the device. + virtual Error memcpyH2D(PlatformStreamHandle *S, const void *HostSrc, + GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) { + return make_error("memcpyH2D not implemented for platform " + getName()); + } + + /// Copies data from one device location to another. + virtual Error memcpyD2D(PlatformStreamHandle *S, + const GlobalDeviceMemoryBase &DeviceSrc, + GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) { + return make_error("memcpyD2D not implemented for platform " + getName()); + } + + /// Blocks the host until the given stream completes all the work enqueued up + /// to the point this function is called. + virtual Error blockHostUntilDone(PlatformStreamHandle *S) { + return make_error("blockHostUntilDone not implemented for platform " + + getName()); + } +}; + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_PLATFORMINTERFACES_H Index: streamexecutor/include/streamexecutor/Stream.h =================================================================== --- /dev/null +++ streamexecutor/include/streamexecutor/Stream.h @@ -0,0 +1,245 @@ +//===-- Stream.h - A stream of execution ------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// +/// A Stream instance represents a queue of sequential, host-asynchronous work +/// to be performed on a device. +/// +/// To enqueue work on a device, first create a StreamExecutor instance for a +/// given device and then use that StreamExecutor to create a Stream instance. +/// The Stream instance will perform its work on the device managed by the +/// StreamExecutor that created it. +/// +/// The various "then" methods of the Stream object, such as thenMemcpyH2D and +/// thenLaunch, may be used to enqueue work on the Stream, and the +/// blockHostUntilDone() method may be used to block the host code until the +/// Stream has completed all its work. +/// +/// Multiple Stream instances can be created for the same StreamExecutor. This +/// allows several independent streams of computation to be performed +/// simultaneously on a single device. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_STREAM_H +#define STREAMEXECUTOR_STREAM_H + +#include +#include +#include + +#include "streamexecutor/DeviceMemory.h" +#include "streamexecutor/Kernel.h" +#include "streamexecutor/LaunchDimensions.h" +#include "streamexecutor/PackedKernelArgumentArray.h" +#include "streamexecutor/PlatformInterfaces.h" +#include "streamexecutor/Utils/Error.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/RWMutex.h" + +namespace streamexecutor { + +/// Represents a stream of dependent computations on a device. +/// +/// The operations within a stream execute sequentially and asynchronously until +/// blockHostUntilDone() is invoked, which synchronously joins host code with +/// the execution of the stream. +/// +/// If any given operation fails when entraining work for the stream, isOK() +/// will indicate that an error has occurred and getStatus() will get the first +/// error that occurred on the stream. There is no way to clear the error state +/// of a stream once it is in an error state. +class Stream { +public: + explicit Stream(std::unique_ptr PStream); + + ~Stream(); + + /// Returns whether any error has occurred while entraining work on this + /// stream. + bool isOK() const { + llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex); + return !ErrorMessage; + } + + /// Returns the status created by the first error that occurred while + /// entraining work on this stream. + Error getStatus() const { + llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex); + if (ErrorMessage) + return make_error(*ErrorMessage); + else + return Error::success(); + }; + + /// Entrains onto the stream of operations a kernel launch with the given + /// arguments. + /// + /// These arguments can be device memory types like GlobalDeviceMemory and + /// SharedDeviceMemory, or they can be primitive types such as int. The + /// allowable argument types are determined by the template parameters to the + /// TypedKernel argument. + template + Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize, + const TypedKernel &Kernel, + const ParameterTs &... Arguments) { + auto ArgumentArray = + make_kernel_argument_pack(Arguments...); + setError(PlatformExecutor->launch(ThePlatformStream.get(), BlockSize, + GridSize, Kernel, ArgumentArray)); + return *this; + } + + /// Entrain onto the stream a memcpy of a given number of elements from a + /// device source to a host destination. + /// + /// HostDst must be a pointer to host memory allocated by + /// StreamExecutor::allocateHostMemory or otherwise allocated and then + /// registered with StreamExecutor::registerHostMemory. + template + Stream &thenMemcpyD2H(const GlobalDeviceMemory &DeviceSrc, + llvm::MutableArrayRef HostDst, size_t ElementCount) { + if (ElementCount > DeviceSrc.getElementCount()) + setError("copying too many elements, " + llvm::Twine(ElementCount) + + ", from device memory array of size " + + llvm::Twine(DeviceSrc.getElementCount())); + else if (ElementCount > HostDst.size()) + setError("copying too many elements, " + llvm::Twine(ElementCount) + + ", to host array of size " + llvm::Twine(HostDst.size())); + else + setError(PlatformExecutor->memcpyD2H(ThePlatformStream.get(), DeviceSrc, + HostDst.data(), + ElementCount * sizeof(T))); + return *this; + } + + /// Same as thenMemcpyD2H above, but copies the entire source to the + /// destination. + template + Stream &thenMemcpyD2H(const GlobalDeviceMemory &DeviceSrc, + llvm::MutableArrayRef HostDst) { + return thenMemcpyD2H(DeviceSrc, HostDst, DeviceSrc.getElementCount()); + } + + /// Entrain onto the stream a memcpy of a given number of elements from a host + /// source to a device destination. + /// + /// HostSrc must be a pointer to host memory allocated by + /// StreamExecutor::allocateHostMemory or otherwise allocated and then + /// registered with StreamExecutor::registerHostMemory. + template + Stream &thenMemcpyH2D(llvm::ArrayRef HostSrc, + GlobalDeviceMemory *DeviceDst, size_t ElementCount) { + if (ElementCount > HostSrc.size()) + setError("copying too many elements, " + llvm::Twine(ElementCount) + + ", from host array of size " + llvm::Twine(HostSrc.size())); + else if (ElementCount > DeviceDst->getElementCount()) + setError("copying too many elements, " + llvm::Twine(ElementCount) + + ", to device memory array of size " + + llvm::Twine(DeviceDst->getElementCount())); + else + setError(PlatformExecutor->memcpyH2D(ThePlatformStream.get(), + HostSrc.data(), DeviceDst, + ElementCount * sizeof(T))); + return *this; + } + + /// Same as thenMemcpyH2D above, but copies the entire source to the + /// destination. + template + Stream &thenMemcpyH2D(llvm::ArrayRef HostSrc, + GlobalDeviceMemory *DeviceDst) { + return thenMemcpyH2D(HostSrc, DeviceDst, HostSrc.size()); + } + + /// Entrain onto the stream a memcpy of a given number of elements from a + /// device source to a device destination. + template + Stream &thenMemcpyD2D(const GlobalDeviceMemory &DeviceSrc, + GlobalDeviceMemory *DeviceDst, size_t ElementCount) { + if (ElementCount > DeviceSrc.getElementCount()) + setError("copying too many elements, " + llvm::Twine(ElementCount) + + ", from device memory array of size " + + llvm::Twine(DeviceSrc.getElementCount())); + else if (ElementCount > DeviceDst->getElementCount()) + setError("copying too many elements, " + llvm::Twine(ElementCount) + + ", to device memory array of size " + + llvm::Twine(DeviceDst->getElementCount())); + else + setError(PlatformExecutor->memcpyD2D(ThePlatformStream.get(), DeviceSrc, + DeviceDst, + ElementCount * sizeof(T))); + return *this; + } + + /// Same as thenMemcpyD2D above, but copies the entire source to the + /// destination. + template + Stream &thenMemcpyD2D(const GlobalDeviceMemory &DeviceSrc, + GlobalDeviceMemory *DeviceDst) { + return thenMemcpyD2D(DeviceSrc, DeviceDst, DeviceSrc.getElementCount()); + } + + /// Blocks the host code, waiting for the operations entrained on the stream + /// (enqueued up to this point in program execution) to complete. + /// + /// Returns true if there are no errors on the stream. + bool blockHostUntilDone() { + Error E = PlatformExecutor->blockHostUntilDone(ThePlatformStream.get()); + bool returnValue = static_cast(E); + setError(std::move(E)); + return returnValue; + } + +private: + /// Sets the error state from an Error object. + /// + /// Does not overwrite the error if it is already set. + void setError(Error &&E) { + if (E) { + llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex); + if (!ErrorMessage) + ErrorMessage = consumeAndGetMessage(std::move(E)); + } + } + + /// Sets the error state from an error message. + /// + /// Does not overwrite the error if it is already set. + void setError(llvm::Twine Message) { + llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex); + if (!ErrorMessage) + ErrorMessage = Message.str(); + } + + /// The PlatformStreamExecutor that supports the operations of this stream. + PlatformStreamExecutor *PlatformExecutor; + + /// The platform-specific stream handle for this instance. + std::unique_ptr ThePlatformStream; + + /// Mutex that guards the error state flags. + /// + /// Mutable so that it can be obtained via const reader lock. + mutable llvm::sys::RWMutex ErrorMessageMutex; + + /// First error message for an operation in this stream or empty if there have + /// been no errors. + llvm::Optional ErrorMessage; + + Stream(const Stream &) = delete; + void operator=(const Stream &) = delete; +}; + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_STREAM_H Index: streamexecutor/include/streamexecutor/StreamExecutor.h =================================================================== --- streamexecutor/include/streamexecutor/StreamExecutor.h +++ streamexecutor/include/streamexecutor/StreamExecutor.h @@ -16,14 +16,20 @@ #ifndef STREAMEXECUTOR_STREAMEXECUTOR_H #define STREAMEXECUTOR_STREAMEXECUTOR_H +#include "streamexecutor/KernelSpec.h" #include "streamexecutor/Utils/Error.h" namespace streamexecutor { class KernelInterface; +class PlatformStreamExecutor; +class Stream; class StreamExecutor { public: + explicit StreamExecutor(PlatformStreamExecutor *PlatformExecutor); + virtual ~StreamExecutor(); + /// Gets the kernel implementation for the underlying platform. virtual Expected> getKernelImplementation(const MultiKernelLoaderSpec &Spec) { @@ -31,7 +37,10 @@ return nullptr; } - // TODO(jhen): Add other methods. + Expected> createStream(); + +private: + PlatformStreamExecutor *PlatformExecutor; }; } // namespace streamexecutor Index: streamexecutor/lib/CMakeLists.txt =================================================================== --- streamexecutor/lib/CMakeLists.txt +++ streamexecutor/lib/CMakeLists.txt @@ -7,7 +7,11 @@ streamexecutor $ Kernel.cpp - KernelSpec.cpp) + KernelSpec.cpp + PackedKernelArgumentArray.cpp + PlatformInterfaces.cpp + Stream.cpp + StreamExecutor.cpp) target_link_libraries(streamexecutor ${llvm_libs}) if(STREAM_EXECUTOR_UNIT_TESTS) Index: streamexecutor/lib/Kernel.cpp =================================================================== --- streamexecutor/lib/Kernel.cpp +++ streamexecutor/lib/Kernel.cpp @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// #include "streamexecutor/Kernel.h" -#include "streamexecutor/Interfaces.h" +#include "streamexecutor/PlatformInterfaces.h" #include "streamexecutor/StreamExecutor.h" #include "llvm/DebugInfo/Symbolize/Symbolize.h" Index: streamexecutor/lib/PackedKernelArgumentArray.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/PackedKernelArgumentArray.cpp @@ -0,0 +1,21 @@ +//===-- PackedKernelArgumentArray.cpp - Packed argument array impl --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation details for classes from PackedKernelArgumentArray.h. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/PackedKernelArgumentArray.h" + +namespace streamexecutor { + +PackedKernelArgumentArrayBase::~PackedKernelArgumentArrayBase() = default; + +} // namespace streamexecutor Index: streamexecutor/lib/PlatformInterfaces.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/PlatformInterfaces.cpp @@ -0,0 +1,23 @@ +//===-- PlatformInterfaces.cpp - Platform interface implementations -------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation file for PlatformInterfaces.h. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/PlatformInterfaces.h" + +namespace streamexecutor { + +PlatformStreamHandle::~PlatformStreamHandle() = default; + +PlatformStreamExecutor::~PlatformStreamExecutor() = default; + +} // namespace streamexecutor Index: streamexecutor/lib/Stream.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/Stream.cpp @@ -0,0 +1,25 @@ +//===-- Stream.cpp - General stream implementation ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the implementation details for a general stream object. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/Stream.h" + +namespace streamexecutor { + +Stream::Stream(std::unique_ptr PStream) + : PlatformExecutor(PStream->getParentExecutor()), + ThePlatformStream(std::move(PStream)) {} + +Stream::~Stream() = default; + +} // namespace streamexecutor Index: streamexecutor/lib/StreamExecutor.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/StreamExecutor.cpp @@ -0,0 +1,39 @@ +//===-- StreamExecutor.cpp - StreamExecutor implementation ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation of StreamExecutor class internals. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/StreamExecutor.h" + +#include "streamexecutor/PlatformInterfaces.h" +#include "streamexecutor/Stream.h" + +#include "llvm/ADT/STLExtras.h" + +namespace streamexecutor { + +StreamExecutor::StreamExecutor(PlatformStreamExecutor *PlatformExecutor) + : PlatformExecutor(PlatformExecutor) {} + +StreamExecutor::~StreamExecutor() = default; + +Expected> StreamExecutor::createStream() { + Expected> MaybePlatformStream = + PlatformExecutor->createStream(); + if (!MaybePlatformStream) { + return MaybePlatformStream.takeError(); + } + (*MaybePlatformStream)->setParentExecutor(PlatformExecutor); + return llvm::make_unique(std::move(*MaybePlatformStream)); +} + +} // namespace streamexecutor Index: streamexecutor/lib/unittests/CMakeLists.txt =================================================================== --- streamexecutor/lib/unittests/CMakeLists.txt +++ streamexecutor/lib/unittests/CMakeLists.txt @@ -23,7 +23,19 @@ PackedKernelArgumentArrayTest.cpp) target_link_libraries( packed_kernel_argument_array_test + streamexecutor ${llvm_libs} ${GTEST_BOTH_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) add_test(PackedKernelArgumentArrayTest packed_kernel_argument_array_test) + +add_executable( + stream_test + StreamTest.cpp) +target_link_libraries( + stream_test + streamexecutor + ${llvm_libs} + ${GTEST_BOTH_LIBRARIES} + ${CMAKE_THREAD_LIBS_INIT}) +add_test(StreamTest stream_test) Index: streamexecutor/lib/unittests/KernelTest.cpp =================================================================== --- streamexecutor/lib/unittests/KernelTest.cpp +++ streamexecutor/lib/unittests/KernelTest.cpp @@ -14,9 +14,9 @@ #include -#include "streamexecutor/Interfaces.h" #include "streamexecutor/Kernel.h" #include "streamexecutor/KernelSpec.h" +#include "streamexecutor/PlatformInterfaces.h" #include "streamexecutor/StreamExecutor.h" #include "llvm/ADT/STLExtras.h" @@ -42,7 +42,8 @@ class MockStreamExecutor : public se::StreamExecutor { public: MockStreamExecutor() - : Unique(llvm::make_unique()), Raw(Unique.get()) {} + : se::StreamExecutor(nullptr), + Unique(llvm::make_unique()), Raw(Unique.get()) {} // Moves the unique pointer into the returned se::Expected instance. // Index: streamexecutor/lib/unittests/StreamTest.cpp =================================================================== --- /dev/null +++ streamexecutor/lib/unittests/StreamTest.cpp @@ -0,0 +1,116 @@ +//===-- StreamTest.cpp - Tests for Stream ---------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the unit tests for Stream code. +/// +//===----------------------------------------------------------------------===// + +#include + +#include "streamexecutor/Kernel.h" +#include "streamexecutor/KernelSpec.h" +#include "streamexecutor/PlatformInterfaces.h" +#include "streamexecutor/Stream.h" +#include "streamexecutor/StreamExecutor.h" + +#include "gtest/gtest.h" + +namespace { + +namespace se = ::streamexecutor; + +/// Mock PlatformStreamExecutor that performs asynchronous memcpy operations by +/// ignoring the stream argument and calling std::memcpy on device memory +/// handles. +class MockPlatformStreamExecutor : public se::PlatformStreamExecutor { +public: + ~MockPlatformStreamExecutor() override {} + + std::string getName() const override { return "MockPlatformStreamExecutor"; } + + se::Expected> + createStream() override { + return nullptr; + } + + se::Error memcpyD2H(se::PlatformStreamHandle *, + const se::GlobalDeviceMemoryBase &DeviceSrc, + void *HostDst, size_t ByteCount) override { + std::memcpy(HostDst, DeviceSrc.getHandle(), ByteCount); + return se::Error::success(); + } + + se::Error memcpyH2D(se::PlatformStreamHandle *, const void *HostSrc, + se::GlobalDeviceMemoryBase *DeviceDst, + size_t ByteCount) override { + std::memcpy(const_cast(DeviceDst->getHandle()), HostSrc, ByteCount); + return se::Error::success(); + } + + se::Error memcpyD2D(se::PlatformStreamHandle *, + const se::GlobalDeviceMemoryBase &DeviceSrc, + se::GlobalDeviceMemoryBase *DeviceDst, + size_t ByteCount) override { + std::memcpy(const_cast(DeviceDst->getHandle()), + DeviceSrc.getHandle(), ByteCount); + return se::Error::success(); + } +}; + +/// Test fixture to hold objects used by tests. +class StreamTest : public ::testing::Test { +public: + StreamTest() + : DeviceA(se::GlobalDeviceMemory::makeFromElementCount(HostA, 10)), + DeviceB(se::GlobalDeviceMemory::makeFromElementCount(HostB, 10)), + Stream(llvm::make_unique(&PlatformExecutor)) { + } + +protected: + // Device memory is backed by host arrays. + int HostA[10]; + se::GlobalDeviceMemory DeviceA; + int HostB[10]; + se::GlobalDeviceMemory DeviceB; + + // Host memory to be used as actual host memory. + int Host[10]; + + MockPlatformStreamExecutor PlatformExecutor; + se::Stream Stream; +}; + +TEST_F(StreamTest, MemcpyCorrectSize) { + Stream.thenMemcpyH2D(llvm::ArrayRef(Host), &DeviceA); + EXPECT_TRUE(Stream.isOK()); + + Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef(Host)); + EXPECT_TRUE(Stream.isOK()); + + Stream.thenMemcpyD2D(DeviceA, &DeviceB); + EXPECT_TRUE(Stream.isOK()); +} + +TEST_F(StreamTest, MemcpyH2DTooManyElements) { + Stream.thenMemcpyH2D(llvm::ArrayRef(Host), &DeviceA, 20); + EXPECT_FALSE(Stream.isOK()); +} + +TEST_F(StreamTest, MemcpyD2HTooManyElements) { + Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef(Host), 20); + EXPECT_FALSE(Stream.isOK()); +} + +TEST_F(StreamTest, MemcpyD2DTooManyElements) { + Stream.thenMemcpyD2D(DeviceA, &DeviceB, 20); + EXPECT_FALSE(Stream.isOK()); +} + +} // namespace