diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h --- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h @@ -34,6 +34,7 @@ bool useBarePtrCallConv = false; bool useOpaquePointers = true; + bool enableVectorConversionPatterns = true; enum class AllocLowering { /// Use malloc for for heap allocations. diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -135,6 +135,10 @@ /// Returns true if using opaque pointers was enabled in the lowering options. bool useOpaquePointers() const { return getOptions().useOpaquePointers; } + bool enableVectorConversionPatterns() const { + return getOptions().enableVectorConversionPatterns; + } + /// Creates an LLVM pointer type with the given element type and address /// space. /// This function is meant to be used in code supporting both typed and opaque diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -37,7 +37,10 @@ addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); - addConversion([&](VectorType type) { return convertVectorType(type); }); + if (enableVectorConversionPatterns()) + addConversion([&](VectorType type) { return convertVectorType(type); }); + else + addConversion([&](VectorType type) { return type; }); // LLVM-compatible types are legal, so add a pass-through conversion. Do this // before the conversions below since conversions are attempted in reverse @@ -457,6 +460,8 @@ /// * 1-D `vector` remains as is while, /// * n>1 `vector` convert via an (n-1)-D array type to /// `!llvm.array>>`. +/// ATM, scalable vectors are assumed to be always 1-D. This could be relaxed +/// in the future if there's a use case. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = convertType(type.getElementType()); if (!elementType) @@ -467,8 +472,9 @@ type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - assert((type.isScalable() == type.allDimsScalable()) && - "expected scalable vector with all dims scalable"); + assert( + (!type.isScalable() || (type.getShape().size() == 1)) && + "expected 1-D scalable vector (n-D scalable vectors are not supported)"); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -97,6 +97,14 @@ target.addLegalDialect(); target.addLegalOp(); + // The following is introcued specifically for ArmSME, which supports 2-d + // scalable vectors, e.g. + // `vector<[16]x[16]xi8>`. + // Such vectors should not be fed to the LLVM type converter (LLVM does not + // support such types). + options.enableVectorConversionPatterns = false; + LLVMTypeConverter converterWithoutVectorPatterns(&getContext(), options); + if (armNeon) { // TODO: we may or may not want to include in-dialect lowering to // LLVM-compatible operations here. So far, all operations in the dialect @@ -109,7 +117,8 @@ } if (armSME) { configureArmSMELegalizeForExportTarget(target); - populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + populateArmSMELegalizeForLLVMExportPatterns(converterWithoutVectorPatterns, + patterns); } if (amx) { configureAMXLegalizeForExportTarget(target);