diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -239,6 +239,7 @@ if (auto splatAttr = attr.dyn_cast()) { llvm::Type *elementType; uint64_t numElements; + bool isScalable = false; if (auto *arrayTy = dyn_cast(llvmType)) { elementType = arrayTy->getElementType(); numElements = arrayTy->getNumElements(); @@ -248,6 +249,7 @@ } else if (auto *sVectorTy = dyn_cast(llvmType)) { elementType = sVectorTy->getElementType(); numElements = sVectorTy->getMinNumElements(); + isScalable = true; } else { llvm_unreachable("unrecognized constant vector type"); } @@ -265,7 +267,7 @@ return nullptr; if (llvmType->isVectorTy()) return llvm::ConstantVector::getSplat( - llvm::ElementCount::get(numElements, /*Scalable=*/false), child); + llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child); if (llvmType->isArrayTy()) { auto *arrayType = llvm::ArrayType::get(elementType, numElements); SmallVector constants(numElements, child); diff --git a/mlir/test/Target/LLVMIR/llvmir-types.mlir b/mlir/test/Target/LLVMIR/llvmir-types.mlir --- a/mlir/test/Target/LLVMIR/llvmir-types.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-types.mlir @@ -90,6 +90,8 @@ llvm.func @return_v4_i32() -> vector<4xi32> // CHECK: declare <4 x float> @return_v4_float() llvm.func @return_v4_float() -> vector<4xf32> +// CHECK: declare @return_vs_4_float() +llvm.func @return_vs_4_float() -> vector<[4]xf32> // CHECK: declare @return_vs_4_i32() llvm.func @return_vs_4_i32() -> !llvm.vec // CHECK: declare @return_vs_8_half() diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -907,6 +907,13 @@ llvm.return %0 : vector<4xf32> } +// CHECK-LABEL: @vector_splat_1d_scalable +llvm.func @vector_splat_1d_scalable() -> vector<[4]xf32> { + // CHECK: ret zeroinitializer + %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<[4]xf32>) : vector<[4]xf32> + llvm.return %0 : vector<[4]xf32> +} + // CHECK-LABEL: @vector_splat_2d llvm.func @vector_splat_2d() -> !llvm.array<4 x vector<16 x f32>> { // CHECK: ret [4 x <16 x float>] zeroinitializer