diff --git a/libc/src/__support/GPU/amdgpu/utils.h b/libc/src/__support/GPU/amdgpu/utils.h --- a/libc/src/__support/GPU/amdgpu/utils.h +++ b/libc/src/__support/GPU/amdgpu/utils.h @@ -20,6 +20,13 @@ /// The number of threads that execute in lock-step in a lane. constexpr const uint64_t LANE_SIZE = __AMDGCN_WAVEFRONT_SIZE; +/// Type aliases to the address spaces used by the AMDGPU backend. We use +/// 'Shared' instead of 'Local' to maintain consistency with NVPTX. +template using Private = [[clang::address_space(5)]] T; +template using Constant = [[clang::address_space(4)]] T; +template using Shared = [[clang::address_space(3)]] T; +template using Global = [[clang::address_space(1)]] T; + /// Returns the number of workgroups in the 'x' dimension of the grid. LIBC_INLINE uint32_t get_num_blocks_x() { return __builtin_amdgcn_grid_size_x() / __builtin_amdgcn_workgroup_size_x(); diff --git a/libc/src/__support/GPU/generic/utils.h b/libc/src/__support/GPU/generic/utils.h --- a/libc/src/__support/GPU/generic/utils.h +++ b/libc/src/__support/GPU/generic/utils.h @@ -18,6 +18,11 @@ constexpr const uint64_t LANE_SIZE = 1; +template using Private = T; +template using Constant = T; +template using Shared = T; +template using Global = T; + LIBC_INLINE uint32_t get_num_blocks_x() { return 1; } LIBC_INLINE uint32_t get_num_blocks_y() { return 1; } diff --git a/libc/src/__support/GPU/nvptx/utils.h b/libc/src/__support/GPU/nvptx/utils.h --- a/libc/src/__support/GPU/nvptx/utils.h +++ b/libc/src/__support/GPU/nvptx/utils.h @@ -19,6 +19,12 @@ /// The number of threads that execute in lock-step in a warp. constexpr const uint64_t LANE_SIZE = 32; +/// Type aliases to the address spaces used by the NVPTX backend. +template using Private = [[clang::address_space(5)]] T; +template using Constant = [[clang::address_space(4)]] T; +template using Shared = [[clang::address_space(3)]] T; +template using Global = [[clang::address_space(1)]] T; + /// Returns the number of CUDA blocks in the 'x' dimension. LIBC_INLINE uint32_t get_num_blocks_x() { return __nvvm_read_ptx_sreg_nctaid_x();