diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2052,6 +2052,25 @@ "Attributes 'minsize and optnone' are incompatible!", V); } + if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled")) { + Check(!Attrs.hasFnAttr("aarch64_pstate_sm_compatible"), + "Attributes 'aarch64_pstate_sm_enabled and " + "aarch64_pstate_sm_compatible' are incompatible!", + V); + } + + if (Attrs.hasFnAttr("aarch64_pstate_za_new")) { + Check(!Attrs.hasFnAttr("aarch64_pstate_za_preserved"), + "Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_preserved' " + "are incompatible!", + V); + + Check(!Attrs.hasFnAttr("aarch64_pstate_za_shared"), + "Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_shared' " + "are incompatible!", + V); + } + if (Attrs.hasFnAttr(Attribute::JumpTable)) { const GlobalValue *GV = cast(V); Check(GV->hasGlobalUnnamedAddr(), diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -0,0 +1,85 @@ +//===-- AArch64SMEAttributes.h - Helper for interpreting SME attributes -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Function.h" + +#ifndef LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H +#define LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H +namespace llvm { + +class Function; +class CallBase; +class AttributeList; + +/// SMEAttrs is a utility class to parse the SME ACLE attributes on functions. +/// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM. It +/// has interfaces to query whether a streaming mode change or lazy-save +/// mechanism is required when going from one function to another (e.g. through +/// a call). +class SMEAttrs { + unsigned Bitmask; + + SMEAttrs(unsigned Mask = 0) : Bitmask(0) { set(Mask); } + +public: + // Enum with bitmasks for each individual SME feature. + enum Mask { + Normal = 0, + SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled + SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible + SM_Body = 1 << 2, // aarch64_pstate_sm_locally + ZA_Shared = 1 << 3, // aarch64_pstate_sm_shared + ZA_New = 1 << 4, // aarch64_pstate_sm_new + ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved + All = ZA_Preserved - 1 + }; + + static SMEAttrs get(unsigned Bitmask) { return SMEAttrs(Bitmask); } + static SMEAttrs getNormal() { return SMEAttrs(Normal); } + static SMEAttrs getFromFunction(const Function &F) { + return getFromAttrList(F.getAttributes()); + } + static SMEAttrs getFromCallBase(const CallBase &CB); + static SMEAttrs getFromAttrList(const AttributeList &L); + + void set(unsigned M, bool Enable = true); + + // Interfaces to query PSTATE.SM + bool hasStreamingBody() const { return Bitmask & SM_Body; } + bool hasStreamingInterface() const { return Bitmask & SM_Enabled; } + bool hasStreamingCompatibleInterface() const { + return Bitmask & SM_Compatible; + } + bool hasNonStreamingInterface() const { + return !hasStreamingInterface() && !hasStreamingCompatibleInterface(); + } + bool hasNonStreamingInterfaceAndBody() const { + return hasNonStreamingInterface() && !hasStreamingBody(); + } + + Optional requiresSMChange(const SMEAttrs &Callee, + bool BodyOverridesInterface = false) const; + + // Interfaces to query PSTATE.ZA + bool hasNewZAInterface() const { return Bitmask & ZA_New; } + bool hasSharedZAInterface() const { return Bitmask & ZA_Shared; } + bool hasPrivateZAInterface() const { return !hasSharedZAInterface(); } + bool preservesZA() const { return Bitmask & ZA_Preserved; } + bool hasZAState() const { + return hasNewZAInterface() || hasSharedZAInterface(); + } + bool requiresLazySave(const SMEAttrs &Callee) const { + return hasZAState() && Callee.hasPrivateZAInterface() && + !Callee.preservesZA(); + } +}; + +} // namespace llvm + +#endif // LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -0,0 +1,77 @@ +//===-- AArch64SMEAttributes.cpp - Helper for interpreting SME attributes -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "AArch64SMEAttributes.h" +#include "llvm/ADT/None.h" +#include "llvm/IR/InstrTypes.h" +#include + +using namespace llvm; + +void SMEAttrs::set(unsigned M, bool Enable) { + if (Enable) + Bitmask |= M; + else + Bitmask &= ~M; + + assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) && + "SM_Enabled and SM_Compatible are mutually exclusive"); + assert(!(hasNewZAInterface() && hasSharedZAInterface()) && + "ZA_New and ZA_Shared are mutually exclusive"); + assert(!(hasNewZAInterface() && preservesZA()) && + "ZA_New and ZA_Preserved are mutually exclusive"); +} + +SMEAttrs SMEAttrs::getFromCallBase(const CallBase &CB) { + SMEAttrs Attrs = getFromAttrList(CB.getAttributes()); + if (auto *F = CB.getCalledFunction()) + Attrs = SMEAttrs(getFromFunction(*F).Bitmask | Attrs.Bitmask); + return Attrs; +} + +SMEAttrs SMEAttrs::getFromAttrList(const AttributeList &Attrs) { + unsigned Bitmask = 0; + if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled")) + Bitmask |= SM_Enabled; + if (Attrs.hasFnAttr("aarch64_pstate_sm_compatible")) + Bitmask |= SM_Compatible; + if (Attrs.hasFnAttr("aarch64_pstate_sm_body")) + Bitmask |= SM_Body; + if (Attrs.hasFnAttr("aarch64_pstate_za_shared")) + Bitmask |= ZA_Shared; + if (Attrs.hasFnAttr("aarch64_pstate_za_new")) + Bitmask |= ZA_New; + if (Attrs.hasFnAttr("aarch64_pstate_za_preserved")) + Bitmask |= ZA_Preserved; + return SMEAttrs(Bitmask); +} + +Optional SMEAttrs::requiresSMChange(const SMEAttrs &Callee, + bool BodyOverridesInterface) const { + // If the transition is not through a call (e.g. when considering inlining) + // and Callee has a streaming body, then we can ignore the interface of + // Callee. + if (BodyOverridesInterface && Callee.hasStreamingBody()) { + bool IsStreaming = hasStreamingInterface() || hasStreamingBody(); + return IsStreaming ? None : Optional(true); + } + + if (Callee.hasStreamingCompatibleInterface()) + return None; + + // Both non-streaming + if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface()) + return None; + + // Both streaming + if ((hasStreamingInterface() || hasStreamingBody()) && + Callee.hasStreamingInterface()) + return None; + + return Callee.hasStreamingInterface() ? true : false; +} diff --git a/llvm/lib/Target/AArch64/Utils/CMakeLists.txt b/llvm/lib/Target/AArch64/Utils/CMakeLists.txt --- a/llvm/lib/Target/AArch64/Utils/CMakeLists.txt +++ b/llvm/lib/Target/AArch64/Utils/CMakeLists.txt @@ -1,8 +1,10 @@ add_llvm_component_library(LLVMAArch64Utils AArch64BaseInfo.cpp + AArch64SMEAttributes.cpp LINK_COMPONENTS Support + Core ADD_TO_COMPONENT AArch64 diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/sme-attributes.ll @@ -0,0 +1,10 @@ +; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s + +declare void @sm_attrs() "aarch64_pstate_sm_enabled" "aarch64_pstate_sm_compatible"; +; CHECK: Attributes 'aarch64_pstate_sm_enabled and aarch64_pstate_sm_compatible' are incompatible! + +declare void @za_preserved() "aarch64_pstate_za_new" "aarch64_pstate_za_preserved"; +; CHECK: Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_preserved' are incompatible! + +declare void @za_shared() "aarch64_pstate_za_new" "aarch64_pstate_za_shared"; +; CHECK: Attributes 'aarch64_pstate_za_new and aarch64_pstate_za_shared' are incompatible! diff --git a/llvm/unittests/Target/AArch64/CMakeLists.txt b/llvm/unittests/Target/AArch64/CMakeLists.txt --- a/llvm/unittests/Target/AArch64/CMakeLists.txt +++ b/llvm/unittests/Target/AArch64/CMakeLists.txt @@ -5,8 +5,10 @@ set(LLVM_LINK_COMPONENTS AArch64CodeGen + AArch64Utils AArch64Desc AArch64Info + AsmParser CodeGen Core GlobalISel @@ -21,6 +23,7 @@ InstSizes.cpp DecomposeStackOffsetTest.cpp MatrixRegisterAliasing.cpp + SMEAttributesTest.cpp ) set_property(TARGET AArch64Tests PROPERTY FOLDER "Tests/UnitTests/TargetTests") diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp @@ -0,0 +1,209 @@ +#include "Utils/AArch64SMEAttributes.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" + +#include "gtest/gtest.h" + +using namespace llvm; +using SA = SMEAttrs; + +std::unique_ptr parseIR(const char *IR) { + static LLVMContext C; + SMDiagnostic Err; + return parseAssemblyString(IR, Err, C); +} + +TEST(SMEAttributes, Constructors) { + LLVMContext Context; + + ASSERT_TRUE( + SA::getFromFunction(*parseIR("declare void @foo()")->getFunction("foo")) + .hasNonStreamingInterfaceAndBody()); + + ASSERT_TRUE(SA::getFromFunction( + *parseIR("declare void @foo() \"aarch64_pstate_sm_body\"") + ->getFunction("foo")) + .hasNonStreamingInterface()); + + ASSERT_TRUE(SA::getFromFunction( + *parseIR("declare void @foo() \"aarch64_pstate_sm_enabled\"") + ->getFunction("foo")) + .hasStreamingInterface()); + + ASSERT_TRUE(SA::getFromFunction( + *parseIR("declare void @foo() \"aarch64_pstate_sm_body\"") + ->getFunction("foo")) + .hasStreamingBody()); + + ASSERT_TRUE( + SA::getFromFunction( + *parseIR("declare void @foo() \"aarch64_pstate_sm_compatible\"") + ->getFunction("foo")) + .hasStreamingCompatibleInterface()); + + ASSERT_TRUE(SA::getFromFunction( + *parseIR("declare void @foo() \"aarch64_pstate_za_shared\"") + ->getFunction("foo")) + .hasSharedZAInterface()); + + ASSERT_TRUE(SA::getFromFunction( + *parseIR("declare void @foo() \"aarch64_pstate_za_new\"") + ->getFunction("foo")) + .hasNewZAInterface()); + + ASSERT_TRUE( + SA::getFromFunction( + *parseIR("declare void @foo() \"aarch64_pstate_za_preserved\"") + ->getFunction("foo")) + .preservesZA()); + + // Invalid combinations. + EXPECT_DEBUG_DEATH(SA::get(SA::SM_Enabled | SA::SM_Compatible), + "SM_Enabled and SM_Compatible are mutually exclusive"); + EXPECT_DEBUG_DEATH(SA::get(SA::ZA_New | SA::ZA_Shared), + "ZA_New and ZA_Shared are mutually exclusive"); + EXPECT_DEBUG_DEATH(SA::get(SA::ZA_New | SA::ZA_Preserved), + "ZA_New and ZA_Preserved are mutually exclusive"); + + // Test that the set() methods equally check validity. + EXPECT_DEBUG_DEATH(SA::get(SA::SM_Enabled).set(SA::SM_Compatible), + "SM_Enabled and SM_Compatible are mutually exclusive"); + EXPECT_DEBUG_DEATH(SA::get(SA::SM_Compatible).set(SA::SM_Enabled), + "SM_Enabled and SM_Compatible are mutually exclusive"); +} + +TEST(SMEAttributes, Basics) { + // Test PSTATE.SM interfaces. + ASSERT_TRUE(SA::get(SA::Normal).hasNonStreamingInterfaceAndBody()); + ASSERT_TRUE(SA::get(SA::SM_Enabled).hasStreamingInterface()); + ASSERT_TRUE(SA::get(SA::SM_Body).hasStreamingBody()); + ASSERT_TRUE(SA::get(SA::SM_Body).hasNonStreamingInterface()); + ASSERT_FALSE(SA::get(SA::SM_Body).hasNonStreamingInterfaceAndBody()); + ASSERT_FALSE(SA::get(SA::SM_Body).hasStreamingInterface()); + ASSERT_TRUE(SA::get(SA::SM_Compatible).hasStreamingCompatibleInterface()); + ASSERT_TRUE(SA::get(SA::SM_Compatible | SA::SM_Body) + .hasStreamingCompatibleInterface()); + ASSERT_TRUE(SA::get(SA::SM_Compatible | SA::SM_Body).hasStreamingBody()); + ASSERT_FALSE( + SA::get(SA::SM_Compatible | SA::SM_Body).hasNonStreamingInterface()); + + // Test PSTATE.ZA interfaces. + ASSERT_FALSE(SA::get(SA::ZA_Shared).hasPrivateZAInterface()); + ASSERT_TRUE(SA::get(SA::ZA_Shared).hasSharedZAInterface()); + ASSERT_TRUE(SA::get(SA::ZA_Shared).hasZAState()); + ASSERT_FALSE(SA::get(SA::ZA_Shared).preservesZA()); + ASSERT_TRUE(SA::get(SA::ZA_Shared | SA::ZA_Preserved).preservesZA()); + + ASSERT_TRUE(SA::get(SA::ZA_New).hasPrivateZAInterface()); + ASSERT_TRUE(SA::get(SA::ZA_New).hasNewZAInterface()); + ASSERT_TRUE(SA::get(SA::ZA_New).hasZAState()); + ASSERT_FALSE(SA::get(SA::ZA_New).preservesZA()); + + ASSERT_TRUE(SA::get(SA::Normal).hasPrivateZAInterface()); + ASSERT_FALSE(SA::get(SA::Normal).hasNewZAInterface()); + ASSERT_FALSE(SA::get(SA::Normal).hasZAState()); + ASSERT_FALSE(SA::get(SA::Normal).preservesZA()); +} + +TEST(SMEAttributes, Transitions) { + // Normal -> Normal + ASSERT_FALSE(SA::get(SA::Normal).requiresSMChange(SA::get(SA::Normal))); + // Normal -> Normal + LocallyStreaming + ASSERT_FALSE( + SA::get(SA::Normal).requiresSMChange(SA::get(SA::Normal | SA::SM_Body))); + ASSERT_EQ(*SA::get(SA::Normal) + .requiresSMChange(SA::get(SA::Normal | SA::SM_Body), + /*BodyOverridesInterface=*/true), + true); + + // Normal -> Streaming + ASSERT_EQ(*SA::get(SA::Normal).requiresSMChange(SA::get(SA::SM_Enabled)), + true); + // Normal -> Streaming + LocallyStreaming + ASSERT_EQ(*SA::get(SA::Normal) + .requiresSMChange(SA::get(SA::SM_Enabled | SA::SM_Body)), + true); + ASSERT_EQ(*SA::get(SA::Normal) + .requiresSMChange(SA::get(SA::SM_Enabled | SA::SM_Body), + /*BodyOverridesInterface=*/true), + true); + + // Normal -> Streaming-compatible + ASSERT_FALSE( + SA::get(SA::Normal).requiresSMChange(SA::get(SA::SM_Compatible))); + // Normal -> Streaming-compatible + LocallyStreaming + ASSERT_FALSE(SA::get(SA::Normal) + .requiresSMChange(SA::get(SA::SM_Compatible | SA::SM_Body))); + ASSERT_EQ(*SA::get(SA::Normal) + .requiresSMChange(SA::get(SA::SM_Compatible | SA::SM_Body), + /*BodyOverridesInterface=*/true), + true); + + // Streaming -> Normal + ASSERT_EQ(*SA::get(SA::SM_Enabled).requiresSMChange(SA::get(SA::Normal)), + false); + // Streaming -> Normal + LocallyStreaming + ASSERT_EQ(*SA::get(SA::SM_Enabled) + .requiresSMChange(SA::get(SA::Normal | SA::SM_Body)), + false); + ASSERT_FALSE(SA::get(SA::SM_Enabled) + .requiresSMChange(SA::get(SA::Normal | SA::SM_Body), + /*BodyOverridesInterface=*/true)); + + // Streaming -> Streaming + ASSERT_FALSE( + SA::get(SA::SM_Enabled).requiresSMChange(SA::get(SA::SM_Enabled))); + // Streaming -> Streaming + LocallyStreaming + ASSERT_FALSE(SA::get(SA::SM_Enabled) + .requiresSMChange(SA::get(SA::SM_Enabled | SA::SM_Body))); + ASSERT_FALSE(SA::get(SA::SM_Enabled) + .requiresSMChange(SA::get(SA::SM_Enabled | SA::SM_Body), + /*BodyOverridesInterface=*/true)); + + // Streaming -> Streaming-compatible + ASSERT_FALSE( + SA::get(SA::SM_Enabled).requiresSMChange(SA::get(SA::SM_Compatible))); + // Streaming -> Streaming-compatible + LocallyStreaming + ASSERT_FALSE(SA::get(SA::SM_Enabled) + .requiresSMChange(SA::get(SA::SM_Compatible | SA::SM_Body))); + ASSERT_FALSE(SA::get(SA::SM_Enabled) + .requiresSMChange(SA::get(SA::SM_Compatible | SA::SM_Body), + /*BodyOverridesInterface=*/true)); + + // Streaming-compatible -> Normal + ASSERT_EQ(*SA::get(SA::SM_Compatible).requiresSMChange(SA::get(SA::Normal)), + false); + ASSERT_EQ(*SA::get(SA::SM_Compatible) + .requiresSMChange(SA::get(SA::Normal | SA::SM_Body)), + false); + ASSERT_EQ(*SA::get(SA::SM_Compatible) + .requiresSMChange(SA::get(SA::Normal | SA::SM_Body), + /*BodyOverridesInterface=*/true), + true); + + // Streaming-compatible -> Streaming + ASSERT_EQ( + *SA::get(SA::SM_Compatible).requiresSMChange(SA::get(SA::SM_Enabled)), + true); + // Streaming-compatible -> Streaming + LocallyStreaming + ASSERT_EQ(*SA::get(SA::SM_Compatible) + .requiresSMChange(SA::get(SA::SM_Enabled | SA::SM_Body)), + true); + ASSERT_EQ(*SA::get(SA::SM_Compatible) + .requiresSMChange(SA::get(SA::SM_Enabled | SA::SM_Body), + /*BodyOverridesInterface=*/true), + true); + + // Streaming-compatible -> Streaming-compatible + ASSERT_FALSE( + SA::get(SA::SM_Compatible).requiresSMChange(SA::get(SA::SM_Compatible))); + // Streaming-compatible -> Streaming-compatible + LocallyStreaming + ASSERT_FALSE(SA::get(SA::SM_Compatible) + .requiresSMChange(SA::get(SA::SM_Compatible | SA::SM_Body))); + ASSERT_EQ(*SA::get(SA::SM_Compatible) + .requiresSMChange(SA::get(SA::SM_Compatible | SA::SM_Body), + /*BodyOverridesInterface=*/true), + true); +}