Index: mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) Index: mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSME) +add_public_tablegen_target(MLIRArmSMETransformsIncGen) + +add_mlir_doc(Passes ArmSMEPasses ./ -gen-pass-doc) Index: mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -0,0 +1,43 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class RewritePatternSet; + +namespace arm_sme { +// Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of +// the function interface (ABI) and the caller manages PSTATE.SM on entry/exit. +// In a locally streaming function PSTATE.SM is kept internal and the callee +// manages it on entry/exit. +enum class ArmStreaming { Default = 0, Locally = 1 }; + +#define GEN_PASS_DECL +#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" + +/// Pass to enable Armv9 Streaming SVE mode. +std::unique_ptr +createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" + +} // namespace arm_sme +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H Index: mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td @@ -0,0 +1,40 @@ +//===-- Passes.td - ArmSME pass definition file ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD +#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def EnableArmStreaming + : Pass<"enable-arm-streaming", "mlir::func::FuncOp"> { + let summary = "Enable Armv9 Streaming SVE mode"; + let description = [{ + Enables the Armv9 Streaming SVE mode [1] for func.func ops by annotating + them with attributes. See options for more details. + + [1] https://developer.arm.com/documentation/ddi0616/aa + }]; + let constructor = "mlir::arm_sme::createEnableArmStreamingPass()"; + let options = [ + Option<"mode", "mode", "mlir::arm_sme::ArmStreaming", + /*default=*/"mlir::arm_sme::ArmStreaming::Default", + "Select how streaming-mode is managed at the function-level.", + [{::llvm::cl::values( + clEnumValN(mlir::arm_sme::ArmStreaming::Default, "default", + "Streaming mode is part of the function interface " + "(ABI), caller manages PSTATE.SM on entry/exit."), + clEnumValN(mlir::arm_sme::ArmStreaming::Locally, "locally", + "Streaming mode is internal to the function, callee " + "manages PSTATE.SM on entry/exit.") + )}]>, + ]; + let dependentDialects = ["func::FuncDialect"]; +} + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD Index: mlir/include/mlir/Dialect/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/CMakeLists.txt +++ mlir/include/mlir/Dialect/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(Arith) add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) +add_subdirectory(ArmSME) add_subdirectory(Async) add_subdirectory(Bufferization) add_subdirectory(Complex) Index: mlir/include/mlir/InitAllPasses.h =================================================================== --- mlir/include/mlir/InitAllPasses.h +++ mlir/include/mlir/InitAllPasses.h @@ -18,6 +18,7 @@ #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" @@ -77,6 +78,7 @@ tosa::registerTosaOptPasses(); transform::registerTransformPasses(); vector::registerVectorPasses(); + arm_sme::registerArmSMEPasses(); // Dialect pipelines sparse_tensor::registerSparseTensorPipelines(); Index: mlir/lib/Dialect/ArmSME/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) Index: mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRArmSMETransforms + EnableArmStreaming.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms + + DEPENDS + MLIRArmSMETransformsIncGen + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRPass + ) Index: mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp @@ -0,0 +1,75 @@ +//===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE +// (SSVE) mode [1][2] by adding either of the following attributes to +// 'func.func' ops: +// +// * 'arm_streaming' (default) +// * 'arm_locally_streaming' +// +// Streaming-mode is part of the interface (ABI) for functions with the +// first attribute and it's the responsibility of the caller to manage +// PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM +// backend will emit 'smstart sm' / 'smstop sm' [4] around calls to +// streaming functions. +// +// In locally streaming functions PSTATE.SM is kept internal and managed by +// the callee on entry/exit. The LLVM backend will emit 'smstart sm' / +// 'smstop sm' in the prologue / epilogue for functions with this +// attribute. +// +// [1] https://developer.arm.com/documentation/ddi0616/aa +// [2] https://llvm.org/docs/AArch64SME.html +// [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces +// [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate-- +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +#define DEBUG_TYPE "enable-arm-streaming" + +namespace mlir { +namespace arm_sme { +#define GEN_PASS_DEF_ENABLEARMSTREAMING +#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" +} // namespace arm_sme +} // namespace mlir + +using namespace mlir; +using namespace mlir::arm_sme; + +static constexpr char kArmStreamingAttr[] = "arm_streaming"; +static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming"; + +namespace { +struct EnableArmStreamingPass + : public arm_sme::impl::EnableArmStreamingBase { + EnableArmStreamingPass(ArmStreaming mode) { this->mode = mode; } + void runOnOperation() override { + std::string attr; + switch (mode) { + case ArmStreaming::Default: + attr = kArmStreamingAttr; + break; + case ArmStreaming::Locally: + attr = kArmLocallyStreamingAttr; + break; + } + getOperation()->setAttr(attr, UnitAttr::get(&getContext())); + } +}; +} // namespace + +std::unique_ptr +mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode) { + return std::make_unique(mode); +} Index: mlir/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/CMakeLists.txt +++ mlir/lib/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Arith) add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) +add_subdirectory(ArmSME) add_subdirectory(Async) add_subdirectory(AMX) add_subdirectory(Bufferization) Index: mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -enable-arm-streaming=mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY + +// CHECK-LABEL: @arm_streaming +// CHECK-SAME: attributes {arm_streaming} +// CHECK-LOCALLY-LABEL: @arm_streaming +// CHECK-LOCALLY-SAME: attributes {arm_locally_streaming} +func.func @arm_streaming() { return }