diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -473,9 +473,8 @@ type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - assert( - (!type.isScalable() || (type.getRank() == 1)) && - "expected 1-D scalable vector (n-D scalable vectors are not supported)"); + assert(!llvm::is_contained(type.getScalableDims().drop_back(), true) && + "only trailing dimension can be scalable"); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2214,3 +2214,13 @@ %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> return %0 : vector<8xf32> } + +// ----- + +// CHECK-LABEL: @make_fixed_vector_of_scalable_vector +func.func @make_fixed_vector_of_scalable_vector(%f : f64) -> vector<3x[2]xf64> +{ + // CHECK: %{{.*}} = llvm.mlir.undef : !llvm.array<3 x vector<[2]xf64>> + %res = vector.broadcast %f : f64 to vector<3x[2]xf64> + return %res : vector<3x[2]xf64> +}