diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7299,28 +7299,42 @@ getCalleeAttrsFromExternalFunction(CLI.Callee)) CalleeAttrs = *Attrs; - bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); - - MachineFrameInfo &MFI = MF.getFrameInfo(); - if (RequiresLazySave) { - // Set up a lazy save mechanism by storing the runtime live slices - // (worst-case N*N) to the TPIDR2 stack object. - SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, - DAG.getConstant(1, DL, MVT::i32)); - SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N); + auto createTPIDR2StackObject = [&](SDValue C, SDValue NumZaSaveSlices) { unsigned TPIDR2Obj = FuncInfo->getLazySaveTPIDR2Obj(); - MachinePointerInfo MPI = MachinePointerInfo::getStack(MF, TPIDR2Obj); SDValue TPIDR2ObjAddr = DAG.getFrameIndex(TPIDR2Obj, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); SDValue BufferPtrAddr = DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr, DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType())); - Chain = DAG.getTruncStore(Chain, DL, NN, BufferPtrAddr, MPI, MVT::i16); - Chain = DAG.getNode( - ISD::INTRINSIC_VOID, DL, MVT::Other, Chain, + C = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, BufferPtrAddr, MPI, + MVT::i16); + C = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, C, DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), TPIDR2ObjAddr); + return C; + }; + + bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); + if (RequiresLazySave) { + // Set up a lazy save mechanism by storing the runtime live slices + // (worst-case N*N) to the TPIDR2 stack object. + SDValue N = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getConstant(1, DL, MVT::i32)); + SDValue NN = DAG.getNode(ISD::MUL, DL, MVT::i64, N, N); + Chain = createTPIDR2StackObject(Chain, NN); + } + + bool HasPreservedZAInterface = !RequiresLazySave && + CalleeAttrs.hasPreservedZAInterface() && + !CalleeAttrs.hasStreamingCompatibleInterface(); + if (HasPreservedZAInterface) { + // Set up a lazy save mechanism by storing 0 to the runtime live slices + // Callees with Private ZA interface's num_za_save_slices must be + // 0 < num_za_save_slices <= N x 16 to comply with the dormant state on + // entry. + Chain = createTPIDR2StackObject(Chain, DAG.getConstant(0, DL, MVT::i64)); } SDValue PStateSM; @@ -7423,6 +7437,7 @@ Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext()); Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty); + MachineFrameInfo &MFI = MF.getFrameInfo(); int FI = MFI.CreateStackObject(StoreSize, Alignment, false); if (isScalable) MFI.setStackID(FI, TargetStackID::ScalableVector); @@ -7778,6 +7793,17 @@ DAG.getConstant(0, DL, MVT::i64)); } + if (HasPreservedZAInterface) { + Result = DAG.getNode( + AArch64ISD::SMSTART, DL, MVT::Other, Result, + DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64)); + Result = DAG.getNode( + ISD::INTRINSIC_VOID, DL, MVT::Other, Result, + DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i64)); + } + if (RequiresSMChange || RequiresLazySave) { for (unsigned I = 0; I < InVals.size(); ++I) { // The smstart/smstop is chained as part of the call, but when the diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -76,13 +76,13 @@ bool hasNewZAInterface() const { return Bitmask & ZA_New; } bool hasSharedZAInterface() const { return Bitmask & ZA_Shared; } bool hasPrivateZAInterface() const { return !hasSharedZAInterface(); } - bool preservesZA() const { return Bitmask & ZA_Preserved; } + bool hasPreservedZAInterface() const { return Bitmask & ZA_Preserved; } bool hasZAState() const { return hasNewZAInterface() || hasSharedZAInterface(); } bool requiresLazySave(const SMEAttrs &Callee) const { return hasZAState() && Callee.hasPrivateZAInterface() && - !Callee.preservesZA(); + !Callee.hasPreservedZAInterface(); } }; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -22,7 +22,7 @@ "SM_Enabled and SM_Compatible are mutually exclusive"); assert(!(hasNewZAInterface() && hasSharedZAInterface()) && "ZA_New and ZA_Shared are mutually exclusive"); - assert(!(hasNewZAInterface() && preservesZA()) && + assert(!(hasNewZAInterface() && hasPreservedZAInterface()) && "ZA_New and ZA_Preserved are mutually exclusive"); } diff --git a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll --- a/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll +++ b/llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll @@ -2,6 +2,7 @@ ; RUN: llc -mtriple=aarch64 -mattr=+sme < %s | FileCheck %s declare void @private_za_callee() +declare void @private_za_preserved_callee() #0 declare float @llvm.cos.f32(float) ; Test lazy-save mechanism for a single callee. @@ -165,3 +166,33 @@ call void @private_za_callee() ret void } + + +; Test lazy-save mechanism for a single callee with aarch64_pstate_za_preserved. +define void @foo() nounwind "aarch64_pstate_za_shared" { +; CHECK-LABEL: foo: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: mov x9, sp +; CHECK-NEXT: msub x8, x8, x8, x9 +; CHECK-NEXT: mov sp, x8 +; CHECK-NEXT: stur x8, [x29, #-16] +; CHECK-NEXT: sub x8, x29, #16 +; CHECK-NEXT: sturh wzr, [x29, #-8] +; CHECK-NEXT: msr TPIDR2_EL0, x8 +; CHECK-NEXT: bl private_za_preserved_callee +; CHECK-NEXT: smstart za +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @private_za_preserved_callee() + ret void +} + +attributes #0 = { + "aarch64_pstate_za_preserved" +} diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp --- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp +++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp @@ -48,7 +48,7 @@ ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_pstate_za_preserved\"") ->getFunction("foo")) - .preservesZA()); + .hasPreservedZAInterface()); // Invalid combinations. EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible), @@ -83,18 +83,18 @@ ASSERT_FALSE(SA(SA::ZA_Shared).hasPrivateZAInterface()); ASSERT_TRUE(SA(SA::ZA_Shared).hasSharedZAInterface()); ASSERT_TRUE(SA(SA::ZA_Shared).hasZAState()); - ASSERT_FALSE(SA(SA::ZA_Shared).preservesZA()); - ASSERT_TRUE(SA(SA::ZA_Shared | SA::ZA_Preserved).preservesZA()); + ASSERT_FALSE(SA(SA::ZA_Shared).hasPreservedZAInterface()); + ASSERT_TRUE(SA(SA::ZA_Shared | SA::ZA_Preserved).hasPreservedZAInterface()); ASSERT_TRUE(SA(SA::ZA_New).hasPrivateZAInterface()); ASSERT_TRUE(SA(SA::ZA_New).hasNewZAInterface()); ASSERT_TRUE(SA(SA::ZA_New).hasZAState()); - ASSERT_FALSE(SA(SA::ZA_New).preservesZA()); + ASSERT_FALSE(SA(SA::ZA_New).hasPreservedZAInterface()); ASSERT_TRUE(SA(SA::Normal).hasPrivateZAInterface()); ASSERT_FALSE(SA(SA::Normal).hasNewZAInterface()); ASSERT_FALSE(SA(SA::Normal).hasZAState()); - ASSERT_FALSE(SA(SA::Normal).preservesZA()); + ASSERT_FALSE(SA(SA::Normal).hasPreservedZAInterface()); } TEST(SMEAttributes, Transitions) {