diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4814,17 +4814,6 @@ Mask); } -static std::optional getCalleeAttrsFromExternalFunction(SDValue V) { - if (auto *ES = dyn_cast(V)) { - StringRef S(ES->getSymbol()); - if (S == "__arm_sme_state" || S == "__arm_tpidr2_save") - return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved); - if (S == "__arm_tpidr2_restore") - return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared); - } - return std::nullopt; -} - SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -7316,28 +7305,31 @@ SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction()); if (CLI.CB) CalleeAttrs = SMEAttrs(*CLI.CB); - else if (std::optional Attrs = - getCalleeAttrsFromExternalFunction(CLI.Callee)) - CalleeAttrs = *Attrs; + else if (auto *ES = dyn_cast(CLI.Callee)) + CalleeAttrs = SMEAttrs(ES->getSymbol()); bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); - - MachineFrameInfo &MFI = MF.getFrameInfo(); if (RequiresLazySave) { - // Set up a lazy save mechanism by storing the runtime live slices - // (worst-case N*N) to the TPIDR2 stack object. - SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, - DAG.getConstant(1, DL, MVT::i32)); - SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N); - unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj(); + SDValue NumZaSaveSlices; + if (!CalleeAttrs.preservesZA()) { + // Set up a lazy save mechanism by storing the runtime live slices + // (worst-case SVL*SVL) to the TPIDR2 stack object. + SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getConstant(1, DL, MVT::i32)); + NumZaSaveSlices = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL); + } else if (CalleeAttrs.preservesZA()) { + NumZaSaveSlices = DAG.getConstant(0, DL, MVT::i64); + } + unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj(); MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, TPIDR2Obj); SDValue TPIDR2ObjAddr = DAG.getFrameIndex(TPIDR2Obj, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); - SDValue BufferPtrAddr = + SDValue NumZaSaveSlicesAddr = DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr, DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType())); - Chain = DAG.getTruncStore(Chain, DL, NN, BufferPtrAddr, MPI, MVT::i16); + Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr, + MPI, MVT::i16); Chain = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), @@ -7444,6 +7436,7 @@ Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext()); Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty); + MachineFrameInfo &MFI = MF.getFrameInfo(); int FI = MFI.CreateStackObject(StoreSize, Alignment, false); if (isScalable) MFI.setStackID(FI, TargetStackID::ScalableVector); @@ -7758,35 +7751,34 @@ } if (RequiresLazySave) { - // Unconditionally resume ZA. - Result = DAG.getNode( - AArch64ISD::SMSTART, DL, MVT::Other, Result, - DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), - DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64)); - - // Conditionally restore the lazy save using a pseudo node. - unsigned FI = FuncInfo->getLazySaveTPIDR2Obj(); - SDValue RegMask = DAG.getRegisterMask( - TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); - SDValue RestoreRoutine = DAG.getTargetExternalSymbol( - "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout())); - SDValue TPIDR2_EL0 = DAG.getNode( - ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result, - DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); - - // Copy the address of the TPIDR2 block into X0 before 'calling' the - // RESTORE_ZA pseudo. - SDValue Glue; - SDValue TPIDR2Block = DAG.getFrameIndex( - FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); - Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue); - Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other, - {Result, TPIDR2_EL0, - DAG.getRegister(AArch64::X0, MVT::i64), - RestoreRoutine, - RegMask, - Result.getValue(1)}); - + if (!CalleeAttrs.preservesZA()) { + // Unconditionally resume ZA. + Result = DAG.getNode( + AArch64ISD::SMSTART, DL, MVT::Other, Result, + DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64)); + + // Conditionally restore the lazy save using a pseudo node. + unsigned FI = FuncInfo->getLazySaveTPIDR2Obj(); + SDValue RegMask = DAG.getRegisterMask( + TRI->SMEABISupportRoutinesCallPreservedMaskFromX0()); + SDValue RestoreRoutine = DAG.getTargetExternalSymbol( + "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout())); + SDValue TPIDR2_EL0 = DAG.getNode( + ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result, + DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32)); + + // Copy the address of the TPIDR2 block into X0 before 'calling' the + // RESTORE_ZA pseudo. + SDValue Glue; + SDValue TPIDR2Block = DAG.getFrameIndex( + FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); + Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue); + Result = DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other, + {Result, TPIDR2_EL0, + DAG.getRegister(AArch64::X0, MVT::i64), + RestoreRoutine, RegMask, Result.getValue(1)}); + } // Finally reset the TPIDR2_EL0 register to 0. Result = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Result, diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -35,6 +35,7 @@ ZA_Shared = 1 << 3, // aarch64_pstate_sm_shared ZA_New = 1 << 4, // aarch64_pstate_sm_new ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved + ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves All = ZA_Preserved - 1 }; @@ -42,6 +43,7 @@ SMEAttrs(const Function &F) : SMEAttrs(F.getAttributes()) {} SMEAttrs(const CallBase &CB); SMEAttrs(const AttributeList &L); + SMEAttrs(StringRef FuncName); void set(unsigned M, bool Enable = true); @@ -82,7 +84,7 @@ } bool requiresLazySave(const SMEAttrs &Callee) const { return hasZAState() && Callee.hasPrivateZAInterface() && - !Callee.preservesZA(); + !(Callee.Bitmask & ZA_NoLazySave); } }; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -24,12 +24,26 @@ "ZA_New and ZA_Shared are mutually exclusive"); assert(!(hasNewZABody() && preservesZA()) && "ZA_New and ZA_Preserved are mutually exclusive"); + assert(!(hasNewZABody() && (Bitmask & ZA_NoLazySave)) && + "ZA_New and ZA_NoLazySave are mutually exclusive"); + assert(!(hasSharedZAInterface() && (Bitmask & ZA_NoLazySave)) && + "ZA_Shared and ZA_NoLazySave are mutually exclusive"); } SMEAttrs::SMEAttrs(const CallBase &CB) { *this = SMEAttrs(CB.getAttributes()); - if (auto *F = CB.getCalledFunction()) - set(SMEAttrs(*F).Bitmask); + if (auto *F = CB.getCalledFunction()) { + set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask); + } +} + +SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) { + if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state") + Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved | + SMEAttrs::ZA_NoLazySave); + if (FuncName == "__arm_tpidr2_restore") + Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared | + SMEAttrs::ZA_NoLazySave); } SMEAttrs::SMEAttrs(const AttributeList &Attrs) { diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll --- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll +++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll @@ -2,6 +2,7 @@ ; RUN: llc -mtriple=aarch64 -mattr=+sme < %s | FileCheck %s declare void @private_za_callee() +declare void @private_za_preserved_callee() #0 declare float @llvm.cos.f32(float) ; Test lazy-save mechanism for a single callee. @@ -165,3 +166,53 @@ call void @private_za_callee() ret void } + + +; Test lazy-save mechanism for an aarch64_pstate_za_shared caller +; calling a callee with aarch64_pstate_za_preserved. +define void @za_shared_caller_za_preserved_callee() nounwind "aarch64_pstate_za_shared" "aarch64_pstate_sm_compatible" { +; CHECK-LABEL: za_shared_caller_za_preserved_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp d15, d14, [sp, #-96]! // 16-byte Folded Spill +; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill +; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: add x29, sp, #64 +; CHECK-NEXT: str x19, [sp, #80] // 8-byte Folded Spill +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: stur x8, [x29, #-80] +; CHECK-NEXT: sub x8, x29, #80 +; CHECK-NEXT: sturh wzr, [x29, #-72] +; CHECK-NEXT: msr TPIDR2_EL0, x8 +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: and x19, x0, #0x1 +; CHECK-NEXT: tbz x19, #0, .LBB4_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: smstop sm +; CHECK-NEXT: .LBB4_2: +; CHECK-NEXT: bl private_za_preserved_callee +; CHECK-NEXT: tbz x19, #0, .LBB4_4 +; CHECK-NEXT: // %bb.3: +; CHECK-NEXT: smstart sm +; CHECK-NEXT: .LBB4_4: +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: sub sp, x29, #64 +; CHECK-NEXT: ldp x29, x30, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: ldr x19, [sp, #80] // 8-byte Folded Reload +; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload +; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @private_za_preserved_callee() + ret void +} + +attributes #0 = { + "aarch64_pstate_za_preserved" +}