Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -1092,6 +1092,10 @@ "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, + Option<"armSME", "enable-arm-sme", + "bool", /*default=*/"false", + "Enables the use of ArmSME dialect while lowering the vector " + "dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " Index: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td =================================================================== --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -119,4 +119,7 @@ def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">; def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">; +def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">; +def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">; + #endif // ARMSME_OPS Index: mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -0,0 +1,29 @@ +//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H +#define MLIR_DIALECT_ARMSME_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; + +/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM +/// intrinsics. +void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering ArmSME ops to ops that map to LLVM +/// intrinsics. +void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H Index: mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -15,6 +15,8 @@ LINK_LIBS PUBLIC MLIRArithDialect MLIRArmNeonDialect + MLIRArmSMEDialect + MLIRArmSMETransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect Index: mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp =================================================================== --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,8 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -49,6 +51,8 @@ registry.insert(); if (armSVE) registry.insert(); + if (armSME) + registry.insert(); if (amx) registry.insert(); if (x86Vector) @@ -102,6 +106,10 @@ configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } + if (armSME) { + configureArmSMELegalizeForExportTarget(target); + populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + } if (amx) { configureAMXLegalizeForExportTarget(target); populateAMXLegalizeForLLVMExportPatterns(converter, patterns); Index: mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp + LegalizeForLLVMExport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms @@ -8,6 +9,8 @@ MLIRArmSMETransformsIncGen LINK_LIBS PUBLIC + MLIRArmSMEDialect MLIRFuncDialect + MLIRLLVMCommonConversion MLIRPass ) Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,74 @@ +//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +namespace { +/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func' +/// ops to enable the ZA storage array. +struct EnableZAPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const final { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(&op.front()); + rewriter.create(op->getLoc()); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + +/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to +/// disable the ZA storage array. +struct DisableZAPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::ReturnOp op, + PatternRewriter &rewriter) const final { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + rewriter.create(op->getLoc()); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; +} // namespace + +void mlir::populateArmSMELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void mlir::configureArmSMELegalizeForExportTarget( + LLVMConversionTarget &target) { + target.addLegalOp(); + + // Enable ZA if the function has the 'arm_za' attribute and it hasn't already + // been enabled. + target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { + auto firstOp = funcOp.getBody().front().begin(); + return !funcOp->hasAttr("arm_za") || + isa(firstOp); + }); + + // Disable ZA if the function has the 'arm_za' attribute and it hasn't + // already been disabled. + target.addDynamicallyLegalOp([&](func::ReturnOp returnOp) { + bool hasDisableZA = false; + auto funcOp = returnOp->getParentOp(); + funcOp->walk( + [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; }); + return !funcOp->hasAttr("arm_za") || hasDisableZA; + }); +} Index: mlir/test/Dialect/ArmSME/enable-arm-za.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/enable-arm-za.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s + +// CHECK-LABEL: @arm_za +func.func @arm_za() { + // CHECK: arm_sme.intr.za.enable + // CHECK-NEXT: arm_sme.intr.za.disable + // CHECK-NEXT: return + return +} Index: mlir/test/Target/LLVMIR/arm-sme.mlir =================================================================== --- mlir/test/Target/LLVMIR/arm-sme.mlir +++ mlir/test/Target/LLVMIR/arm-sme.mlir @@ -223,3 +223,14 @@ (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () llvm.return } + +// ----- + +// CHECK-LABEL: @arm_sme_toggle_za +llvm.func @arm_sme_toggle_za() { + // CHECK: call void @llvm.aarch64.sme.za.enable() + "arm_sme.intr.za.enable"() : () -> () + // CHECK: call void @llvm.aarch64.sme.za.disable() + "arm_sme.intr.za.disable"() : () -> () + llvm.return +}