diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -772,8 +772,12 @@ ArrayRef ref(caps, llvm::array_lengthof(caps)); \ capabilities.push_back(ref); \ } \ - } break + /* No requirements for other bitwidths */ \ + return; \ + } + // This part only handles the cases where special bitwidths appearing in + // interface storage classes. if (storage) { switch (*storage) { STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16); @@ -782,17 +786,17 @@ STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess, StorageUniform16); case StorageClass::Input: - case StorageClass::Output: + case StorageClass::Output: { if (bitwidth == 16) { static const Capability caps[] = {Capability::StorageInputOutput16}; ArrayRef ref(caps, llvm::array_lengthof(caps)); capabilities.push_back(ref); } - break; + return; + } default: break; } - return; } #undef STORAGE_CASE diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -32,25 +32,34 @@ // ----- -// TODO: Uncomment this test when the extension handling correctly -// converts an i16 type to i32 type and handles the load/stores -// correctly. - -// module attributes { -// spv.target_env = #spv.target_env< -// #spv.vce, -// {max_compute_workgroup_invocations = 128 : i32, -// max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> -// } -// { -// func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { -// %0 = alloc() : memref<4x5xi16, 3> -// %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3> -// store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3> -// dealloc %0 : memref<4x5xi16, 3> -// return -// } -// } +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<4x5xi16, 3> + %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3> + store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3> + dealloc %0 : memref<4x5xi16, 3> + return + } +} + +// CHECK: spv.globalVariable @__workgroup_mem__{{[0-9]+}} +// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem +// CHECK: %[[VAR:.+]] = spv._address_of @__workgroup_mem__0 +// CHECK: %[[LOC:.+]] = spv.SDiv +// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]] +// CHECK: %{{.+}} = spv.Load "Workgroup" %[[PTR]] : i32 +// CHECK: %[[LOC:.+]] = spv.SDiv +// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]] +// CHECK: %{{.+}} = spv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr +// CHECK: %{{.+}} = spv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr + // -----