diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4818,9 +4818,11 @@ if (auto *ES = dyn_cast(V)) { StringRef S(ES->getSymbol()); if (S == "__arm_sme_state" || S == "__arm_tpidr2_save") - return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved); + return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved | + SMEAttrs::ZA_NoLazySave); if (S == "__arm_tpidr2_restore") - return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared); + return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared | + SMEAttrs::ZA_NoLazySave); } return std::nullopt; } @@ -7321,23 +7323,27 @@ CalleeAttrs = *Attrs; bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); - - MachineFrameInfo &MFI = MF.getFrameInfo(); if (RequiresLazySave) { - // Set up a lazy save mechanism by storing the runtime live slices - // (worst-case N*N) to the TPIDR2 stack object. - SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, - DAG.getConstant(1, DL, MVT::i32)); - SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N); - unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj(); + SDValue NumZaSaveSlices; + if (!CalleeAttrs.preservesZA()) { + // Set up a lazy save mechanism by storing the runtime live slices + // (worst-case SVL*SVL) to the TPIDR2 stack object. + SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getConstant(1, DL, MVT::i32)); + NumZaSaveSlices = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL); + } else if (CalleeAttrs.preservesZA()) { + NumZaSaveSlices = DAG.getConstant(0, DL, MVT::i64); + } + unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj(); MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, TPIDR2Obj); SDValue TPIDR2ObjAddr = DAG.getFrameIndex(TPIDR2Obj, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); SDValue BufferPtrAddr = DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr, DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType())); - Chain = DAG.getTruncStore(Chain, DL, NN, BufferPtrAddr, MPI, MVT::i16); + Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, BufferPtrAddr, MPI, + MVT::i16); Chain = DAG.getNode( ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), @@ -7444,6 +7450,7 @@ Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext()); Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty); + MachineFrameInfo &MFI = MF.getFrameInfo(); int FI = MFI.CreateStackObject(StoreSize, Alignment, false); if (isScalable) MFI.setStackID(FI, TargetStackID::ScalableVector); @@ -7757,7 +7764,7 @@ PStateSM, false); } - if (RequiresLazySave) { + if (RequiresLazySave && !CalleeAttrs.preservesZA()) { // Unconditionally resume ZA. Result = DAG.getNode( AArch64ISD::SMSTART, DL, MVT::Other, Result, @@ -7794,6 +7801,13 @@ DAG.getConstant(0, DL, MVT::i64)); } + if (RequiresLazySave && CalleeAttrs.preservesZA()) { + Result = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, Result, + DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i64)); + } + if (RequiresSMChange || RequiresLazySave) { for (unsigned I = 0; I < InVals.size(); ++I) { // The smstart/smstop is chained as part of the call, but when the diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -35,6 +35,7 @@ ZA_Shared = 1 << 3, // aarch64_pstate_sm_shared ZA_New = 1 << 4, // aarch64_pstate_sm_new ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved + ZA_NoLazySave = 1 << 6, // Don't emit TPI_DR2 lazy saves All = ZA_Preserved - 1 }; @@ -80,9 +81,10 @@ bool hasZAState() const { return hasNewZABody() || hasSharedZAInterface(); } + bool hasNoLazySave() const { return Bitmask & ZA_NoLazySave; } bool requiresLazySave(const SMEAttrs &Callee) const { return hasZAState() && Callee.hasPrivateZAInterface() && - !Callee.preservesZA(); + !Callee.hasNoLazySave(); } }; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -28,8 +28,13 @@ SMEAttrs::SMEAttrs(const CallBase &CB) { *this = SMEAttrs(CB.getAttributes()); - if (auto *F = CB.getCalledFunction()) + if (auto *F = CB.getCalledFunction()) { set(SMEAttrs(*F).Bitmask); + StringRef FuncName = F->getName(); + if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state" || + FuncName == "__arm_tpidr2_restore") + Bitmask |= ZA_NoLazySave; + } } SMEAttrs::SMEAttrs(const AttributeList &Attrs) { diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll --- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll +++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll @@ -2,6 +2,7 @@ ; RUN: llc -mtriple=aarch64 -mattr=+sme < %s | FileCheck %s declare void @private_za_callee() +declare void @private_za_preserved_callee() #0 declare float @llvm.cos.f32(float) ; Test lazy-save mechanism for a single callee. @@ -165,3 +166,32 @@ call void @private_za_callee() ret void } + + +; Test lazy-save mechanism for a single callee with aarch64_pstate_za_preserved. +define void @foo() nounwind "aarch64_pstate_za_shared" { +; CHECK-LABEL: foo: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: stur x8, [x29, #-16] +; CHECK-NEXT: sub x8, x29, #16 +; CHECK-NEXT: sturh wzr, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x8 +; CHECK-NEXT: bl private_za_preserved_callee +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @private_za_preserved_callee() + ret void +} + +attributes #0 = { + "aarch64_pstate_za_preserved" +}