diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" @@ -239,8 +240,9 @@ const SPIRVConversionOptions &options, VectorType type, Optional storageClass = {}) { + auto scalarType = type.getElementType().cast(); if (type.getRank() <= 1 && type.getNumElements() == 1) - return type.getElementType(); + return convertScalarType(targetEnv, options, scalarType, storageClass); if (!spirv::CompositeType::isValid(type)) { // TODO: Vector types with more than four elements can be translated into @@ -260,9 +262,8 @@ succeeded(checkExtensionRequirements(type, targetEnv, extensions))) return type; - auto elementType = convertScalarType( - targetEnv, options, type.getElementType().cast(), - storageClass); + auto elementType = + convertScalarType(targetEnv, options, scalarType, storageClass); if (elementType) return VectorType::get(type.getShape(), elementType); return nullptr; diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -207,6 +207,10 @@ %arg1: vector<3xf64> ) { return } +// CHECK-LABEL: spirv.func @one_element_vector +// CHECK-SAME: %{{.+}}: i32 +func.func @one_element_vector(%arg0: vector<1xi8>) { return } + } // end module // -----