Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4131,6 +4131,7 @@ static bool mayTailCallThisCC(CallingConv::ID CC) { switch (CC) { case CallingConv::C: + case CallingConv::AArch64_SVE_VectorCall: case CallingConv::PreserveMost: case CallingConv::Swift: return true; @@ -4150,6 +4151,15 @@ MachineFunction &MF = DAG.getMachineFunction(); const Function &CallerF = MF.getFunction(); CallingConv::ID CallerCC = CallerF.getCallingConv(); + + // If this function uses the C calling convention but has an SVE signature, + // then it preserves more registers and should assume the SVE_VectorCall CC. + // The check for matching callee-saved regs will determine whether it is + // eligible for TCO. + if (CallerCC == CallingConv::C && + AArch64RegisterInfo::hasSVEArgsOrReturn(&MF)) + CallerCC = CallingConv::AArch64_SVE_VectorCall; + bool CCMatch = CallerCC == CalleeCC; // When using the Windows calling convention on a non-windows OS, we want @@ -4337,6 +4347,28 @@ bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; bool IsSibCall = false; + // Check callee args/returns for SVE registers and set calling convention + // accordingly. + if (CallConv == CallingConv::C) { + bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ + return Out.VT.isScalableVector(); + }); + bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){ + return In.VT.isScalableVector(); + }); + + // We check if the instruction is an invoke, because unwinders may only + // preserve the callee-saved registers as specified in the base ABI over the + // exception edge. This means that the caller needs to preserve additional + // registers for when the call throws an exception and the unwinder has + // tried to recover state. + bool CallNeedsUnwinding = false; + if (const CallBase *CB = CLI.CB) + CallNeedsUnwinding = isa(CB) && !CB->doesNotThrow(); + if (!CallNeedsUnwinding && (CalleeInSVE || CalleeOutSVE)) + CallConv = CallingConv::AArch64_SVE_VectorCall; + } + if (IsTailCall) { // Check if it's really possible to do a tail call. IsTailCall = isEligibleForTailCallOptimization( @@ -4690,28 +4722,6 @@ Ops.push_back(DAG.getRegister(RegToPass.first, RegToPass.second.getValueType())); - // Check callee args/returns for SVE registers and set calling convention - // accordingly. - if (CallConv == CallingConv::C) { - bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ - return Out.VT.isScalableVector(); - }); - bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){ - return In.VT.isScalableVector(); - }); - - // We check if the instruction is an invoke, because unwinders may only - // preserve the callee-saved registers as specified in the base ABI over the - // exception edge. This means that the caller needs to preserve additional - // registers for when the call throws an exception and the unwinder has - // tried to recover state. - bool CallNeedsUnwinding = false; - if (const CallBase *CB = CLI.CB) - CallNeedsUnwinding = isa(CB) && !CB->doesNotThrow(); - if (!CallNeedsUnwinding && (CalleeInSVE || CalleeOutSVE)) - CallConv = CallingConv::AArch64_SVE_VectorCall; - } - // Add a register mask operand representing the call-preserved registers. const uint32_t *Mask; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); Index: llvm/lib/Target/AArch64/AArch64RegisterInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64RegisterInfo.h +++ llvm/lib/Target/AArch64/AArch64RegisterInfo.h @@ -42,6 +42,8 @@ void UpdateCustomCallPreservedMask(MachineFunction &MF, const uint32_t **Mask) const; + static bool hasSVEArgsOrReturn(const MachineFunction *MF); + /// Code Generation virtual methods... const MCPhysReg *getCalleeSavedRegs(const MachineFunction *MF) const override; const MCPhysReg *getDarwinCalleeSavedRegs(const MachineFunction *MF) const; Index: llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -63,7 +63,7 @@ return true; } -static bool hasSVEArgsOrReturn(const MachineFunction *MF) { +bool AArch64RegisterInfo::hasSVEArgsOrReturn(const MachineFunction *MF) { const Function &F = MF->getFunction(); return isa(F.getReturnType()) || any_of(F.args(), [](const Argument &Arg) { Index: llvm/test/CodeGen/AArch64/sve-tailcall.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-tailcall.ll @@ -0,0 +1,107 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64 -mattr=+sve < %s 2>%t | FileCheck %s +; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t + +; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. +; WARN-NOT: warning + +; Check that a tail call from an SVE function to another SVE function +; can use a tail-call, as the same registers will be preserved by the +; callee. +define @sve_caller_sve_callee() nounwind { +; CHECK-LABEL: sve_caller_sve_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: addvl sp, sp, #-2 +; CHECK-NEXT: str z10, [sp] // 16-byte Folded Spill +; CHECK-NEXT: str z9, [sp, #1, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: //APP +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: ldr z10, [sp] // 16-byte Folded Reload +; CHECK-NEXT: ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: addvl sp, sp, #2 +; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: b sve_callee + tail call void asm sideeffect "", "~{z9},~{z10}"() + %call = tail call @sve_callee() + ret %call +} + +declare @sve_callee() + +; Check that a tail call from an SVE function to a non-SVE function +; does not use a tail-call, because after the call many of the SVE +; registers may be clobbered and needs to be restored. +define i32 @sve_caller_non_sve_callee( %arg) nounwind { +; CHECK-LABEL: sve_caller_non_sve_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: addvl sp, sp, #-18 +; CHECK-NEXT: str p15, [sp, #4, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p14, [sp, #5, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p13, [sp, #6, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p12, [sp, #7, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p11, [sp, #8, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p10, [sp, #9, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p9, [sp, #10, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p8, [sp, #11, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p7, [sp, #12, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p6, [sp, #13, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p5, [sp, #14, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p4, [sp, #15, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str z23, [sp, #2, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z22, [sp, #3, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z21, [sp, #4, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z20, [sp, #5, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z19, [sp, #6, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z18, [sp, #7, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z17, [sp, #8, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z16, [sp, #9, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z15, [sp, #10, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z14, [sp, #11, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z13, [sp, #12, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z12, [sp, #13, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z11, [sp, #14, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z10, [sp, #15, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z9, [sp, #16, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z8, [sp, #17, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: //APP +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: bl non_sve_callee +; CHECK-NEXT: ldr p15, [sp, #4, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p14, [sp, #5, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p13, [sp, #6, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p12, [sp, #7, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p11, [sp, #8, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p10, [sp, #9, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p9, [sp, #10, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p8, [sp, #11, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p7, [sp, #12, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p6, [sp, #13, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p5, [sp, #14, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p4, [sp, #15, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z21, [sp, #4, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z20, [sp, #5, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z19, [sp, #6, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z18, [sp, #7, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z17, [sp, #8, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z16, [sp, #9, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z15, [sp, #10, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z14, [sp, #11, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z13, [sp, #12, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z12, [sp, #13, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z11, [sp, #14, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z10, [sp, #15, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z9, [sp, #16, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z8, [sp, #17, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: addvl sp, sp, #18 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + tail call void asm sideeffect "", "~{z9},~{z10}"() + %call = tail call i32 @non_sve_callee() + ret i32 %call +} + +declare i32 @non_sve_callee()