diff --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h --- a/llvm/include/llvm/IR/CallingConv.h +++ b/llvm/include/llvm/IR/CallingConv.h @@ -252,6 +252,9 @@ /// M68k_INTR - Calling convention used for M68k interrupt routines. M68k_INTR = 101, + /// Preserve X0-X13, X19-X29, SP, Z0-Z31, P0-P15. + AArch64_SME_ABI_Support_Routines_PreserveMost = 102, + /// The highest possible calling convention ID. Must be some 2^k - 1. MaxID = 1023 }; diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -395,6 +395,15 @@ X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, LR, FP)>; +// SME ABI support routines such as __arm_sme_state and __arm_tpidr2_save/restore +// preserve most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 0, 13), + (sequence "X%u",19, 28), + LR, FP)>; + def CSR_AArch64_AAPCS_SwiftTail : CalleeSavedRegs<(sub CSR_AArch64_AAPCS, X20, X22)>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -15,6 +15,7 @@ #define LLVM_LIB_TARGET_AARCH64_AARCH64ISELLOWERING_H #include "AArch64.h" +#include "Utils/AArch64SMEAttributes.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/SelectionDAG.h" @@ -1161,6 +1162,11 @@ // This function does not handle predicate bitcasts. SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; + // Returns the runtime value for PSTATE.SM. When the function is streaming- + // compatible, this generates a call to __arm_sme_state. + SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs, + SDLoc DL, EVT VT) const; + bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; }; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4480,6 +4480,32 @@ return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask); } +SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain, + SMEAttrs Attrs, SDLoc DL, + EVT VT) const { + if (Attrs.hasStreamingInterface() || Attrs.hasStreamingBody()) + return DAG.getConstant(1, DL, VT); + + if (Attrs.hasNonStreamingInterfaceAndBody()) + return DAG.getConstant(0, DL, VT); + + assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface"); + + SDValue Callee = DAG.getExternalSymbol("__arm_sme_state", + getPointerTy(DAG.getDataLayout())); + Type *Int64Ty = Type::getInt64Ty(*DAG.getContext()); + Type *RetTy = StructType::get(Int64Ty, Int64Ty); + TargetLowering::CallLoweringInfo CLI(DAG); + ArgListTy Args; + CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost, RetTy, Callee, + std::move(Args)); + std::pair CallResult = LowerCallTo(CLI); + SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64); + return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0), + Mask); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -4511,13 +4537,11 @@ return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL); } case Intrinsic::aarch64_sme_get_pstatesm: { - SDValue Chain = Op.getOperand(0); - SDValue MRS = DAG.getNode( - AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other), - Chain, DAG.getConstant(AArch64SysReg::SVCR, DL, MVT::i64)); - SDValue Mask = DAG.getConstant(/* PSTATE.SM */ 1, DL, MVT::i64); - SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, MRS, Mask); - return DAG.getMergeValues({And, Chain}, DL); + SDValue Chain = Op->getOperand(0); + SMEAttrs Attrs = + SMEAttrs::getFromFunction(DAG.getMachineFunction().getFunction()); + SDValue PStateSM = getPStateSM(DAG, Chain, Attrs, DL, Op.getValueType()); + return DAG.getMergeValues({PStateSM, Chain}, DL); } } } @@ -5784,6 +5808,7 @@ return CC_AArch64_Win64_CFGuard_Check; case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost: return CC_AArch64_AAPCS; } } diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -90,6 +90,12 @@ return CSR_AArch64_AAVPCS_SaveList; if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) return CSR_AArch64_SVE_AAPCS_SaveList; + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost is " + "only supported to improve calls to SME ACLE save/restore/state " + "functions, and is not intended to be used beyond that scope."); if (MF->getSubtarget().getTargetLowering() ->supportSwiftError() && MF->getFunction().getAttributes().hasAttrSomewhere( @@ -122,6 +128,12 @@ if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost is " + "only supported to improve calls to SME ACLE save/restore/state " + "functions, and is not intended to be used beyond that scope."); if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS) return MF->getInfo()->isSplitCSR() ? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList @@ -192,6 +204,10 @@ if (CC == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost is " + "unsupported on Darwin."); if (CC == CallingConv::CFGuard_Check) report_fatal_error( "Calling convention CFGuard_Check is unsupported on Darwin."); @@ -229,6 +245,8 @@ if (CC == CallingConv::AArch64_SVE_VectorCall) return SCS ? CSR_AArch64_SVE_AAPCS_SCS_RegMask : CSR_AArch64_SVE_AAPCS_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_RegMask; if (CC == CallingConv::CFGuard_Check) return CSR_Win_AArch64_CFGuard_Check_RegMask; if (MF.getSubtarget().getTargetLowering() @@ -479,6 +497,7 @@ return HasReg(CC_AArch64_Win64_CFGuard_Check_ArgRegs, Reg); case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost: return HasReg(CC_AArch64_AAPCS_ArgRegs, Reg); } } diff --git a/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll b/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll --- a/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll +++ b/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll @@ -1,14 +1,46 @@ -; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs -stop-after=finalize-isel < %s | FileCheck %s --check-prefix=CHECK-CSRMASK -define i64 @is_streaming() { -; CHECK-LABEL: is_streaming: +define i64 @get_pstatesm_normal() nounwind { +; CHECK-LABEL: get_pstatesm_normal: ; CHECK: // %bb.0: -; CHECK-NEXT: mrs x8, SVCR -; CHECK-NEXT: and x0, x8, #0x1 +; CHECK-NEXT: mov x0, xzr ; CHECK-NEXT: ret %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() ret i64 %pstate } +define i64 @get_pstatesm_streaming() nounwind "aarch64_pstate_sm_enabled" { +; CHECK-LABEL: get_pstatesm_streaming: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w0, #1 +; CHECK-NEXT: ret + %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() + ret i64 %pstate +} + +define i64 @get_pstatesm_locally_streaming() nounwind "aarch64_pstate_sm_body" { +; CHECK-LABEL: get_pstatesm_locally_streaming: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w0, #1 +; CHECK-NEXT: ret + %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() + ret i64 %pstate +} + +define i64 @get_pstatesm_streaming_compatible() nounwind "aarch64_pstate_sm_compatible" { +; CHECK-LABEL: get_pstatesm_streaming_compatible: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: and x0, x0, #0x1 +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; +; CHECK-CSRMASK-LABEL: name: get_pstatesm_streaming_compatible +; CHECK-CSRMASK: BL &__arm_sme_state, csr_aarch64_sme_abi_support_routines_preservemost + %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() + ret i64 %pstate +} + declare i64 @llvm.aarch64.sme.get.pstatesm()