diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -440,8 +440,16 @@ return nullptr; } + Optional arrayElemSize = getTypeNumBytes(options, arrayElemType); + if (!arrayElemSize) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot deduce converted element size\n"); + return nullptr; + } + if (!type.hasStaticShape()) { - auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize); + auto arrayType = + spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize); // Wrap in a struct to satisfy Vulkan interface requirements. auto structType = spirv::StructType::get(arrayType, 0); return spirv::PointerType::get(structType, *storageClass); @@ -456,12 +464,6 @@ auto arrayElemCount = *memrefSize / *elementSize; - Optional arrayElemSize = getTypeNumBytes(options, arrayElemType); - if (!arrayElemSize) { - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot deduce converted element size\n"); - return nullptr; - } auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -296,6 +296,8 @@ // An i1 is store in 8-bit, so 5xi1 has 40 bits, which is stored in 2xi32. // CHECK-LABEL: spv.func @memref_1bit_type // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// NOEMU-LABEL: func @memref_1bit_type +// NOEMU-SAME: memref<5xi1> func @memref_1bit_type(%arg0: memref<5xi1>) { return } // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer @@ -509,12 +511,68 @@ // CHECK-SAME: memref<*xi32> func @unranked_memref(%arg0: memref<*xi32>) { return } +// Check that dynamic dims on i1 are not supported. +// CHECK-LABEL: func @memref_1bit_type +// CHECK-SAME: memref +func @memref_1bit_type(%arg0: memref) { return } + // CHECK-LABEL: func @dynamic_dim_memref // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @dynamic_dim_memref(%arg0: memref<8x?xi32>, - %arg1: memref) -{ return } + %arg1: memref) { return } + +// Check that using non-32-bit scalar types in interface storage classes +// requires special capability and extension: convert them to 32-bit if not +// satisfied. + +// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// NOEMU-LABEL: func @memref_8bit_StorageBuffer +// NOEMU-SAME: memref +func @memref_8bit_StorageBuffer(%arg0: memref) { return } + +// CHECK-LABEL: spv.func @memref_8bit_Uniform +// CHECK-SAME: !spv.ptr [0])>, Uniform> +// NOEMU-LABEL: func @memref_8bit_Uniform +// NOEMU-SAME: memref +func @memref_8bit_Uniform(%arg0: memref) { return } + +// CHECK-LABEL: spv.func @memref_8bit_PushConstant +// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// NOEMU-LABEL: func @memref_8bit_PushConstant +// NOEMU-SAME: memref +func @memref_8bit_PushConstant(%arg0: memref) { return } + +// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +// NOEMU-LABEL: func @memref_16bit_StorageBuffer +// NOEMU-SAME: memref +func @memref_16bit_StorageBuffer(%arg0: memref) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Uniform +// CHECK-SAME: !spv.ptr [0])>, Uniform> +// NOEMU-LABEL: func @memref_16bit_Uniform +// NOEMU-SAME: memref +func @memref_16bit_Uniform(%arg0: memref) { return } + +// CHECK-LABEL: spv.func @memref_16bit_PushConstant +// CHECK-SAME: !spv.ptr [0])>, PushConstant> +// NOEMU-LABEL: func @memref_16bit_PushConstant +// NOEMU-SAME: memref +func @memref_16bit_PushConstant(%arg0: memref) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Input +// CHECK-SAME: !spv.ptr [0])>, Input> +// NOEMU-LABEL: func @memref_16bit_Input +// NOEMU-SAME: memref +func @memref_16bit_Input(%arg3: memref) { return } + +// CHECK-LABEL: spv.func @memref_16bit_Output +// CHECK-SAME: !spv.ptr [0])>, Output> +// NOEMU-LABEL: func @memref_16bit_Output +// NOEMU-SAME: memref +func @memref_16bit_Output(%arg4: memref) { return } } // end module