Index: parallel-libs/trunk/streamexecutor/include/streamexecutor/DeviceMemory.h =================================================================== --- parallel-libs/trunk/streamexecutor/include/streamexecutor/DeviceMemory.h +++ parallel-libs/trunk/streamexecutor/include/streamexecutor/DeviceMemory.h @@ -71,6 +71,9 @@ /// Gets the number of elements in this slice. size_t getElementCount() const { return ElementCount; } + /// Returns the number of bytes that can fit in this slice. + size_t getByteCount() const { return ElementCount * sizeof(ElemT); } + /// Creates a slice of the memory with the first DropCount elements removed. GlobalDeviceMemorySlice drop_front(size_t DropCount) const { assert(DropCount <= ElementCount && @@ -175,6 +178,9 @@ /// allocation. size_t getElementCount() const { return ByteCount / sizeof(ElemT); } + /// Returns the number of bytes that can fit in this memory buffer. + size_t getByteCount() const { return ByteCount; } + /// Converts this memory object into a slice. GlobalDeviceMemorySlice asSlice() const { return GlobalDeviceMemorySlice(*this); @@ -224,10 +230,13 @@ /// Copy-assignable because it is just an array size. SharedDeviceMemory &operator=(const SharedDeviceMemory &) = default; - /// Returns the number of elements of type ElemT that can fit this memory + /// Returns the number of elements of type ElemT that can fit in this memory /// buffer. size_t getElementCount() const { return ElementCount; } + /// Returns the number of bytes that can fit in this memory buffer. + size_t getByteCount() const { return ElementCount * sizeof(ElemT); } + /// Returns whether this is a single-element memory buffer. bool isScalar() const { return getElementCount() == 1; } Index: parallel-libs/trunk/streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp =================================================================== --- parallel-libs/trunk/streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp +++ parallel-libs/trunk/streamexecutor/lib/unittests/PackedKernelArgumentArrayTest.cpp @@ -101,16 +101,16 @@ TEST_F(DeviceMemoryPackingTest, SingleTypedShared) { auto Array = se::make_kernel_argument_pack(TypedShared); - ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), - Type::SHARED_DEVICE_MEMORY, Array, 0); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(1u, Array.getSharedCount()); } TEST_F(DeviceMemoryPackingTest, SingleTypedSharedPointer) { auto Array = se::make_kernel_argument_pack(&TypedShared); - ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), - Type::SHARED_DEVICE_MEMORY, Array, 0); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(1u, Array.getSharedCount()); } @@ -118,8 +118,8 @@ TEST_F(DeviceMemoryPackingTest, SingleConstTypedSharedPointer) { const se::SharedDeviceMemory *ArgumentPointer = &TypedShared; auto Array = se::make_kernel_argument_pack(ArgumentPointer); - ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), - Type::SHARED_DEVICE_MEMORY, Array, 0); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(1u, Array.getSharedCount()); } @@ -137,12 +137,12 @@ Type::GLOBAL_DEVICE_MEMORY, Array, 2); ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 3); - ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), - Type::SHARED_DEVICE_MEMORY, Array, 4); - ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), - Type::SHARED_DEVICE_MEMORY, Array, 5); - ExpectEqual(nullptr, TypedShared.getElementCount() * sizeof(int), - Type::SHARED_DEVICE_MEMORY, Array, 6); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 4); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 5); + ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, + Array, 6); EXPECT_EQ(7u, Array.getArgumentCount()); EXPECT_EQ(3u, Array.getSharedCount()); }