Index: streamexecutor/examples/HostSaxpy.cpp =================================================================== --- streamexecutor/examples/HostSaxpy.cpp +++ streamexecutor/examples/HostSaxpy.cpp @@ -33,8 +33,8 @@ // Wrapper function converts argument addresses to arguments. void SaxpyWrapper(const void *const *ArgumentAddresses) { Saxpy(*static_cast(ArgumentAddresses[0]), - static_cast(const_cast(ArgumentAddresses[1])), - static_cast(const_cast(ArgumentAddresses[2])), + *static_cast(const_cast(ArgumentAddresses[1])), + *static_cast(const_cast(ArgumentAddresses[2])), *static_cast(ArgumentAddresses[3])); } Index: streamexecutor/include/streamexecutor/DeviceMemory.h =================================================================== --- streamexecutor/include/streamexecutor/DeviceMemory.h +++ streamexecutor/include/streamexecutor/DeviceMemory.h @@ -133,6 +133,9 @@ /// Returns an opaque handle to the underlying memory. const void *getHandle() const { return Handle; } + /// Returns the address of the opaque handle as stored by this object. + const void *const *getHandleAddress() const { return &Handle; } + // Cannot copy because the handle must be owned by a single object. GlobalDeviceMemoryBase(const GlobalDeviceMemoryBase &) = delete; GlobalDeviceMemoryBase &operator=(const GlobalDeviceMemoryBase &) = delete; Index: streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h =================================================================== --- streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h +++ streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h @@ -164,31 +164,10 @@ Types[Index] = KernelArgumentType::VALUE; } - // Pack a GlobalDeviceMemoryBase argument. - void PackOneArgument(size_t Index, const GlobalDeviceMemoryBase &Argument) { - Addresses[Index] = Argument.getHandle(); - Sizes[Index] = sizeof(void *); - Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; - } - - // Pack a GlobalDeviceMemoryBase pointer argument. - void PackOneArgument(size_t Index, GlobalDeviceMemoryBase *Argument) { - Addresses[Index] = Argument->getHandle(); - Sizes[Index] = sizeof(void *); - Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; - } - - // Pack a const GlobalDeviceMemoryBase pointer argument. - void PackOneArgument(size_t Index, const GlobalDeviceMemoryBase *Argument) { - Addresses[Index] = Argument->getHandle(); - Sizes[Index] = sizeof(void *); - Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; - } - // Pack a GlobalDeviceMemory argument. template void PackOneArgument(size_t Index, const GlobalDeviceMemory &Argument) { - Addresses[Index] = Argument.getHandle(); + Addresses[Index] = Argument.getHandleAddress(); Sizes[Index] = sizeof(void *); Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; } @@ -196,7 +175,7 @@ // Pack a GlobalDeviceMemory pointer argument. template void PackOneArgument(size_t Index, GlobalDeviceMemory *Argument) { - Addresses[Index] = Argument->getHandle(); + Addresses[Index] = Argument->getHandleAddress(); Sizes[Index] = sizeof(void *); Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; } @@ -204,7 +183,7 @@ // Pack a const GlobalDeviceMemory pointer argument. template void PackOneArgument(size_t Index, const GlobalDeviceMemory *Argument) { - Addresses[Index] = Argument->getHandle(); + Addresses[Index] = Argument->getHandleAddress(); Sizes[Index] = sizeof(void *); Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; } Index: streamexecutor/include/streamexecutor/Stream.h =================================================================== --- streamexecutor/include/streamexecutor/Stream.h +++ streamexecutor/include/streamexecutor/Stream.h @@ -104,6 +104,13 @@ /// SharedDeviceMemory, or they can be primitive types such as int. The /// allowable argument types are determined by the template parameters to the /// Kernel argument. + /// + /// \warning + /// This function passes the addresses of its \p Arguments to the underlying + /// platform launcher. If those addresses become invalidated because another + /// thread touches an argument, this call will fail in strange-looking ways, + /// so be sure that no other threads are touching the arguments to this + /// function until it returns. template Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize, const Kernel &K, Index: streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp =================================================================== --- streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp +++ streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp @@ -76,7 +76,7 @@ TEST_F(DeviceMemoryPackingTest, SingleTypedGlobal) { auto Array = se::make_kernel_argument_pack(TypedGlobal); - ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(0u, Array.getSharedCount()); @@ -84,7 +84,7 @@ TEST_F(DeviceMemoryPackingTest, SingleTypedGlobalPointer) { auto Array = se::make_kernel_argument_pack(&TypedGlobal); - ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(0u, Array.getSharedCount()); @@ -93,7 +93,7 @@ TEST_F(DeviceMemoryPackingTest, SingleConstTypedGlobalPointer) { const se::GlobalDeviceMemory *ArgumentPointer = &TypedGlobal; auto Array = se::make_kernel_argument_pack(ArgumentPointer); - ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 0); EXPECT_EQ(1u, Array.getArgumentCount()); EXPECT_EQ(0u, Array.getSharedCount()); @@ -131,11 +131,11 @@ TypedGlobalPointer, TypedShared, &TypedShared, TypedSharedPointer); ExpectEqual(&Value, sizeof(Value), Type::VALUE, Array, 0); - ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 1); - ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 2); - ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), + ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *), Type::GLOBAL_DEVICE_MEMORY, Array, 3); ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY, Array, 4);