Index: llvm/docs/AArch64SME.rst =================================================================== --- llvm/docs/AArch64SME.rst +++ llvm/docs/AArch64SME.rst @@ -40,6 +40,9 @@ ``aarch64_pstate_za_preserved`` is used for functions with ``__attribute__((arm_preserves_za))`` +``aarch64_expanded_pstate_za`` + is used for functions with ``__attribute__((arm_new_za))`` + Clang must ensure that the above attributes are added both to the function's declaration/definition as well as to their call-sites. This is important for calls to attributed function pointers, where there is no @@ -423,8 +426,10 @@ lazy-save mechanism for calls to private-ZA functions (i.e. functions that may either directly or indirectly clobber ZA state). -For this purpose, we'll introduce a new LLVM IR pass that is run just before -SelectionDAG. +For the purpose of handling functions marked with ``aarch64_pstate_za_new``, +we have introduced a new LLVM IR pass (SMEABIPass) that is run just before +SelectionDAG. Any such functions dealt with by this pass are marked with +``aarch64_expanded_pstate_za``. Setting up a lazy-save ---------------------- Index: llvm/lib/Target/AArch64/AArch64.h =================================================================== --- llvm/lib/Target/AArch64/AArch64.h +++ llvm/lib/Target/AArch64/AArch64.h @@ -58,6 +58,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass(); FunctionPass *createAArch64CollectLOHPass(); +FunctionPass *createSMEABIPass(); ModulePass *createSVEIntrinsicOptsPass(); InstructionSelector * createAArch64InstructionSelector(const AArch64TargetMachine &, @@ -100,6 +101,7 @@ void initializeFalkorHWPFFixPass(PassRegistry&); void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&); void initializeLDTLSCleanupPass(PassRegistry&); +void initializeSMEABIPass(PassRegistry &); void initializeSVEIntrinsicOptsPass(PassRegistry&); void initializeAArch64StackTaggingPass(PassRegistry&); void initializeAArch64StackTaggingPreRAPass(PassRegistry&); Index: llvm/lib/Target/AArch64/AArch64TargetMachine.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -224,6 +224,7 @@ initializeFalkorHWPFFixPass(*PR); initializeFalkorMarkStridedAccessesLegacyPass(*PR); initializeLDTLSCleanupPass(*PR); + initializeSMEABIPass(*PR); initializeSVEIntrinsicOptsPass(*PR); initializeAArch64SpeculationHardeningPass(*PR); initializeAArch64SLSHardeningPass(*PR); @@ -588,6 +589,11 @@ addPass(createInterleavedAccessPass()); } + // Expand any functions marked with SME attributes which require special + // changes for the calling convention or that require the lazy-saving + // mechanism specified in the SME ABI. + addPass(createSMEABIPass()); + // Add Control Flow Guard checks. if (TM->getTargetTriple().isOSWindows()) addPass(createCFGuardCheckPass()); Index: llvm/lib/Target/AArch64/CMakeLists.txt =================================================================== --- llvm/lib/Target/AArch64/CMakeLists.txt +++ llvm/lib/Target/AArch64/CMakeLists.txt @@ -83,6 +83,7 @@ AArch64TargetMachine.cpp AArch64TargetObjectFile.cpp AArch64TargetTransformInfo.cpp + SMEABIPass.cpp SVEIntrinsicOpts.cpp AArch64SIMDInstrOpt.cpp Index: llvm/lib/Target/AArch64/SMEABIPass.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/AArch64/SMEABIPass.cpp @@ -0,0 +1,144 @@ +//===--------- SMEABI - SME ABI-------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements parts of the the SME ABI, such as: +// * Using the lazy-save mechanism before enabling the use of ZA. +// * Setting up the lazy-save mechanism around invokes. +// +//===----------------------------------------------------------------------===// + +#include "AArch64.h" +#include "Utils/AArch64BaseInfo.h" +#include "Utils/AArch64SMEAttributes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/Cloning.h" + +using namespace llvm; + +#define DEBUG_TYPE "aarch64-sme-abi" + +namespace { +struct SMEABI : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SMEABI() : FunctionPass(ID) { + initializeSMEABIPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + +private: + bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder); +}; +} // end anonymous namespace + +void SMEABI::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); } + +char SMEABI::ID = 0; +static const char *name = "SME ABI Pass"; +INITIALIZE_PASS_BEGIN(SMEABI, DEBUG_TYPE, name, false, false) +INITIALIZE_PASS_END(SMEABI, DEBUG_TYPE, name, false, false) + +FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); } + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +// Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0. +void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) { + auto *TPIDR2SaveTy = + FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false); + + auto Attrs = + AttributeList::get(M->getContext(), 0, {"aarch64_pstate_sm_compatible"}); + FunctionCallee Callee = + M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs); + Builder.CreateCall(Callee); + + // A save to TPIDR2 should be followed by clearing TPIDR2_EL0. + Function *WriteIntr = + Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_set_tpidr2); + Builder.CreateCall(WriteIntr->getFunctionType(), WriteIntr, + Builder.getInt64(0)); +} + +/// This function generates code to commit a lazy save at the beginning of a +/// function marked with `aarch64_pstate_za_new`. If the value read from +/// TPIDR2_EL0 is not null on entry to the function then the lazy-saving scheme +/// is active and we should call __arm_tpidr2_save to commit the lazy save. +/// Additionally, PSTATE.ZA should be enabled at the beginning of the function +/// and disabled before returning. +bool SMEABI::updateNewZAFunctions(Module *M, Function *F, + IRBuilder<> &Builder) { + LLVMContext &Context = F->getContext(); + BasicBlock *OrigBB = &F->getEntryBlock(); + + // Create the new blocks for reading TPIDR2_EL0 & enabling ZA state. + auto *SaveBB = OrigBB->splitBasicBlock(OrigBB->begin(), "save.za", true); + auto *PreludeBB = BasicBlock::Create(Context, "prelude", F, SaveBB); + + // Read TPIDR2_EL0 in PreludeBB & branch to SaveBB if not 0. + Builder.SetInsertPoint(PreludeBB); + Function *TPIDR2Intr = + Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_get_tpidr2); + auto *TPIDR2 = Builder.CreateCall(TPIDR2Intr->getFunctionType(), TPIDR2Intr, + {}, "tpidr2"); + auto *Cmp = + Builder.CreateCmp(ICmpInst::ICMP_NE, TPIDR2, Builder.getInt64(0), "cmp"); + Builder.CreateCondBr(Cmp, SaveBB, OrigBB); + + // Create a call __arm_tpidr2_save, which commits the lazy save. + Builder.SetInsertPoint(&SaveBB->back()); + emitTPIDR2Save(M, Builder); + + // Enable pstate.za at the start of the function. + Builder.SetInsertPoint(&OrigBB->front()); + Function *EnableZAIntr = + Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_enable); + Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr); + + // Before returning, disable pstate.za + for (BasicBlock &BB : F->getBasicBlockList()) { + Instruction *T = BB.getTerminator(); + if (!T || !isa(T)) + continue; + Builder.SetInsertPoint(T); + Function *DisableZAIntr = + Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_disable); + Builder.CreateCall(DisableZAIntr->getFunctionType(), DisableZAIntr); + } + + F->addFnAttr("aarch64_expanded_pstate_za"); + return true; +} + +bool SMEABI::runOnFunction(Function &F) { + Module *M = F.getParent(); + LLVMContext &Context = F.getContext(); + IRBuilder<> Builder(Context); + + if (F.isDeclaration() || F.hasFnAttribute("aarch64_expanded_pstate_za")) + return false; + + bool Changed = false; + SMEAttrs FnAttrs(F); + if (FnAttrs.hasNewZAInterface()) + Changed |= updateNewZAFunctions(M, &F, Builder); + + return Changed; +} Index: llvm/test/CodeGen/AArch64/O0-pipeline.ll =================================================================== --- llvm/test/CodeGen/AArch64/O0-pipeline.ll +++ llvm/test/CodeGen/AArch64/O0-pipeline.ll @@ -26,6 +26,7 @@ ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: AArch64 Stack Tagging +; CHECK-NEXT: SME ABI Pass ; CHECK-NEXT: Exception handling preparation ; CHECK-NEXT: Safe Stack instrumentation pass ; CHECK-NEXT: Insert stack protectors Index: llvm/test/CodeGen/AArch64/O3-pipeline.ll =================================================================== --- llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -92,6 +92,7 @@ ; CHECK-NEXT: Interleaved Load Combine Pass ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Interleaved Access Pass +; CHECK-NEXT: SME ABI Pass ; CHECK-NEXT: Natural Loop Information ; CHECK-NEXT: Type Promotion ; CHECK-NEXT: CodeGen Prepare Index: llvm/test/CodeGen/AArch64/sme-new-za-function.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sme-new-za-function.ll @@ -0,0 +1,62 @@ +; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s +; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi -aarch64-sme-abi %s | FileCheck %s + +declare void @shared_za_callee() "aarch64_pstate_za_shared" + +define void @private_za() "aarch64_pstate_za_new" { +; CHECK-LABEL: @private_za( +; CHECK-NEXT: prelude: +; CHECK-NEXT: [[TPIDR2:%.*]] = call i64 @llvm.aarch64.sme.get.tpidr2() +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[TPIDR2]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[SAVE_ZA:%.*]], label [[TMP0:%.*]] +; CHECK: save.za: +; CHECK-NEXT: call void @__arm_tpidr2_save() +; CHECK-NEXT: call void @llvm.aarch64.sme.set.tpidr2(i64 0) +; CHECK-NEXT: br label [[TMP0]] +; CHECK: 0: +; CHECK-NEXT: call void @llvm.aarch64.sme.za.enable() +; CHECK-NEXT: call void @shared_za_callee() +; CHECK-NEXT: call void @llvm.aarch64.sme.za.disable() +; CHECK-NEXT: ret void +; + call void @shared_za_callee() + ret void +} + +define i32 @private_za_multiple_exit(i32 %a, i32 %b, i64 %cond) "aarch64_pstate_za_new" { +; CHECK-LABEL: @private_za_multiple_exit( +; CHECK-NEXT: prelude: +; CHECK-NEXT: [[TPIDR2:%.*]] = call i64 @llvm.aarch64.sme.get.tpidr2() +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[TPIDR2]], 0 +; CHECK-NEXT: br i1 [[CMP]], label [[SAVE_ZA:%.*]], label [[ENTRY:%.*]] +; CHECK: save.za: +; CHECK-NEXT: call void @__arm_tpidr2_save() +; CHECK-NEXT: call void @llvm.aarch64.sme.set.tpidr2(i64 0) +; CHECK-NEXT: br label [[ENTRY]] +; CHECK: entry: +; CHECK-NEXT: call void @llvm.aarch64.sme.za.enable() +; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i64 [[COND:%.*]], 1 +; CHECK-NEXT: br i1 [[TOBOOL]], label [[IF_ELSE:%.*]], label [[IF_END:%.*]] +; CHECK: if.else: +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: call void @llvm.aarch64.sme.za.disable() +; CHECK-NEXT: ret i32 [[ADD]] +; CHECK: if.end: +; CHECK-NEXT: [[SUB:%.*]] = sub i32 [[A]], [[B]] +; CHECK-NEXT: call void @llvm.aarch64.sme.za.disable() +; CHECK-NEXT: ret i32 [[SUB]] +; +entry: + %tobool = icmp eq i64 %cond, 1 + br i1 %tobool, label %if.else, label %if.end + +if.else: + %add = add i32 %a, %b + ret i32 %add + +if.end: + %sub = sub i32 %a, %b + ret i32 %sub +} + +; CHECK: declare "aarch64_pstate_sm_compatible" void @__arm_tpidr2_save()