Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3148,6 +3148,17 @@ // We currently pass all varargs at 8-byte alignment. StackOffset = ((StackOffset + 7) & ~7); FuncInfo->setVarArgsStackIndex(MFI.CreateFixedObject(4, StackOffset, true)); + + if (MFI.hasMustTailInVarArgFunc()) { + SmallVector RegParmTypes; + RegParmTypes.push_back(MVT::i64); + RegParmTypes.push_back(MVT::f128); + // Compute the set of forwarded registers. The rest are scratch. + SmallVectorImpl &Forwards = + FuncInfo->getForwardedMustTailRegParms(); + CCInfo.analyzeMustTailForwardedRegisters(Forwards, RegParmTypes, + CC_AArch64_AAPCS); + } } unsigned StackArgSize = CCInfo.getNextStackOffset(); @@ -3608,6 +3619,14 @@ SmallVector MemOpChains; auto PtrVT = getPointerTy(DAG.getDataLayout()); + if (IsVarArg && CLI.CS && CLI.CS.isMustTailCall()) { + const auto &Forwards = FuncInfo->getForwardedMustTailRegParms(); + for (const auto &F : Forwards) { + SDValue Val = DAG.getCopyFromReg(Chain, DL, F.VReg, F.VT); + RegsToPass.push_back(std::make_pair(unsigned(F.PReg), Val)); + } + } + // Walk the register/memloc assignments, inserting copies/loads. for (unsigned i = 0, realArgIdx = 0, e = ArgLocs.size(); i != e; ++i, ++realArgIdx) { Index: lib/Target/AArch64/AArch64MachineFunctionInfo.h =================================================================== --- lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -18,6 +18,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/MC/MCLinkerOptimizationHint.h" #include @@ -97,6 +98,9 @@ /// attribute, in which case it is set to false at construction. Optional HasRedZone; + /// ForwardedMustTailRegParms - A list of virtual and physical registers + /// that must be forwarded to every musttail call. + SmallVector ForwardedMustTailRegParms; public: AArch64FunctionInfo() = default; @@ -209,6 +213,10 @@ LOHRelated.insert(Args.begin(), Args.end()); } + SmallVectorImpl &getForwardedMustTailRegParms() { + return ForwardedMustTailRegParms; + } + private: // Hold the lists of LOHs. MILOHContainer LOHContainerSet; Index: test/CodeGen/AArch64/vararg-tallcall.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/vararg-tallcall.ll @@ -0,0 +1,31 @@ +; RUN: llc -mtriple=aarch64-windows-msvc %s -o - | FileCheck %s +; RUN: llc -mtriple=aarch64-linux-gnu %s -o - | FileCheck %s + +target datalayout = "e-m:w-p:64:64-i32:32-i64:64-i128:128-n32:64-S128" + +%class.X = type { i8 } +%struct.B = type { i32 (...)** } + +$"??_9B@@$BA@AA" = comdat any + +; Function Attrs: noinline optnone +define linkonce_odr void @"??_9B@@$BA@AA"(%struct.B* %this, ...) #1 comdat align 2 { +entry: + %this.addr = alloca %struct.B*, align 8 + store %struct.B* %this, %struct.B** %this.addr, align 8 + %this1 = load %struct.B*, %struct.B** %this.addr, align 8 + %0 = bitcast %struct.B* %this1 to void (%struct.B*, ...)*** + %vtable = load void (%struct.B*, ...)**, void (%struct.B*, ...)*** %0, align 8 + %vfn = getelementptr inbounds void (%struct.B*, ...)*, void (%struct.B*, ...)** %vtable, i64 0 + %1 = load void (%struct.B*, ...)*, void (%struct.B*, ...)** %vfn, align 8 + musttail call void (%struct.B*, ...) %1(%struct.B* %this1, ...) + ret void + ; No predecessors! + ret void +} + +attributes #1 = { noinline optnone "thunk" } + +; CHECK: ldr x8, [x0] +; CHECK: ldr x8, [x8] +; CHECK: br x8