diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -51,6 +51,9 @@ { for (auto argType : enumerate(funcOp.getType().getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); + if (!convertedType) { + return matchFailure(); + } signatureConverter.addInputs(argType.index(), convertedType); } } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -41,10 +41,18 @@ // TODO(ravishankarm): This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. static Optional getTypeNumBytes(Type t) { - if (auto integerType = t.dyn_cast()) { - return integerType.getWidth() / 8; - } else if (auto floatType = t.dyn_cast()) { - return floatType.getWidth() / 8; + if (spirv::SPIRVDialect::isValidScalarType(t)) { + auto bitWidth = t.getIntOrFloatBitWidth(); + // According to the SPIR-V spec: + // "There is no physical size or bit pattern defined for values with boolean + // type. If they are stored (in conjunction with OpVariable), they can only + // be used with logical addressing operations, not physical, and only with + // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, + // Private, Function, Input, and Output." + if (bitWidth == 1) { + return llvm::None; + } + return bitWidth / 8; } else if (auto memRefType = t.dyn_cast()) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -289,3 +289,12 @@ %0 = std.sitofp %arg0 : i32 to f32 return } + +//===----------------------------------------------------------------------===// +// memref type +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) { +func @memref_type(%arg0: memref<3xi1>) { + return +}