diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1092,6 +1092,10 @@ "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, + Option<"armSME", "enable-arm-sme", + "bool", /*default=*/"false", + "Enables the use of ArmSME dialect while lowering the vector " + "dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -119,4 +119,7 @@ def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">; def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">; +def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">; +def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">; + #endif // ARMSME_OPS diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -0,0 +1,29 @@ +//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H +#define MLIR_DIALECT_ARMSME_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; + +/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM +/// intrinsics. +void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering ArmSME ops to ops that map to LLVM +/// intrinsics. +void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -15,6 +15,8 @@ LINK_LIBS PUBLIC MLIRArithDialect MLIRArmNeonDialect + MLIRArmSMEDialect + MLIRArmSMETransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,8 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -49,6 +51,8 @@ registry.insert(); if (armSVE) registry.insert(); + if (armSME) + registry.insert(); if (amx) registry.insert(); if (x86Vector) @@ -102,6 +106,10 @@ configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } + if (armSME) { + configureArmSMELegalizeForExportTarget(target); + populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + } if (amx) { configureAMXLegalizeForExportTarget(target); populateAMXLegalizeForLLVMExportPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp + LegalizeForLLVMExport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms @@ -8,6 +9,8 @@ MLIRArmSMETransformsIncGen LINK_LIBS PUBLIC + MLIRArmSMEDialect MLIRFuncDialect + MLIRLLVMCommonConversion MLIRPass ) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,78 @@ +//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +namespace { +/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func' +/// ops to enable the ZA storage array. +struct EnableZAPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const final { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(&op.front()); + rewriter.create(op->getLoc()); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + +/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to +/// disable the ZA storage array. +struct DisableZAPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::ReturnOp op, + PatternRewriter &rewriter) const final { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + rewriter.create(op->getLoc()); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; +} // namespace + +void mlir::populateArmSMELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void mlir::configureArmSMELegalizeForExportTarget( + LLVMConversionTarget &target) { + target.addLegalOp(); + + // Mark 'func.func' ops as legal if either: + // 1. no 'arm_za' function attribute is present. + // 2. the 'arm_za' function attribute is present and the first op in the + // function is an 'arm_sme::aarch64_sme_za_enable' intrinsic. + target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { + auto firstOp = funcOp.getBody().front().begin(); + return !funcOp->hasAttr("arm_za") || + isa(firstOp); + }); + + // Mark 'func.return' ops as legal if either: + // 1. no 'arm_za' function attribute is present. + // 2. the 'arm_za' function attribute is present and there's a preceding + // 'arm_sme::aarch64_sme_za_disable' intrinsic. + target.addDynamicallyLegalOp([&](func::ReturnOp returnOp) { + bool hasDisableZA = false; + auto funcOp = returnOp->getParentOp(); + funcOp->walk( + [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; }); + return !funcOp->hasAttr("arm_za") || hasDisableZA; + }); +} diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA +// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING + +// CHECK-LABEL: @arm_za +func.func @arm_za() { + // ENABLE-ZA: arm_sme.intr.za.enable + // ENABLE-ZA-NEXT: arm_sme.intr.za.disable + // ENABLE-ZA-NEXT: return + // DISABLE-ZA-NOT: arm_sme.intr.za.enable + // DISABLE-ZA-NOT: arm_sme.intr.za.disable + // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable + // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable + return +} diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir --- a/mlir/test/Target/LLVMIR/arm-sme.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -223,3 +223,14 @@ (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () llvm.return } + +// ----- + +// CHECK-LABEL: @arm_sme_toggle_za +llvm.func @arm_sme_toggle_za() { + // CHECK: call void @llvm.aarch64.sme.za.enable() + "arm_sme.intr.za.enable"() : () -> () + // CHECK: call void @llvm.aarch64.sme.za.disable() + "arm_sme.intr.za.disable"() : () -> () + llvm.return +}