Index: llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp +++ llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp @@ -89,6 +89,8 @@ bool expandCALL_BTI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); bool expandStoreSwiftAsyncContext(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); + MachineBasicBlock *expandRestoreZA(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); MachineBasicBlock *expandCondSMToggle(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); }; @@ -851,6 +853,48 @@ return true; } +MachineBasicBlock * +AArch64ExpandPseudo::expandRestoreZA(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { + MachineInstr &MI = *MBBI; + assert((std::next(MBBI) != MBB.end() || + MI.getParent()->successors().begin() != + MI.getParent()->successors().end()) && + "Unexpected unreachable in block that restores ZA"); + + // Compare TPIDR2_EL0 value against 0. + DebugLoc DL = MI.getDebugLoc(); + MachineInstrBuilder Cbz = BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBZX)) + .add(MI.getOperand(0)); + + // Split MBB and create two new blocks: + // - MBB now contains all instructions before RestoreZAPseudo. + // - SMBB contains the RestoreZAPseudo instruction only. + // - EndBB contains all instructions after RestoreZAPseudo. + MachineInstr &PrevMI = *std::prev(MBBI); + MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true); + MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end() + ? *SMBB->successors().begin() + : SMBB->splitAt(MI, /*UpdateLiveIns*/ true); + + // Add the SMBB label to the TB[N]Z instruction & create a branch to EndBB. + Cbz.addMBB(SMBB); + BuildMI(&MBB, DL, TII->get(AArch64::B)) + .addMBB(EndBB); + MBB.addSuccessor(EndBB); + + // Replace the pseudo with a call (BL). + MachineInstrBuilder MIB = + BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::BL)); + MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit); + for (unsigned I = 2; I < MI.getNumOperands(); ++I) + MIB.add(MI.getOperand(I)); + BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB); + + MI.eraseFromParent(); + return EndBB; +} + MachineBasicBlock * AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) { @@ -1371,6 +1415,12 @@ return expandCALL_BTI(MBB, MBBI); case AArch64::StoreSwiftAsyncContext: return expandStoreSwiftAsyncContext(MBB, MBBI); + case AArch64::RestoreZAPseudo: { + auto *NewMBB = expandRestoreZA(MBB, MBBI); + if (NewMBB != &MBB) + NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated. + return true; + } case AArch64::MSRpstatePseudo: { auto *NewMBB = expandCondSMToggle(MBB, MBBI); if (NewMBB != &MBB) Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -60,6 +60,7 @@ SMSTART, SMSTOP, + RESTORE_ZA, // Produces the full sequence of instructions for getting the thread pointer // offset of a variable into X0, using the TLSDesc model. @@ -895,6 +896,9 @@ void addDRTypeForNEON(MVT VT); void addQRTypeForNEON(MVT VT); + unsigned allocateLazySaveBuffer(SDValue &Chain, const SDLoc &DL, + SelectionDAG &DAG, Register &Reg) const; + SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl &Ins, Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2063,6 +2063,7 @@ break; MAKE_CASE(AArch64ISD::SMSTART) MAKE_CASE(AArch64ISD::SMSTOP) + MAKE_CASE(AArch64ISD::RESTORE_ZA) MAKE_CASE(AArch64ISD::CALL) MAKE_CASE(AArch64ISD::ADRP) MAKE_CASE(AArch64ISD::ADR) @@ -5909,6 +5910,50 @@ : RetCC_AArch64_AAPCS; } + +/// Returns true if the Function has ZA state and contains at least one call to +/// a function that requires setting up a lazy-save buffer. +static bool requiresBufferForLazySave(const Function &F) { + SMEAttrs CallerAttrs(F); + if (!CallerAttrs.hasZAState()) + return false; + + for (const BasicBlock &BB : F) + for (const Instruction &I : BB) + if (const CallInst *Call = dyn_cast(&I)) + if (CallerAttrs.requiresLazySave(SMEAttrs(*Call))) + return true; + return false; +} + +unsigned AArch64TargetLowering::allocateLazySaveBuffer( + SDValue &Chain, const SDLoc &DL, SelectionDAG &DAG, Register &Reg) const { + MachineFunction &MF = DAG.getMachineFunction(); + MachineFrameInfo &MFI = MF.getFrameInfo(); + + // Allocate a lazy-save buffer object of size SVL.B * SVL.B (worst-case) + SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getConstant(1, DL, MVT::i32)); + SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N); + SDValue Ops[] = {Chain, NN, DAG.getConstant(1, DL, MVT::i64)}; + SDVTList VTs = DAG.getVTList(MVT::i64, MVT::Other); + SDValue Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL, VTs, Ops); + unsigned FI = MFI.CreateVariableSizedObject(Align(1), nullptr); + Reg = MF.getRegInfo().createVirtualRegister(getRegClassFor(MVT::i64)); + Chain = DAG.getCopyToReg(Buffer.getValue(1), DL, Reg, Buffer.getValue(0)); + + // Allocate an additional TPIDR2 object on the stack (16 bytes) + unsigned TPIDR2Obj = MFI.CreateStackObject(16, Align(16), false); + + // Store the buffer pointer to the TPIDR2 stack object. + MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, FI); + SDValue Ptr = DAG.getFrameIndex( + FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); + Chain = DAG.getStore(Chain, DL, Buffer, Ptr, MPI); + + return TPIDR2Obj; +} + SDValue AArch64TargetLowering::LowerFormalArguments( SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl &Ins, const SDLoc &DL, @@ -6252,6 +6297,14 @@ if (Subtarget->hasCustomCallingConv()) Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF); + if (requiresBufferForLazySave(MF.getFunction())) { + // Set up a buffer once and store the buffer in the MachineFunctionInfo. + Register Reg; + unsigned TPIDR2Obj = allocateLazySaveBuffer(Chain, DL, DAG, Reg); + FuncInfo->setLazySaveBufferReg(Reg); + FuncInfo->setLazySaveTPIDR2Obj(TPIDR2Obj); + } + return Chain; } @@ -6834,7 +6887,36 @@ getCalleeAttrsFromExternalFunction(CLI.Callee)) CalleeAttrs = *Attrs; - SDValue InFlag, PStateSM; + bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); + + MachineFrameInfo &MFI = MF.getFrameInfo(); + if (RequiresLazySave) { + // Set up a lazy save mechanism by storing the runtime live slices + // (worst-case N*N) to the TPIDR2 stack object. + SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getConstant(1, DL, MVT::i32)); + SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N); + unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj(); + + if (!TPIDR2Obj) { + Register Reg; + TPIDR2Obj = allocateLazySaveBuffer(Chain, DL, DAG, Reg); + } + + MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, TPIDR2Obj); + SDValue TPIDR2ObjAddr = DAG.getFrameIndex(TPIDR2Obj, + DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); + SDValue BufferPtrAddr = + DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr, + DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType())); + Chain = DAG.getTruncStore(Chain, DL, NN, BufferPtrAddr, MPI, MVT::i16); + Chain = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, + DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + TPIDR2ObjAddr); + } + + SDValue PStateSM; Optional RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs); if (RequiresSMChange) PStateSM = getPStateSM(DAG, Chain, CallerAttrs, DL, MVT::i64); @@ -6931,7 +7013,6 @@ StoreSize *= NumParts; } - MachineFrameInfo &MFI = MF.getFrameInfo(); Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext()); Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty); int FI = MFI.CreateStackObject(StoreSize, Alignment, false); @@ -7090,6 +7171,7 @@ if (!MemOpChains.empty()) Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains); + SDValue InFlag; if (RequiresSMChange) { SDValue NewChain = changeStreamingMode(DAG, DL, *RequiresSMChange, Chain, InFlag, PStateSM, true); @@ -7243,6 +7325,46 @@ assert(PStateSM && "Expected a PStateSM to be set"); Result = changeStreamingMode(DAG, DL, !*RequiresSMChange, Result, InFlag, PStateSM, false); + } + + if (RequiresLazySave) { + // Unconditionally resume ZA. + Result = DAG.getNode( + AArch64ISD::SMSTART, DL, MVT::Other, Result, + DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64)); + + // Conditionally restore the lazy save using a pseudo node. + unsigned FI = FuncInfo->getLazySaveTPIDR2Obj(); + SDValue RegMask = DAG.getRegisterMask( + TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + SDValue RestoreRoutine = DAG.getTargetExternalSymbol( + "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout())); + SDValue TPIDR2_EL0 = DAG.getNode( + ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result, + DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); + + // Copy the address of the TPIDR2 block into X0 before 'calling' the + // RESTORE_ZA pseudo. + SDValue Glue; + SDValue TPIDR2Block = DAG.getFrameIndex( + FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); + Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue); + Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other, + {Result, TPIDR2_EL0, + DAG.getRegister(AArch64::X0, MVT::i64), + RestoreRoutine, + RegMask, + Result.getValue(1)}); + + // Finally reset the TPIDR2_EL0 register to 0. + Result = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, Result, + DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i64)); + } + + if (RequiresSMChange || RequiresLazySave) { for (unsigned I = 0; I < InVals.size(); ++I) { // The smstart/smstop is chained as part of the call, but when the // resulting chain is discarded (which happens when the call is not part Index: llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -184,6 +184,14 @@ /// or return type bool IsSVECC = false; + /// The virtual register that is the pointer to the lazy save buffer. + /// This value is used during ISelLowering. + Register LazySaveBufferReg = 0; + + /// The frame-index for the TPIDR2 object used for lazy saves. + Register LazySaveTPIDR2Obj = 0; + + /// True if the function need unwind information. mutable Optional NeedsDwarfUnwindInfo; @@ -201,6 +209,12 @@ bool isSVECC() const { return IsSVECC; }; void setIsSVECC(bool s) { IsSVECC = s; }; + unsigned getLazySaveBufferReg() const { return LazySaveBufferReg; } + void setLazySaveBufferReg(unsigned Reg) { LazySaveBufferReg = Reg; } + + unsigned getLazySaveTPIDR2Obj() const { return LazySaveTPIDR2Obj; } + void setLazySaveTPIDR2Obj(unsigned Reg) { LazySaveTPIDR2Obj = Reg; } + void initializeBaseYamlFields(const yaml::AArch64FunctionInfo &YamlMFI); unsigned getBytesInStackArgArea() const { return BytesInStackArgArea; } Index: llvm/lib/Target/AArch64/AArch64RegisterInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64RegisterInfo.h +++ llvm/lib/Target/AArch64/AArch64RegisterInfo.h @@ -69,6 +69,7 @@ const uint32_t *getTLSCallPreservedMask() const; const uint32_t *getSMStartStopCallPreservedMask() const; + const uint32_t *SMEABISupportRoutinesCallPreservedMaskFromX0() const; // Funclets on ARM64 Windows don't preserve any registers. const uint32_t *getNoPreservedMask() const override; Index: llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -325,6 +325,11 @@ return CSR_AArch64_SMStartStop_RegMask; } +const uint32_t * +AArch64RegisterInfo::SMEABISupportRoutinesCallPreservedMaskFromX0() const { + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask; +} + const uint32_t *AArch64RegisterInfo::getNoPreservedMask() const { return CSR_AArch64_NoRegs_RegMask; } Index: llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -18,6 +18,10 @@ [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>, [SDNPHasChain, SDNPSideEffect, SDNPVariadic, SDNPOptInGlue, SDNPOutGlue]>; +def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3, + [SDTCisInt<0>, SDTCisPtrTy<1>]>, + [SDNPHasChain, SDNPSideEffect, SDNPVariadic, + SDNPOptInGlue]>; //===----------------------------------------------------------------------===// // Add vector elements horizontally or vertically to ZA tile. @@ -162,6 +166,24 @@ (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>, Sched<[WriteSys]>; +// Pseudo to conditionally restore ZA state. This expands: +// +// pseudonode tpidr2_el0, tpidr2obj, restore_routine +// +// Into: +// +// if (tpidr2_el0 == 0) +// BL restore_routine, implicit-use tpidr2obj +// +def RestoreZAPseudo : + Pseudo<(outs), + (ins GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, i64imm:$restore_routine, variable_ops), []>, + Sched<[]>; + +def : Pat<(AArch64_restore_za + (i64 GPR64:$tpidr2_el0), (i64 GPR64sp:$tpidr2obj), (i64 texternalsym:$restore_routine)), + (RestoreZAPseudo GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, texternalsym:$restore_routine)>; + // Scenario A: // // %pstate.before.call = 1 Index: llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll @@ -0,0 +1,167 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64 -mattr=+sme < %s | FileCheck %s + +declare void @private_za_callee() +declare float @llvm.cos.f32(float) + +; Test lazy-save mechanism for a single callee. +define void @test_lazy_save_1_callee() nounwind "aarch64_pstate_za_shared" { +; CHECK-LABEL: test_lazy_save_1_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: mul x8, x8, x8 +; CHECK-NEXT: sub x9, x9, x8 +; CHECK-NEXT: mov sp, x9 +; CHECK-NEXT: sub x10, x29, #16 +; CHECK-NEXT: str x9, [x29] +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x10 +; CHECK-NEXT: bl private_za_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbnz x8, .LBB0_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @private_za_callee() + ret void +} + +; Test lazy-save mechanism for multiple callees. +define void @test_lazy_save_2_callees() nounwind "aarch64_pstate_za_shared" { +; CHECK-LABEL: test_lazy_save_2_callees: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mul x19, x8, x8 +; CHECK-NEXT: mov x8, sp +; CHECK-NEXT: sub x8, x8, x19 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: sub x20, x29, #16 +; CHECK-NEXT: str x8, [x29] +; CHECK-NEXT: sturh w19, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x20 +; CHECK-NEXT: bl private_za_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbnz x8, .LBB1_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB1_2: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: sturh w19, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x20 +; CHECK-NEXT: bl private_za_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbnz x8, .LBB1_4 +; CHECK-NEXT: // %bb.3: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB1_4: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @private_za_callee() + call void @private_za_callee() + ret void +} + +; Test a call of an intrinsic that gets expanded to a library call. +define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_pstate_za_shared" { +; CHECK-LABEL: test_lazy_save_expanded_intrinsic: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: mul x8, x8, x8 +; CHECK-NEXT: sub x9, x9, x8 +; CHECK-NEXT: mov sp, x9 +; CHECK-NEXT: sub x10, x29, #16 +; CHECK-NEXT: str x9, [x29] +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x10 +; CHECK-NEXT: bl cosf +; CHECK-NEXT: smstart za +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbnz x8, .LBB2_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB2_2: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + %res = call float @llvm.cos.f32(float %a) + ret float %res +} + +; Test a combination of streaming-compatible -> normal call with lazy-save. +define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_pstate_za_shared" "aarch64_pstate_sm_compatible" { +; CHECK-LABEL: test_lazy_save_and_conditional_smstart: +; CHECK: // %bb.0: +; CHECK-NEXT: stp d15, d14, [sp, #-96]! // 16-byte Folded Spill +; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill +; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: add x29, sp, #64 +; CHECK-NEXT: str x19, [sp, #80] // 8-byte Folded Spill +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: mul x8, x8, x8 +; CHECK-NEXT: sub x9, x9, x8 +; CHECK-NEXT: mov sp, x9 +; CHECK-NEXT: sub x10, x29, #80 +; CHECK-NEXT: stur x9, [x29, #-64] +; CHECK-NEXT: sturh w8, [x29, #-72] +; CHECK-NEXT: msr TPIDR2_EL0, x10 +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: and x19, x0, #0x1 +; CHECK-NEXT: tbz x19, #0, .LBB3_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: smstop sm +; CHECK-NEXT: .LBB3_2: +; CHECK-NEXT: bl private_za_callee +; CHECK-NEXT: tbz x19, #0, .LBB3_4 +; CHECK-NEXT: // %bb.3: +; CHECK-NEXT: smstart sm +; CHECK-NEXT: .LBB3_4: +; CHECK-NEXT: smstart za +; CHECK-NEXT: sub x0, x29, #80 +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbnz x8, .LBB3_6 +; CHECK-NEXT: // %bb.5: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB3_6: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: sub sp, x29, #64 +; CHECK-NEXT: ldp x29, x30, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload +; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: ldr x19, [sp, #80] // 8-byte Folded Reload +; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @private_za_callee() + ret void +} Index: llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll =================================================================== --- llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll +++ llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll @@ -4,17 +4,65 @@ declare void @private_za_callee() ; Ensure that we don't use tail call optimization when a lazy-save is required. -; -; FIXME: The code below if obviously not yet correct, because it should set up -; a lazy-save buffer before doing the call, and (conditionally) restore it after -; the call. But this functionality will follow in a future patch. define void @disable_tailcallopt() "aarch64_pstate_za_shared" nounwind { ; CHECK-LABEL: disable_tailcallopt: ; CHECK: // %bb.0: -; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: mul x8, x8, x8 +; CHECK-NEXT: sub x9, x9, x8 +; CHECK-NEXT: mov sp, x9 +; CHECK-NEXT: sub x10, x29, #16 +; CHECK-NEXT: str x9, [x29] +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x10 ; CHECK-NEXT: bl private_za_callee -; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: smstart za +; CHECK-NEXT: sub x0, x29, #16 +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbnz x8, .LBB0_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload ; CHECK-NEXT: ret tail call void @private_za_callee() ret void } + +; Ensure we set up and restore the lazy save correctly for instructions which are lowered to lib calls +define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_pstate_za_shared" nounwind { +; CHECK-LABEL: f128_call_za: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: mul x8, x8, x8 +; CHECK-NEXT: sub x9, x9, x8 +; CHECK-NEXT: mov sp, x9 +; CHECK-NEXT: sub x10, x29, #16 +; CHECK-NEXT: sturh w8, [x29, #-8] +; CHECK-NEXT: str x9, [x29] +; CHECK-NEXT: msr TPIDR2_EL0, x10 +; CHECK-NEXT: bl __addtf3 +; CHECK-NEXT: smstart za +; CHECK-NEXT: add x0, x29, #0 +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbnz x8, .LBB1_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: bl __arm_tpidr2_restore +; CHECK-NEXT: .LBB1_2: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + %res = fadd fp128 %a, %b + ret fp128 %res +}