diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -84,6 +84,30 @@ return success(); } +/// Returns true if the given `storageClass` needs explicit layout when used in +/// Shader environments. +static bool needsExplicitLayout(spirv::StorageClass storageClass) { + switch (storageClass) { + case spirv::StorageClass::PhysicalStorageBuffer: + case spirv::StorageClass::PushConstant: + case spirv::StorageClass::StorageBuffer: + case spirv::StorageClass::Uniform: + return true; + default: + return false; + } +} + +/// Wraps the given `elementType` in a struct and gets the pointer to the +/// struct. This is used to satisfy Vulkan interface requirements. +static spirv::PointerType +wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { + auto structType = needsExplicitLayout(storageClass) + ? spirv::StructType::get(elementType, 0) + : spirv::StructType::get(elementType); + return spirv::PointerType::get(structType, storageClass); +} + //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// @@ -392,12 +416,7 @@ auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); - // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with - // workgroup storage class do not need the struct to be laid out explicitly. - auto structType = *storageClass == spirv::StorageClass::Workgroup - ? spirv::StructType::get(arrayType) - : spirv::StructType::get(arrayType, 0); - return spirv::PointerType::get(structType, *storageClass); + return wrapInStructAndGetPointer(arrayType, *storageClass); } static Type convertMemrefType(const spirv::TargetEnv &targetEnv, @@ -452,9 +471,7 @@ if (!type.hasStaticShape()) { auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize); - // Wrap in a struct to satisfy Vulkan interface requirements. - auto structType = spirv::StructType::get(arrayType, 0); - return spirv::PointerType::get(structType, *storageClass); + return wrapInStructAndGetPointer(arrayType, *storageClass); } Optional memrefSize = getTypeNumBytes(options, type); @@ -470,12 +487,7 @@ auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); - // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with - // workgroup storage class do not need the struct to be laid out explicitly. - auto structType = *storageClass == spirv::StorageClass::Workgroup - ? spirv::StructType::get(arrayType) - : spirv::StructType::get(arrayType, 0); - return spirv::PointerType::get(structType, *storageClass); + return wrapInStructAndGetPointer(arrayType, *storageClass); } SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, diff --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir @@ -9,7 +9,7 @@ // CHECK: spv.func // CHECK-SAME: {{%.*}}: f32 // CHECK-NOT: spv.interface_var_abi - // CHECK-SAME: {{%.*}}: !spv.ptr [0])>, CrossWorkgroup> + // CHECK-SAME: {{%.*}}: !spv.ptr)>, CrossWorkgroup> // CHECK-NOT: spv.interface_var_abi // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -337,13 +337,13 @@ func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0])>, Input> +// CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: func @memref_16bit_Input // NOEMU-SAME: memref<16xf16, 9> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0])>, Output> +// CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: func @memref_16bit_Output // NOEMU-SAME: memref<16xf16, 10> func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return } @@ -451,15 +451,15 @@ } { // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0])>, Input> +// CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: spv.func @memref_16bit_Input -// NOEMU-SAME: !spv.ptr [0])>, Input> +// NOEMU-SAME: !spv.ptr)>, Input> func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0])>, Output> +// CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: spv.func @memref_16bit_Output -// NOEMU-SAME: !spv.ptr [0])>, Output> +// NOEMU-SAME: !spv.ptr)>, Output> func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return } } // end module @@ -563,13 +563,13 @@ func @memref_16bit_PushConstant(%arg0: memref) { return } // CHECK-LABEL: spv.func @memref_16bit_Input -// CHECK-SAME: !spv.ptr [0])>, Input> +// CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: func @memref_16bit_Input // NOEMU-SAME: memref func @memref_16bit_Input(%arg3: memref) { return } // CHECK-LABEL: spv.func @memref_16bit_Output -// CHECK-SAME: !spv.ptr [0])>, Output> +// CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: func @memref_16bit_Output // NOEMU-SAME: memref func @memref_16bit_Output(%arg4: memref) { return }