diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h --- a/llvm/include/llvm/AsmParser/LLToken.h +++ b/llvm/include/llvm/AsmParser/LLToken.h @@ -141,6 +141,8 @@ kw_arm_aapcs_vfpcc, kw_aarch64_vector_pcs, kw_aarch64_sve_vector_pcs, + kw_aarch64_sme_preservemost_from_x0, + kw_aarch64_sme_preservemost_from_x2, kw_msp430_intrcc, kw_avr_intrcc, kw_avr_signalcc, diff --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h --- a/llvm/include/llvm/IR/CallingConv.h +++ b/llvm/include/llvm/IR/CallingConv.h @@ -252,6 +252,12 @@ /// M68k_INTR - Calling convention used for M68k interrupt routines. M68k_INTR = 101, + /// Preserve X0-X13, X19-X29, SP, Z0-Z31, P0-P15. + AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 = 102, + + /// Preserve X2-X15, X19-X29, SP, Z0-Z31, P0-P15. + AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 = 103, + /// The highest possible calling convention ID. Must be some 2^k - 1. MaxID = 1023 }; diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -597,6 +597,8 @@ KEYWORD(arm_aapcs_vfpcc); KEYWORD(aarch64_vector_pcs); KEYWORD(aarch64_sve_vector_pcs); + KEYWORD(aarch64_sme_preservemost_from_x0); + KEYWORD(aarch64_sme_preservemost_from_x2); KEYWORD(msp430_intrcc); KEYWORD(avr_intrcc); KEYWORD(avr_signalcc); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -1875,6 +1875,8 @@ /// ::= 'arm_aapcs_vfpcc' /// ::= 'aarch64_vector_pcs' /// ::= 'aarch64_sve_vector_pcs' +/// ::= 'aarch64_sme_preservemost_from_x0' +/// ::= 'aarch64_sme_preservemost_from_x2' /// ::= 'msp430_intrcc' /// ::= 'avr_intrcc' /// ::= 'avr_signalcc' @@ -1925,6 +1927,12 @@ case lltok::kw_aarch64_sve_vector_pcs: CC = CallingConv::AArch64_SVE_VectorCall; break; + case lltok::kw_aarch64_sme_preservemost_from_x0: + CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0; + break; + case lltok::kw_aarch64_sme_preservemost_from_x2: + CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2; + break; case lltok::kw_msp430_intrcc: CC = CallingConv::MSP430_INTR; break; case lltok::kw_avr_intrcc: CC = CallingConv::AVR_INTR; break; case lltok::kw_avr_signalcc: CC = CallingConv::AVR_SIGNAL; break; diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -312,6 +312,12 @@ case CallingConv::AArch64_SVE_VectorCall: Out << "aarch64_sve_vector_pcs"; break; + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + Out << "aarch64_sme_preservemost_from_x0"; + break; + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: + Out << "aarch64_sme_preservemost_from_x2"; + break; case CallingConv::MSP430_INTR: Out << "msp430_intrcc"; break; case CallingConv::AVR_INTR: Out << "avr_intrcc "; break; case CallingConv::AVR_SIGNAL: Out << "avr_signalcc "; break; diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -395,6 +395,22 @@ X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, LR, FP)>; +// SME ABI support routines such as __arm_tpidr2_save/restore preserve most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 0, 13), + (sequence "X%u",19, 28), + LR, FP)>; + +// SME ABI support routines __arm_sme_state preserves most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 2, 15), + (sequence "X%u",19, 28), + LR, FP)>; + def CSR_AArch64_AAPCS_SwiftTail : CalleeSavedRegs<(sub CSR_AArch64_AAPCS, X20, X22)>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -15,6 +15,7 @@ #define LLVM_LIB_TARGET_AARCH64_AARCH64ISELLOWERING_H #include "AArch64.h" +#include "Utils/AArch64SMEAttributes.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/SelectionDAG.h" @@ -1161,6 +1162,11 @@ // This function does not handle predicate bitcasts. SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; + // Returns the runtime value for PSTATE.SM. When the function is streaming- + // compatible, this generates a call to __arm_sme_state. + SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs, + SDLoc DL, EVT VT) const; + bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; }; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4480,6 +4480,32 @@ return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask); } +SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain, + SMEAttrs Attrs, SDLoc DL, + EVT VT) const { + if (Attrs.hasStreamingInterfaceOrBody()) + return DAG.getConstant(1, DL, VT); + + if (Attrs.hasNonStreamingInterfaceAndBody()) + return DAG.getConstant(0, DL, VT); + + assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface"); + + SDValue Callee = DAG.getExternalSymbol("__arm_sme_state", + getPointerTy(DAG.getDataLayout())); + Type *Int64Ty = Type::getInt64Ty(*DAG.getContext()); + Type *RetTy = StructType::get(Int64Ty, Int64Ty); + TargetLowering::CallLoweringInfo CLI(DAG); + ArgListTy Args; + CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2, + RetTy, Callee, std::move(Args)); + std::pair CallResult = LowerCallTo(CLI); + SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64); + return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0), + Mask); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -4511,13 +4537,10 @@ return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL); } case Intrinsic::aarch64_sme_get_pstatesm: { - SDValue Chain = Op.getOperand(0); - SDValue MRS = DAG.getNode( - AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other), - Chain, DAG.getConstant(AArch64SysReg::SVCR, DL, MVT::i64)); - SDValue Mask = DAG.getConstant(/* PSTATE.SM */ 1, DL, MVT::i64); - SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, MRS, Mask); - return DAG.getMergeValues({And, Chain}, DL); + SDValue Chain = Op->getOperand(0); + SMEAttrs Attrs = SMEAttrs(DAG.getMachineFunction().getFunction()); + SDValue PStateSM = getPStateSM(DAG, Chain, Attrs, DL, Op.getValueType()); + return DAG.getMergeValues({PStateSM, Chain}, DL); } } } @@ -5784,6 +5807,8 @@ return CC_AArch64_Win64_CFGuard_Check; case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return CC_AArch64_AAPCS; } } diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -90,6 +90,18 @@ return CSR_AArch64_AAVPCS_SaveList; if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) return CSR_AArch64_SVE_AAPCS_SaveList; + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "only supported to improve calls to SME ACLE save/restore/disable-za " + "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "only supported to improve calls to SME ACLE __arm_sme_state " + "and is not intended to be used beyond that scope."); if (MF->getSubtarget().getTargetLowering() ->supportSwiftError() && MF->getFunction().getAttributes().hasAttrSomewhere( @@ -122,6 +134,18 @@ if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "only supported to improve calls to SME ACLE save/restore/disable-za " + "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "only supported to improve calls to SME ACLE __arm_sme_state " + "and is not intended to be used beyond that scope."); if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS) return MF->getInfo()->isSplitCSR() ? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList @@ -192,6 +216,14 @@ if (CC == CallingConv::AArch64_SVE_VectorCall) report_fatal_error( "Calling convention SVE_VectorCall is unsupported on Darwin."); + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "unsupported on Darwin."); + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + report_fatal_error( + "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "unsupported on Darwin."); if (CC == CallingConv::CFGuard_Check) report_fatal_error( "Calling convention CFGuard_Check is unsupported on Darwin."); @@ -229,6 +261,10 @@ if (CC == CallingConv::AArch64_SVE_VectorCall) return SCS ? CSR_AArch64_SVE_AAPCS_SCS_RegMask : CSR_AArch64_SVE_AAPCS_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2_RegMask; if (CC == CallingConv::CFGuard_Check) return CSR_Win_AArch64_CFGuard_Check_RegMask; if (MF.getSubtarget().getTargetLowering() @@ -479,6 +515,8 @@ return HasReg(CC_AArch64_Win64_CFGuard_Check_ArgRegs, Reg); case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return HasReg(CC_AArch64_AAPCS_ArgRegs, Reg); } } diff --git a/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll b/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll --- a/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll +++ b/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll @@ -1,14 +1,46 @@ -; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs -stop-after=finalize-isel < %s | FileCheck %s --check-prefix=CHECK-CSRMASK -define i64 @is_streaming() { -; CHECK-LABEL: is_streaming: +define i64 @get_pstatesm_normal() nounwind { +; CHECK-LABEL: get_pstatesm_normal: ; CHECK: // %bb.0: -; CHECK-NEXT: mrs x8, SVCR -; CHECK-NEXT: and x0, x8, #0x1 +; CHECK-NEXT: mov x0, xzr ; CHECK-NEXT: ret %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() ret i64 %pstate } +define i64 @get_pstatesm_streaming() nounwind "aarch64_pstate_sm_enabled" { +; CHECK-LABEL: get_pstatesm_streaming: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w0, #1 +; CHECK-NEXT: ret + %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() + ret i64 %pstate +} + +define i64 @get_pstatesm_locally_streaming() nounwind "aarch64_pstate_sm_body" { +; CHECK-LABEL: get_pstatesm_locally_streaming: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w0, #1 +; CHECK-NEXT: ret + %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() + ret i64 %pstate +} + +define i64 @get_pstatesm_streaming_compatible() nounwind "aarch64_pstate_sm_compatible" { +; CHECK-LABEL: get_pstatesm_streaming_compatible: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: and x0, x0, #0x1 +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; +; CHECK-CSRMASK-LABEL: name: get_pstatesm_streaming_compatible +; CHECK-CSRMASK: BL &__arm_sme_state, csr_aarch64_sme_abi_support_routines_preservemost_from_x2 + %pstate = call i64 @llvm.aarch64.sme.get.pstatesm() + ret i64 %pstate +} + declare i64 @llvm.aarch64.sme.get.pstatesm() diff --git a/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll b/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll @@ -0,0 +1,37 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs -stop-after=finalize-isel < %s | FileCheck %s --check-prefix=CHECK-CSRMASK + +; Test that the PCS attribute is accepted and uses the correct register mask. +; + +define void @test_sme_calling_convention_x0() nounwind { +; CHECK-LABEL: test_sme_calling_convention_x0: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl __arm_tpidr2_save +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; +; CHECK-CSRMASK-LABEL: name: test_sme_calling_convention_x0 +; CHECK-CSRMASK: BL @__arm_tpidr2_save, csr_aarch64_sme_abi_support_routines_preservemost_from_x0 + call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save() + ret void +} + +define i64 @test_sme_calling_convention_x2() nounwind { +; CHECK-LABEL: test_sme_calling_convention_x2: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; +; CHECK-CSRMASK-LABEL: name: test_sme_calling_convention_x2 +; CHECK-CSRMASK: BL @__arm_sme_state, csr_aarch64_sme_abi_support_routines_preservemost_from_x2 + %pstate = call aarch64_sme_preservemost_from_x2 {i64, i64} @__arm_sme_state() + %pstate.sm = extractvalue {i64, i64} %pstate, 0 + ret i64 %pstate.sm +} + +declare void @__arm_tpidr2_save() +declare {i64, i64} @__arm_sme_state()