Index: parallel-libs/trunk/streamexecutor/include/streamexecutor/DeviceMemory.h =================================================================== --- parallel-libs/trunk/streamexecutor/include/streamexecutor/DeviceMemory.h +++ parallel-libs/trunk/streamexecutor/include/streamexecutor/DeviceMemory.h @@ -0,0 +1,224 @@ +//===-- DeviceMemory.h - Types representing device memory -------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines types that represent device memory buffers. Two memory +/// spaces are represented here: global and shared. Host code can have a handle +/// to device global memory, and that handle can be used to copy data to and +/// from the device. Host code cannot have a handle to device shared memory +/// because that memory only exists during the execution of a kernel. +/// +/// GlobalDeviceMemoryBase is similar to a pair consisting of a void* pointer +/// and a byte count to tell how much memory is pointed to by that void*. +/// +/// GlobalDeviceMemory is a subclass of GlobalDeviceMemoryBase which keeps +/// track of the type of element to be stored in the device array. It is similar +/// to a pair of a T* pointer and an element count to tell how many elements of +/// type T fit in the memory pointed to by that T*. +/// +/// SharedDeviceMemoryBase is just the size in bytes of a shared memory buffer. +/// +/// SharedDeviceMemory is a subclass of SharedDeviceMemoryBase which knows +/// how many elements of type T it can hold. +/// +/// These classes are useful for keeping track of which memory space a buffer +/// lives in, and the typed subclasses are useful for type-checking. +/// +/// The typed subclass will be used by user code, and the untyped base classes +/// will be used for type-unsafe operations inside of StreamExecutor. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_DEVICEMEMORY_H +#define STREAMEXECUTOR_DEVICEMEMORY_H + +#include + +namespace streamexecutor { + +/// Wrapper around a generic global device memory allocation. +/// +/// This class represents a buffer of untyped bytes in the global memory space +/// of a device. See GlobalDeviceMemory for the corresponding type that +/// includes type information for the elements in its buffer. +/// +/// This is effectively a pair consisting of an opaque handle and a buffer size +/// in bytes. The opaque handle is a platform-dependent handle to the actual +/// memory that is allocated on the device. +/// +/// In some cases, such as in the CUDA platform, the opaque handle may actually +/// be a pointer in the virtual address space and it may be valid to perform +/// arithmetic on it to obtain other device pointers, but this is not the case +/// in general. +/// +/// For example, in the OpenCL platform, the handle is a pointer to a _cl_mem +/// handle object which really is completely opaque to the user. +/// +/// The only fully platform-generic operations on handles are using them to +/// create new GlobalDeviceMemoryBase objects, and comparing them to each other +/// for equality. +class GlobalDeviceMemoryBase { +public: + /// Creates a GlobalDeviceMemoryBase from an optional handle and an optional + /// byte count. + explicit GlobalDeviceMemoryBase(const void *Handle = nullptr, + size_t ByteCount = 0) + : Handle(Handle), ByteCount(ByteCount) {} + + /// Copyable like a pointer. + GlobalDeviceMemoryBase(const GlobalDeviceMemoryBase &) = default; + + /// Copy-assignable like a pointer. + GlobalDeviceMemoryBase &operator=(const GlobalDeviceMemoryBase &) = default; + + /// Returns the size, in bytes, for the backing memory. + size_t getByteCount() const { return ByteCount; } + + /// Gets the internal handle. + /// + /// Warning: note that the pointer returned is not necessarily directly to + /// device virtual address space, but is platform-dependent. + const void *getHandle() const { return Handle; } + +private: + const void *Handle; // Platform-dependent value representing allocated memory. + size_t ByteCount; // Size in bytes of this allocation. +}; + +/// Typed wrapper around the "void *"-like GlobalDeviceMemoryBase class. +/// +/// For example, GlobalDeviceMemory is a simple wrapper around +/// GlobalDeviceMemoryBase that represents a buffer of integers stored in global +/// device memory. +template +class GlobalDeviceMemory : public GlobalDeviceMemoryBase { +public: + /// Creates a typed area of GlobalDeviceMemory with a given opaque handle and + /// the given element count. + static GlobalDeviceMemory makeFromElementCount(const void *Handle, + size_t ElementCount) { + return GlobalDeviceMemory(Handle, ElementCount); + } + + /// Creates a typed device memory region from an untyped device memory region. + /// + /// This effectively amounts to a cast from a void* to an ElemT*, but it also + /// manages the difference in the size measurements when + /// GlobalDeviceMemoryBase is measured in bytes and GlobalDeviceMemory is + /// measured in elements. + explicit GlobalDeviceMemory(const GlobalDeviceMemoryBase &Other) + : GlobalDeviceMemoryBase(Other.getHandle(), Other.getByteCount()) {} + + /// Copyable like a pointer. + GlobalDeviceMemory(const GlobalDeviceMemory &) = default; + + /// Copy-assignable like a pointer. + GlobalDeviceMemory &operator=(const GlobalDeviceMemory &) = default; + + /// Returns the number of elements of type ElemT that constitute this + /// allocation. + size_t getElementCount() const { return getByteCount() / sizeof(ElemT); } + +private: + /// Constructs a GlobalDeviceMemory instance from an opaque handle and an + /// element count. + /// + /// This constructor is not public because there is a potential for confusion + /// between the size of the buffer in bytes and the size of the buffer in + /// elements. + /// + /// The static method makeFromElementCount is provided for users of this class + /// because its name makes the meaning of the size parameter clear. + GlobalDeviceMemory(const void *Handle, size_t ElementCount) + : GlobalDeviceMemoryBase(Handle, ElementCount * sizeof(ElemT)) {} +}; + +/// A class to represent the size of a dynamic shared memory buffer on a device. +/// +/// This class maintains no information about the types to be stored in the +/// buffer. For the typed version of this class see SharedDeviceMemory. +/// +/// Shared memory buffers exist only on the device and cannot be manipulated +/// from the host, so instances of this class do not have an opaque handle, only +/// a size. +/// +/// This type of memory is called "local" memory in OpenCL and "shared" memory +/// in CUDA, and both platforms follow the rule that the host code only knows +/// the size of these buffers and does not have a handle to them. +/// +/// The treatment of shared memory in StreamExecutor matches the way it is done +/// in OpenCL, where a kernel takes any number of shared memory sizes as kernel +/// function arguments. +/// +/// In CUDA only one shared memory size argument is allowed per kernel call. +/// StreamExecutor handles this by allowing CUDA kernel signatures that take +/// multiple SharedDeviceMemory arguments, and simply adding together all the +/// shared memory sizes to get the final shared memory size that is used to +/// launch the kernel. +class SharedDeviceMemoryBase { +public: + /// Creates an untyped shared memory array from a byte count. + SharedDeviceMemoryBase(size_t ByteCount) : ByteCount(ByteCount) {} + + /// Copyable because it is just an array size. + SharedDeviceMemoryBase(const SharedDeviceMemoryBase &) = default; + + /// Copy-assignable because it is just an array size. + SharedDeviceMemoryBase &operator=(const SharedDeviceMemoryBase &) = default; + + /// Gets the byte count. + size_t getByteCount() const { return ByteCount; } + +private: + size_t ByteCount; +}; + +/// Typed wrapper around the untyped SharedDeviceMemoryBase class. +/// +/// For example, SharedDeviceMemory is a wrapper around +/// SharedDeviceMemoryBase that represents a buffer of integers stored in shared +/// device memory. +template +class SharedDeviceMemory : public SharedDeviceMemoryBase { +public: + /// Creates a typed area of shared device memory with a given number of + /// elements. + static SharedDeviceMemory makeFromElementCount(size_t ElementCount) { + return SharedDeviceMemory(ElementCount); + } + + /// Copyable because it is just an array size. + SharedDeviceMemory(const SharedDeviceMemory &) = default; + + /// Copy-assignable because it is just an array size. + SharedDeviceMemory &operator=(const SharedDeviceMemory &) = default; + + /// Returns the number of elements of type ElemT that can fit this memory + /// buffer. + size_t getElementCount() const { return getByteCount() / sizeof(ElemT); } + + /// Returns whether this is a single-element memory buffer. + bool isScalar() const { return getElementCount() == 1; } + +private: + /// Constructs a SharedDeviceMemory instance from an element count. + /// + /// This constructor is not public because there is a potential for confusion + /// between the size of the buffer in bytes and the size of the buffer in + /// elements. + /// + /// The static method makeFromElementCount is provided for users of this class + /// because its name makes the meaning of the size parameter clear. + explicit SharedDeviceMemory(size_t ElementCount) + : SharedDeviceMemoryBase(ElementCount * sizeof(ElemT)) {} +}; + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_DEVICEMEMORY_H Index: parallel-libs/trunk/streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h =================================================================== --- parallel-libs/trunk/streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h +++ parallel-libs/trunk/streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h @@ -0,0 +1,232 @@ +//===-- PackedKernelArgumentArray.h - Packed kernel arg types ---*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// The types in this file are designed to deal with the fact that device memory +/// kernel arguments are treated differently from other arguments during kernel +/// argument packing. +/// +/// GlobalDeviceMemory arguments are passed to a kernel by passing their +/// opaque handle. SharedDeviceMemory arguments have no associated address, +/// only a size, so the size is the only information that gets passed to the +/// kernel launch. +/// +/// The KernelArgumentType enum is used to keep track of the type of each +/// argument. +/// +/// The PackedKernelArgumentArray class uses template metaprogramming to convert +/// each argument to a PackedKernelArgument with minimal runtime overhead. +/// +/// The design of the PackedKernelArgumentArray class has a few idiosyncrasies +/// due to the fact that parameter packing has been identified as +/// performance-critical in some applications. The packed argument data is +/// stored as a struct of arrays rather than an array of structs because CUDA +/// kernel launches in the CUDA driver API take an array of argument addresses. +/// Having created the array of argument addresses here, no further work will +/// need to be done in the CUDA driver layer to unpack and repack the addresses. +/// +/// The shared memory argument count is maintained separately because in the +/// common case where it is zero, the CUDA layer doesn't have to loop through +/// the argument array and sum up all the shared memory sizes. This is another +/// performance optimization that shows up as a quirk in this class interface. +/// +/// The platform-interface kernel launch function will take the following +/// arguments, which are provided by this interface: +/// * argument count, +/// * array of argument address, +/// * array of argument sizes, +/// * array of argument types, and +/// * shared pointer count. +/// This information should be enough to allow any platform to launch the kernel +/// efficiently, although it is probably more information than is needed for any +/// specific platform. +/// +//===----------------------------------------------------------------------===// + +#ifndef STREAMEXECUTOR_PACKEDKERNELARGUMENTARRAY_H +#define STREAMEXECUTOR_PACKEDKERNELARGUMENTARRAY_H + +#include + +#include "streamexecutor/DeviceMemory.h" + +namespace streamexecutor { + +enum class KernelArgumentType { + VALUE, /// Non-device-memory argument. + GLOBAL_DEVICE_MEMORY, /// Non-shared device memory argument. + SHARED_DEVICE_MEMORY /// Shared device memory argument. +}; + +/// An array of packed kernel arguments. +template class PackedKernelArgumentArray { +public: + /// Constructs an instance by packing the specified arguments. + PackedKernelArgumentArray(const ParameterTs &... Arguments) + : SharedCount(0u) { + PackArguments(0, Arguments...); + } + + /// Gets the number of packed arguments. + size_t getArgumentCount() const { return sizeof...(ParameterTs); } + + /// Gets the address of the argument at the given index. + const void *getAddress(size_t Index) const { return Addresses[Index]; } + + /// Gets the size of the argument at the given index. + size_t getSize(size_t Index) const { return Sizes[Index]; } + + /// Gets the type of the argument at the given index. + KernelArgumentType getType(size_t Index) const { return Types[Index]; } + + /// Gets a pointer to the address array. + const void *const *getAddresses() const { return Addresses.data(); } + + /// Gets a pointer to the sizes array. + const size_t *getSizes() const { return Sizes.data(); } + + /// Gets a pointer to the types array. + const KernelArgumentType *getTypes() const { return Types.data(); } + + /// Gets the number of shared device memory arguments. + size_t getSharedCount() const { return SharedCount; } + +private: + // Base case for PackArguments when there are no arguments to pack. + void PackArguments(size_t) {} + + // Induction step for PackArguments. + template + void PackArguments(size_t Index, const T &Argument, + const RemainingParameterTs &... RemainingArguments) { + PackOneArgument(Index, Argument); + PackArguments(Index + 1, RemainingArguments...); + } + + // Pack a normal, non-device-memory argument. + template void PackOneArgument(size_t Index, const T &Argument) { + Addresses[Index] = &Argument; + Sizes[Index] = sizeof(T); + Types[Index] = KernelArgumentType::VALUE; + } + + // Pack a GlobalDeviceMemoryBase argument. + void PackOneArgument(size_t Index, const GlobalDeviceMemoryBase &Argument) { + Addresses[Index] = Argument.getHandle(); + Sizes[Index] = sizeof(void *); + Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; + } + + // Pack a GlobalDeviceMemoryBase pointer argument. + void PackOneArgument(size_t Index, GlobalDeviceMemoryBase *Argument) { + Addresses[Index] = Argument->getHandle(); + Sizes[Index] = sizeof(void *); + Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; + } + + // Pack a const GlobalDeviceMemoryBase pointer argument. + void PackOneArgument(size_t Index, const GlobalDeviceMemoryBase *Argument) { + Addresses[Index] = Argument->getHandle(); + Sizes[Index] = sizeof(void *); + Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; + } + + // Pack a GlobalDeviceMemory argument. + template + void PackOneArgument(size_t Index, const GlobalDeviceMemory &Argument) { + Addresses[Index] = Argument.getHandle(); + Sizes[Index] = sizeof(void *); + Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; + } + + // Pack a GlobalDeviceMemory pointer argument. + template + void PackOneArgument(size_t Index, GlobalDeviceMemory *Argument) { + Addresses[Index] = Argument->getHandle(); + Sizes[Index] = sizeof(void *); + Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; + } + + // Pack a const GlobalDeviceMemory pointer argument. + template + void PackOneArgument(size_t Index, const GlobalDeviceMemory *Argument) { + Addresses[Index] = Argument->getHandle(); + Sizes[Index] = sizeof(void *); + Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; + } + + // Pack a SharedDeviceMemoryBase argument. + void PackOneArgument(size_t Index, const SharedDeviceMemoryBase &Argument) { + ++SharedCount; + Addresses[Index] = nullptr; + Sizes[Index] = Argument.getByteCount(); + Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; + } + + // Pack a SharedDeviceMemoryBase pointer argument. + void PackOneArgument(size_t Index, SharedDeviceMemoryBase *Argument) { + ++SharedCount; + Addresses[Index] = nullptr; + Sizes[Index] = Argument->getByteCount(); + Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; + } + + // Pack a const SharedDeviceMemoryBase pointer argument. + void PackOneArgument(size_t Index, const SharedDeviceMemoryBase *Argument) { + ++SharedCount; + Addresses[Index] = nullptr; + Sizes[Index] = Argument->getByteCount(); + Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; + } + + // Pack a SharedDeviceMemory argument. + template + void PackOneArgument(size_t Index, const SharedDeviceMemory &Argument) { + ++SharedCount; + Addresses[Index] = nullptr; + Sizes[Index] = Argument.getByteCount(); + Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; + } + + // Pack a SharedDeviceMemory pointer argument. + template + void PackOneArgument(size_t Index, SharedDeviceMemory *Argument) { + ++SharedCount; + Addresses[Index] = nullptr; + Sizes[Index] = Argument->getByteCount(); + Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; + } + + // Pack a const SharedDeviceMemory pointer argument. + template + void PackOneArgument(size_t Index, const SharedDeviceMemory *Argument) { + ++SharedCount; + Addresses[Index] = nullptr; + Sizes[Index] = Argument->getByteCount(); + Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; + } + + std::array Addresses; + std::array Sizes; + std::array Types; + size_t SharedCount; +}; + +// Utility template function to call the PackedKernelArgumentArray constructor +// with the template arguments matching the types of the arguments passed to +// this function. +template +PackedKernelArgumentArray +make_kernel_argument_pack(const ParameterTs &... Arguments) { + return PackedKernelArgumentArray(Arguments...); +} + +} // namespace streamexecutor + +#endif // STREAMEXECUTOR_PACKEDKERNELARGUMENTARRAY_H Index: parallel-libs/trunk/streamexecutor/lib/unittests/CMakeLists.txt =================================================================== --- parallel-libs/trunk/streamexecutor/lib/unittests/CMakeLists.txt +++ parallel-libs/trunk/streamexecutor/lib/unittests/CMakeLists.txt @@ -17,3 +17,13 @@ ${GTEST_BOTH_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) add_test(KernelSpecTest kernel_spec_test) + +add_executable( + packed_kernel_argument_array_test + PackedKernelArgumentArrayTest.cpp) +target_link_libraries( + packed_kernel_argument_array_test + ${llvm_libs} + ${GTEST_BOTH_LIBRARIES} + ${CMAKE_THREAD_LIBS_INIT}) +add_test(PackedKernelArgumentArrayTest packed_kernel_argument_array_test) Index: parallel-libs/trunk/streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp =================================================================== --- parallel-libs/trunk/streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp +++ parallel-libs/trunk/streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp @@ -0,0 +1,202 @@ +//===-- PackedKernelArgumentArrayTest.cpp - tests for kernel arg packing --===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Unit tests for kernel argument packing. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/DeviceMemory.h" +#include "streamexecutor/PackedKernelArgumentArray.h" + +#include "llvm/ADT/Twine.h" + +#include "gtest/gtest.h" + +namespace { + +namespace se = ::streamexecutor; + +using Type = se::KernelArgumentType; + +// Test fixture class for testing argument packing. +// +// Basically defines a bunch of types to be packed so they don't have to be +// defined separately in each test. +class DeviceMemoryPackingTest : public ::testing::Test { +public: + DeviceMemoryPackingTest() + : Value(42), Handle(&Value), ByteCount(15), ElementCount(5), + UntypedGlobal(Handle, ByteCount), + TypedGlobal(se::GlobalDeviceMemory::makeFromElementCount( + Handle, ElementCount)), + UntypedShared(ByteCount), + TypedShared( + se::SharedDeviceMemory::makeFromElementCount(ElementCount)) {} + + int Value; + void *Handle; + size_t ByteCount; + size_t ElementCount; + se::GlobalDeviceMemoryBase UntypedGlobal; + se::GlobalDeviceMemory TypedGlobal; + se::SharedDeviceMemoryBase UntypedShared; + se::SharedDeviceMemory TypedShared; +}; + +// Utility method to check the expected address, size, and type for a packed +// argument at the given index of a PackedKernelArgumentArray. +template +static void +ExpectEqual(const void *ExpectedAddress, size_t ExpectedSize, Type ExpectedType, + const se::PackedKernelArgumentArray &Observed, + size_t Index) { + SCOPED_TRACE(("Index = " + llvm::Twine(Index)).str()); + EXPECT_EQ(ExpectedAddress, Observed.getAddress(Index)); + EXPECT_EQ(ExpectedAddress, Observed.getAddresses()[Index]); + EXPECT_EQ(ExpectedSize, Observed.getSize(Index)); + EXPECT_EQ(ExpectedSize, Observed.getSizes()[Index]); + EXPECT_EQ(ExpectedType, Observed.getType(Index)); + EXPECT_EQ(ExpectedType, Observed.getTypes()[Index]); +} + +TEST_F(DeviceMemoryPackingTest, SingleValue) { + auto Array = se::make_kernel_argument_pack(Value); + ExpectEqual(&Value, sizeof(Value), Type::VALUE, Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(0u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleUntypedGlobal) { + auto Array = se::make_kernel_argument_pack(UntypedGlobal); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(0u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleUntypedGlobalPointer) { + auto Array = se::make_kernel_argument_pack(&UntypedGlobal); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(0u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleConstUntypedGlobalPointer) { + const se::GlobalDeviceMemoryBase *ConstPointer = &UntypedGlobal; + auto Array = se::make_kernel_argument_pack(ConstPointer); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(0u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleTypedGlobal) { + auto Array = se::make_kernel_argument_pack(TypedGlobal); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(0u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleTypedGlobalPointer) { + auto Array = se::make_kernel_argument_pack(&TypedGlobal); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(0u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleConstTypedGlobalPointer) { + const se::GlobalDeviceMemory *ArgumentPointer = &TypedGlobal; + auto Array = se::make_kernel_argument_pack(ArgumentPointer); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(0u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleUntypedShared) { + auto Array = se::make_kernel_argument_pack(UntypedShared); + ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(1u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleUntypedSharedPointer) { + auto Array = se::make_kernel_argument_pack(&UntypedShared); + ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(1u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleConstUntypedSharedPointer) { + const se::SharedDeviceMemoryBase *ArgumentPointer = &UntypedShared; + auto Array = se::make_kernel_argument_pack(ArgumentPointer); + ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(1u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleTypedShared) { + auto Array = se::make_kernel_argument_pack(TypedShared); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(1u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleTypedSharedPointer) { + auto Array = se::make_kernel_argument_pack(&TypedShared); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(1u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, SingleConstTypedSharedPointer) { + const se::SharedDeviceMemory *ArgumentPointer = &TypedShared; + auto Array = se::make_kernel_argument_pack(ArgumentPointer); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); + EXPECT_EQ(1u, Array.getArgumentCount()); + EXPECT_EQ(1u, Array.getSharedCount()); +} + +TEST_F(DeviceMemoryPackingTest, PackSeveralArguments) { + const se::GlobalDeviceMemoryBase *UntypedGlobalPointer = &UntypedGlobal; + const se::GlobalDeviceMemory *TypedGlobalPointer = &TypedGlobal; + const se::SharedDeviceMemoryBase *UntypedSharedPointer = &UntypedShared; + const se::SharedDeviceMemory *TypedSharedPointer = &TypedShared; + auto Array = se::make_kernel_argument_pack( + Value, UntypedGlobal, &UntypedGlobal, UntypedGlobalPointer, TypedGlobal, + &TypedGlobal, TypedGlobalPointer, UntypedShared, &UntypedShared, + UntypedSharedPointer, TypedShared, &TypedShared, TypedSharedPointer); + ExpectEqual(&Value, sizeof(Value), Type::VALUE, Array, 0); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 1); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 2); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 3); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 4); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 5); + ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 6); + ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 7); + ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 8); + ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 9); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 10); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 11); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 12); + EXPECT_EQ(13u, Array.getArgumentCount()); + EXPECT_EQ(6u, Array.getSharedCount()); +} + +} // namespace