diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4132,6 +4132,7 @@ static bool mayTailCallThisCC(CallingConv::ID CC) { switch (CC) { case CallingConv::C: + case CallingConv::AArch64_SVE_VectorCall: case CallingConv::PreserveMost: case CallingConv::Swift: return true; @@ -4151,6 +4152,15 @@ MachineFunction &MF = DAG.getMachineFunction(); const Function &CallerF = MF.getFunction(); CallingConv::ID CallerCC = CallerF.getCallingConv(); + + // If this function uses the C calling convention but has an SVE signature, + // then it preserves more registers and should assume the SVE_VectorCall CC. + // The check for matching callee-saved regs will determine whether it is + // eligible for TCO. + if (CallerCC == CallingConv::C && + AArch64RegisterInfo::hasSVEArgsOrReturn(&MF)) + CallerCC = CallingConv::AArch64_SVE_VectorCall; + bool CCMatch = CallerCC == CalleeCC; // When using the Windows calling convention on a non-windows OS, we want @@ -4338,6 +4348,20 @@ bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; bool IsSibCall = false; + // Check callee args/returns for SVE registers and set calling convention + // accordingly. + if (CallConv == CallingConv::C) { + bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ + return Out.VT.isScalableVector(); + }); + bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){ + return In.VT.isScalableVector(); + }); + + if (CalleeInSVE || CalleeOutSVE) + CallConv = CallingConv::AArch64_SVE_VectorCall; + } + if (IsTailCall) { // Check if it's really possible to do a tail call. IsTailCall = isEligibleForTailCallOptimization( @@ -4691,20 +4715,6 @@ Ops.push_back(DAG.getRegister(RegToPass.first, RegToPass.second.getValueType())); - // Check callee args/returns for SVE registers and set calling convention - // accordingly. - if (CallConv == CallingConv::C) { - bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ - return Out.VT.isScalableVector(); - }); - bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){ - return In.VT.isScalableVector(); - }); - - if (CalleeInSVE || CalleeOutSVE) - CallConv = CallingConv::AArch64_SVE_VectorCall; - } - // Add a register mask operand representing the call-preserved registers. const uint32_t *Mask; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h @@ -42,6 +42,8 @@ void UpdateCustomCallPreservedMask(MachineFunction &MF, const uint32_t **Mask) const; + static bool hasSVEArgsOrReturn(const MachineFunction *MF); + /// Code Generation virtual methods... const MCPhysReg *getCalleeSavedRegs(const MachineFunction *MF) const override; const MCPhysReg *getDarwinCalleeSavedRegs(const MachineFunction *MF) const; diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -63,7 +63,7 @@ return true; } -static bool hasSVEArgsOrReturn(const MachineFunction *MF) { +bool AArch64RegisterInfo::hasSVEArgsOrReturn(const MachineFunction *MF) { const Function &F = MF->getFunction(); return isa(F.getReturnType()) || any_of(F.args(), [](const Argument &Arg) { diff --git a/llvm/test/CodeGen/AArch64/sve-tailcall.ll b/llvm/test/CodeGen/AArch64/sve-tailcall.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-tailcall.ll @@ -0,0 +1,107 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64 -mattr=+sve < %s 2>%t | FileCheck %s +; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t + +; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. +; WARN-NOT: warning + +; Check that a tail call from an SVE function to another SVE function +; can use a tail-call, as the same registers will be preserved by the +; callee. +define @sve_caller_sve_callee() nounwind { +; CHECK-LABEL: sve_caller_sve_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: addvl sp, sp, #-2 +; CHECK-NEXT: str z10, [sp] // 16-byte Folded Spill +; CHECK-NEXT: str z9, [sp, #1, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: //APP +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: ldr z10, [sp] // 16-byte Folded Reload +; CHECK-NEXT: ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: addvl sp, sp, #2 +; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: b sve_callee + tail call void asm sideeffect "", "~{z9},~{z10}"() + %call = tail call @sve_callee() + ret %call +} + +declare @sve_callee() + +; Check that a tail call from an SVE function to a non-SVE function +; does not use a tail-call, because after the call many of the SVE +; registers may be clobbered and needs to be restored. +define i32 @sve_caller_non_sve_callee( %arg) nounwind { +; CHECK-LABEL: sve_caller_non_sve_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill +; CHECK-NEXT: addvl sp, sp, #-18 +; CHECK-NEXT: str p15, [sp, #4, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p14, [sp, #5, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p13, [sp, #6, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p12, [sp, #7, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p11, [sp, #8, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p10, [sp, #9, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p9, [sp, #10, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p8, [sp, #11, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p7, [sp, #12, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p6, [sp, #13, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p5, [sp, #14, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str p4, [sp, #15, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: str z23, [sp, #2, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z22, [sp, #3, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z21, [sp, #4, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z20, [sp, #5, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z19, [sp, #6, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z18, [sp, #7, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z17, [sp, #8, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z16, [sp, #9, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z15, [sp, #10, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z14, [sp, #11, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z13, [sp, #12, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z12, [sp, #13, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z11, [sp, #14, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z10, [sp, #15, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z9, [sp, #16, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: str z8, [sp, #17, mul vl] // 16-byte Folded Spill +; CHECK-NEXT: //APP +; CHECK-NEXT: //NO_APP +; CHECK-NEXT: bl non_sve_callee +; CHECK-NEXT: ldr p15, [sp, #4, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p14, [sp, #5, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p13, [sp, #6, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p12, [sp, #7, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p11, [sp, #8, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p10, [sp, #9, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p9, [sp, #10, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p8, [sp, #11, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p7, [sp, #12, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p6, [sp, #13, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p5, [sp, #14, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr p4, [sp, #15, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z21, [sp, #4, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z20, [sp, #5, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z19, [sp, #6, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z18, [sp, #7, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z17, [sp, #8, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z16, [sp, #9, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z15, [sp, #10, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z14, [sp, #11, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z13, [sp, #12, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z12, [sp, #13, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z11, [sp, #14, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z10, [sp, #15, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z9, [sp, #16, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: ldr z8, [sp, #17, mul vl] // 16-byte Folded Reload +; CHECK-NEXT: addvl sp, sp, #18 +; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload +; CHECK-NEXT: ret + tail call void asm sideeffect "", "~{z9},~{z10}"() + %call = tail call i32 @non_sve_callee() + ret i32 %call +} + +declare i32 @non_sve_callee()