diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -507,7 +507,8 @@ type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - if (type.isScalable() && (type.getRank() > 1)) + // Only the trailing dimension can be scalable. + if (llvm::is_contained(type.getScalableDims().drop_back(), true)) return failure(); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2260,3 +2260,13 @@ %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> return %0 : vector<8xf32> } + +// ----- + +// CHECK-LABEL: @make_fixed_vector_of_scalable_vector +func.func @make_fixed_vector_of_scalable_vector(%f : f64) -> vector<3x[2]xf64> +{ + // CHECK: %{{.*}} = llvm.mlir.undef : !llvm.array<3 x vector<[2]xf64>> + %res = vector.broadcast %f : f64 to vector<3x[2]xf64> + return %res : vector<3x[2]xf64> +}