diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H #define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -16,6 +17,9 @@ class RewritePatternSet; namespace arm_sme { +//===----------------------------------------------------------------------===// +// The EnableArmStreaming pass. +//===----------------------------------------------------------------------===// // Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of // the function interface (ABI) and the caller manages PSTATE.SM on entry/exit. // In a locally streaming function PSTATE.SM is kept internal and the callee @@ -30,6 +34,14 @@ createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default, const bool enableZA = false); +//===----------------------------------------------------------------------===// +// Type ArmSMETypeConverter pass. +//===----------------------------------------------------------------------===// +class ArmSMETypeConverter : public LLVMTypeConverter { +public: + ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); +}; + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -17,6 +17,7 @@ MLIRArmNeonDialect MLIRArmSMEDialect MLIRArmSMETransforms + MLIRVectorToArmSME MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" @@ -96,6 +97,8 @@ target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); + arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options); + if (armNeon) { // TODO: we may or may not want to include in-dialect lowering to // LLVM-compatible operations here. So far, all operations in the dialect @@ -108,7 +111,7 @@ } if (armSME) { configureArmSMELegalizeForExportTarget(target); - populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns); } if (amx) { configureAMXLegalizeForExportTarget(target); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp @@ -0,0 +1,23 @@ +//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types +//-----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" + +using namespace mlir; +arm_sme::ArmSMETypeConverter::ArmSMETypeConverter( + MLIRContext *ctx, const LowerToLLVMOptions &options) + : LLVMTypeConverter(ctx, options) { + // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable + // vectors (common in the context of ArmSME), e.g. + // `vector<[16]x[16]xi8>`, + // entering the LLVM Type converter. LLVM does not support arrays of scalable + // vectors, but in the case of SME such types are effectively eliminated when + // emitting ArmSME LLVM IR intrinsics. + addConversion([&](VectorType type) { return type; }); +} diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp LegalizeForLLVMExport.cpp + ArmSMETypeConverter.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms