diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -457,6 +457,8 @@ /// * 1-D `vector` remains as is while, /// * n>1 `vector` convert via an (n-1)-D array type to /// `!llvm.array>>`. +/// ATM, scalable vectors are assumed to be always 1-D. This could be relaxed +/// in the future if there's a use case. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = convertType(type.getElementType()); if (!elementType) @@ -467,8 +469,9 @@ type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - assert((type.isScalable() == type.allDimsScalable()) && - "expected scalable vector with all dims scalable"); + assert( + (!type.isScalable() || (type.getShape().size() == 1)) && + "expected 1-D scalable vector (n-D scalable vectors are not supported)"); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);