diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -41,10 +41,9 @@ // TODO(ravishankarm): This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. static Optional getTypeNumBytes(Type t) { - if (auto integerType = t.dyn_cast()) { - return integerType.getWidth() / 8; - } else if (auto floatType = t.dyn_cast()) { - return floatType.getWidth() / 8; + if (spirv::SPIRVDialect::isValidScalarType(t)) { + auto bitWidth = t.getIntOrFloatBitWidth(); + return bitWidth == 1 ? 1 : (bitWidth / 8); } else if (auto memRefType = t.dyn_cast()) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -289,3 +289,12 @@ %0 = std.sitofp %arg0 : i32 to f32 return } + +//===----------------------------------------------------------------------===// +// memref type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @memref_type({{%.*}}: !spv.ptr [0]>, {{.*}}> +func @memref_type(%arg0 : memref<3xi1>) { + return +}