Index: streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h =================================================================== --- streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h +++ streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h @@ -209,36 +209,12 @@ Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; } - // Pack a SharedDeviceMemoryBase argument. - void PackOneArgument(size_t Index, const SharedDeviceMemoryBase &Argument) { - ++SharedCount; - Addresses[Index] = nullptr; - Sizes[Index] = Argument.getByteCount(); - Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; - } - - // Pack a SharedDeviceMemoryBase pointer argument. - void PackOneArgument(size_t Index, SharedDeviceMemoryBase *Argument) { - ++SharedCount; - Addresses[Index] = nullptr; - Sizes[Index] = Argument->getByteCount(); - Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; - } - - // Pack a const SharedDeviceMemoryBase pointer argument. - void PackOneArgument(size_t Index, const SharedDeviceMemoryBase *Argument) { - ++SharedCount; - Addresses[Index] = nullptr; - Sizes[Index] = Argument->getByteCount(); - Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; - } - // Pack a SharedDeviceMemory argument. template void PackOneArgument(size_t Index, const SharedDeviceMemory &Argument) { ++SharedCount; Addresses[Index] = nullptr; - Sizes[Index] = Argument.getByteCount(); + Sizes[Index] = Argument.getElementCount() * sizeof(T); Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; } @@ -247,7 +223,7 @@ void PackOneArgument(size_t Index, SharedDeviceMemory *Argument) { ++SharedCount; Addresses[Index] = nullptr; - Sizes[Index] = Argument->getByteCount(); + Sizes[Index] = Argument->getElementCount() * sizeof(T); Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; } @@ -256,7 +232,7 @@ void PackOneArgument(size_t Index, const SharedDeviceMemory *Argument) { ++SharedCount; Addresses[Index] = nullptr; - Sizes[Index] = Argument->getByteCount(); + Sizes[Index] = Argument->getElementCount() * sizeof(T); Types[Index] = KernelArgumentType::SHARED_DEVICE_MEMORY; } Index: streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp =================================================================== --- streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp +++ streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp @@ -12,8 +12,11 @@ /// //===----------------------------------------------------------------------===// +#include "SimpleHostPlatformDevice.h" +#include "streamexecutor/Device.h" #include "streamexecutor/DeviceMemory.h" #include "streamexecutor/PackedKernelArgumentArray.h" +#include "streamexecutor/PlatformInterfaces.h" #include "llvm/ADT/Twine.h" @@ -32,21 +35,19 @@ class DeviceMemoryPackingTest : public ::testing::Test { public: DeviceMemoryPackingTest() - : Value(42), Handle(&Value), ByteCount(15), ElementCount(5), - UntypedGlobal(Handle, ByteCount), - TypedGlobal(se::GlobalDeviceMemory::makeFromElementCount( - Handle, ElementCount)), - UntypedShared(ByteCount), + : Device(&PDevice), Value(42), Handle(&Value), ByteCount(15), + ElementCount(5), + TypedGlobal(getOrDie(Device.allocateDeviceMemory(ElementCount))), TypedShared( se::SharedDeviceMemory::makeFromElementCount(ElementCount)) {} + se::test::SimpleHostPlatformDevice PDevice; + se::Device Device; int Value; void *Handle; size_t ByteCount; size_t ElementCount; - se::GlobalDeviceMemoryBase UntypedGlobal; se::GlobalDeviceMemory TypedGlobal; - se::SharedDeviceMemoryBase UntypedShared; se::SharedDeviceMemory TypedShared; }; @@ -73,38 +74,18 @@ EXPECT_EQ(0u, Array.getSharedCount()); } -TEST_F(DeviceMemoryPackingTest, SingleUntypedGlobal) { - auto Array = se::make_kernel_argument_pack(UntypedGlobal); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); - EXPECT_EQ(1u, Array.getArgumentCount()); - EXPECT_EQ(0u, Array.getSharedCount()); -} - -TEST_F(DeviceMemoryPackingTest, SingleUntypedGlobalPointer) { - auto Array = se::make_kernel_argument_pack(&UntypedGlobal); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); - EXPECT_EQ(1u, Array.getArgumentCount()); - EXPECT_EQ(0u, Array.getSharedCount()); -} - -TEST_F(DeviceMemoryPackingTest, SingleConstUntypedGlobalPointer) { - const se::GlobalDeviceMemoryBase *ConstPointer = &UntypedGlobal; - auto Array = se::make_kernel_argument_pack(ConstPointer); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); - EXPECT_EQ(1u, Array.getArgumentCount()); - EXPECT_EQ(0u, Array.getSharedCount()); -} - TEST_F(DeviceMemoryPackingTest, SingleTypedGlobal) { auto Array = se::make_kernel_argument_pack(TypedGlobal); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + Type::GLOBAL_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(0u, Array.getSharedCount()); } TEST_F(DeviceMemoryPackingTest, SingleTypedGlobalPointer) { auto Array = se::make_kernel_argument_pack(&TypedGlobal); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + Type::GLOBAL_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(0u, Array.getSharedCount()); } @@ -112,48 +93,24 @@ TEST_F(DeviceMemoryPackingTest, SingleConstTypedGlobalPointer) { const se::GlobalDeviceMemory *ArgumentPointer = &TypedGlobal; auto Array = se::make_kernel_argument_pack(ArgumentPointer); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); + ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + Type::GLOBAL_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(0u, Array.getSharedCount()); } -TEST_F(DeviceMemoryPackingTest, SingleUntypedShared) { - auto Array = se::make_kernel_argument_pack(UntypedShared); - ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 0); - EXPECT_EQ(1u, Array.getArgumentCount()); - EXPECT_EQ(1u, Array.getSharedCount()); -} - -TEST_F(DeviceMemoryPackingTest, SingleUntypedSharedPointer) { - auto Array = se::make_kernel_argument_pack(&UntypedShared); - ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 0); - EXPECT_EQ(1u, Array.getArgumentCount()); - EXPECT_EQ(1u, Array.getSharedCount()); -} - -TEST_F(DeviceMemoryPackingTest, SingleConstUntypedSharedPointer) { - const se::SharedDeviceMemoryBase *ArgumentPointer = &UntypedShared; - auto Array = se::make_kernel_argument_pack(ArgumentPointer); - ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 0); - EXPECT_EQ(1u, Array.getArgumentCount()); - EXPECT_EQ(1u, Array.getSharedCount()); -} - TEST_F(DeviceMemoryPackingTest, SingleTypedShared) { auto Array = se::make_kernel_argument_pack(TypedShared); - ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 0); + ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), + Type::SHARED_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(1u, Array.getSharedCount()); } TEST_F(DeviceMemoryPackingTest, SingleTypedSharedPointer) { auto Array = se::make_kernel_argument_pack(&TypedShared); - ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 0); + ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), + Type::SHARED_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(1u, Array.getSharedCount()); } @@ -161,42 +118,33 @@ TEST_F(DeviceMemoryPackingTest, SingleConstTypedSharedPointer) { const se::SharedDeviceMemory *ArgumentPointer = &TypedShared; auto Array = se::make_kernel_argument_pack(ArgumentPointer); - ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 0); + ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), + Type::SHARED_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(1u, Array.getSharedCount()); } TEST_F(DeviceMemoryPackingTest, PackSeveralArguments) { - const se::GlobalDeviceMemoryBase *UntypedGlobalPointer = &UntypedGlobal; const se::GlobalDeviceMemory *TypedGlobalPointer = &TypedGlobal; - const se::SharedDeviceMemoryBase *UntypedSharedPointer = &UntypedShared; const se::SharedDeviceMemory *TypedSharedPointer = &TypedShared; - auto Array = se::make_kernel_argument_pack( - Value, UntypedGlobal, &UntypedGlobal, UntypedGlobalPointer, TypedGlobal, - &TypedGlobal, TypedGlobalPointer, UntypedShared, &UntypedShared, - UntypedSharedPointer, TypedShared, &TypedShared, TypedSharedPointer); + auto Array = se::make_kernel_argument_pack(Value, TypedGlobal, &TypedGlobal, + TypedGlobalPointer, TypedShared, + &TypedShared, TypedSharedPointer); ExpectEqual(&Value, sizeof(Value), Type::VALUE, Array, 0); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 1); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 2); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 3); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 4); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 5); - ExpectEqual(Handle, sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 6); - ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 7); - ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 8); - ExpectEqual(nullptr, UntypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 9); - ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 10); - ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 11); - ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, - Array, 12); - EXPECT_EQ(13u, Array.getArgumentCount()); - EXPECT_EQ(6u, Array.getSharedCount()); + ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + Type::GLOBAL_DEVICE_MEMORY, Array, 1); + ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + Type::GLOBAL_DEVICE_MEMORY, Array, 2); + ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + Type::GLOBAL_DEVICE_MEMORY, Array, 3); + ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), + Type::SHARED_DEVICE_MEMORY, Array, 4); + ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), + Type::SHARED_DEVICE_MEMORY, Array, 5); + ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), + Type::SHARED_DEVICE_MEMORY, Array, 6); + EXPECT_EQ(7u, Array.getArgumentCount()); + EXPECT_EQ(3u, Array.getSharedCount()); } } // namespace