Index: llvm/trunk/include/llvm/CodeGen/CallingConvLower.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/CallingConvLower.h +++ llvm/trunk/include/llvm/CodeGen/CallingConvLower.h @@ -158,6 +158,16 @@ } }; +/// Describes a register that needs to be forwarded from the prologue to a +/// musttail call. +struct ForwardedRegister { + ForwardedRegister(unsigned VReg, MCPhysReg PReg, MVT VT) + : VReg(VReg), PReg(PReg), VT(VT) {} + unsigned VReg; + MCPhysReg PReg; + MVT VT; +}; + /// CCAssignFn - This function assigns a location for Val, updating State to /// reflect the change. It returns 'true' if it failed to handle Val. typedef bool CCAssignFn(unsigned ValNo, MVT ValVT, @@ -470,6 +480,19 @@ return PendingLocs; } + /// Compute the remaining unused register parameters that would be used for + /// the given value type. This is useful when varargs are passed in the + /// registers that normal prototyped parameters would be passed in, or for + /// implementing perfect forwarding. + void getRemainingRegParmsForType(SmallVectorImpl &Regs, MVT VT, + CCAssignFn Fn); + + /// Compute the set of registers that need to be preserved and forwarded to + /// any musttail calls. + void analyzeMustTailForwardedRegisters( + SmallVectorImpl &Forwards, ArrayRef RegParmTypes, + CCAssignFn Fn); + private: /// MarkAllocated - Mark a register and all of its aliases as allocated. void MarkAllocated(unsigned Reg); Index: llvm/trunk/lib/CodeGen/CallingConvLower.cpp =================================================================== --- llvm/trunk/lib/CodeGen/CallingConvLower.cpp +++ llvm/trunk/lib/CodeGen/CallingConvLower.cpp @@ -14,9 +14,11 @@ #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetRegisterInfo.h" @@ -178,3 +180,57 @@ llvm_unreachable(nullptr); } } + +void CCState::getRemainingRegParmsForType(SmallVectorImpl &Regs, + MVT VT, CCAssignFn Fn) { + unsigned SavedStackOffset = StackOffset; + unsigned NumLocs = Locs.size(); + + // Allocate something of this value type repeatedly with just the inreg flag + // set until we get assigned a location in memory. + ISD::ArgFlagsTy Flags; + Flags.setInReg(); + bool HaveRegParm = true; + while (HaveRegParm) { + if (Fn(0, VT, VT, CCValAssign::Full, Flags, *this)) { +#ifndef NDEBUG + dbgs() << "Call has unhandled type " << EVT(VT).getEVTString() + << " while computing remaining regparms\n"; +#endif + llvm_unreachable(nullptr); + } + HaveRegParm = Locs.back().isRegLoc(); + } + + // Copy all the registers from the value locations we added. + assert(NumLocs < Locs.size() && "CC assignment failed to add location"); + for (unsigned I = NumLocs, E = Locs.size(); I != E; ++I) + if (Locs[I].isRegLoc()) + Regs.push_back(MCPhysReg(Locs[I].getLocReg())); + + // Clear the assigned values and stack memory. We leave the registers marked + // as allocated so that future queries don't return the same registers, i.e. + // when i64 and f64 are both passed in GPRs. + StackOffset = SavedStackOffset; + Locs.resize(NumLocs); +} + +void CCState::analyzeMustTailForwardedRegisters( + SmallVectorImpl &Forwards, ArrayRef RegParmTypes, + CCAssignFn Fn) { + // Oftentimes calling conventions will not user register parameters for + // variadic functions, so we need to assume we're not variadic so that we get + // all the registers that might be used in a non-variadic call. + SaveAndRestore SavedVarArg(IsVarArg, false); + + for (MVT RegVT : RegParmTypes) { + SmallVector RemainingRegs; + getRemainingRegParmsForType(RemainingRegs, RegVT, Fn); + const TargetLowering *TL = MF.getSubtarget().getTargetLowering(); + const TargetRegisterClass *RC = TL->getRegClassFor(RegVT); + for (MCPhysReg PReg : RemainingRegs) { + unsigned VReg = MF.addLiveIn(PReg, RC); + Forwards.push_back(ForwardedRegister(VReg, PReg, RegVT)); + } + } +} Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -2549,11 +2549,19 @@ MFI->CreateFixedObject(1, StackSize, true)); } + // Figure out if XMM registers are in use. + bool HaveXMMArgs = Is64Bit && !IsWin64; + bool NoImplicitFloatOps = Fn->getAttributes().hasAttribute( + AttributeSet::FunctionIndex, Attribute::NoImplicitFloat); + assert(!(MF.getTarget().Options.UseSoftFloat && NoImplicitFloatOps) && + "SSE register cannot be used when SSE is disabled!"); + if (MF.getTarget().Options.UseSoftFloat || NoImplicitFloatOps || + !Subtarget->hasSSE1()) + HaveXMMArgs = false; + // 64-bit calling conventions support varargs and register parameters, so we - // have to do extra work to spill them in the prologue or forward them to - // musttail calls. - if (Is64Bit && isVarArg && - (MFI->hasVAStart() || MFI->hasMustTailInVarArgFunc())) { + // have to do extra work to spill them in the prologue. + if (Is64Bit && isVarArg && MFI->hasVAStart()) { // Find the first unallocated argument registers. ArrayRef ArgGPRs = get64BitArgumentGPRs(CallConv, Subtarget); ArrayRef ArgXMMs = get64BitArgumentXMMs(MF, CallConv, Subtarget); @@ -2583,90 +2591,99 @@ } } - // Store them to the va_list returned by va_start. - if (MFI->hasVAStart()) { - if (IsWin64) { - const TargetFrameLowering &TFI = *MF.getSubtarget().getFrameLowering(); - // Get to the caller-allocated home save location. Add 8 to account - // for the return address. - int HomeOffset = TFI.getOffsetOfLocalArea() + 8; - FuncInfo->setRegSaveFrameIndex( + if (IsWin64) { + const TargetFrameLowering &TFI = *MF.getSubtarget().getFrameLowering(); + // Get to the caller-allocated home save location. Add 8 to account + // for the return address. + int HomeOffset = TFI.getOffsetOfLocalArea() + 8; + FuncInfo->setRegSaveFrameIndex( MFI->CreateFixedObject(1, NumIntRegs * 8 + HomeOffset, false)); - // Fixup to set vararg frame on shadow area (4 x i64). - if (NumIntRegs < 4) - FuncInfo->setVarArgsFrameIndex(FuncInfo->getRegSaveFrameIndex()); - } else { - // For X86-64, if there are vararg parameters that are passed via - // registers, then we must store them to their spots on the stack so - // they may be loaded by deferencing the result of va_next. - FuncInfo->setVarArgsGPOffset(NumIntRegs * 8); - FuncInfo->setVarArgsFPOffset(ArgGPRs.size() * 8 + NumXMMRegs * 16); - FuncInfo->setRegSaveFrameIndex(MFI->CreateStackObject( - ArgGPRs.size() * 8 + ArgXMMs.size() * 16, 16, false)); - } - - // Store the integer parameter registers. - SmallVector MemOps; - SDValue RSFIN = DAG.getFrameIndex(FuncInfo->getRegSaveFrameIndex(), - getPointerTy()); - unsigned Offset = FuncInfo->getVarArgsGPOffset(); - for (SDValue Val : LiveGPRs) { - SDValue FIN = DAG.getNode(ISD::ADD, dl, getPointerTy(), RSFIN, - DAG.getIntPtrConstant(Offset)); - SDValue Store = - DAG.getStore(Val.getValue(1), dl, Val, FIN, - MachinePointerInfo::getFixedStack( - FuncInfo->getRegSaveFrameIndex(), Offset), - false, false, 0); - MemOps.push_back(Store); - Offset += 8; - } - - if (!ArgXMMs.empty() && NumXMMRegs != ArgXMMs.size()) { - // Now store the XMM (fp + vector) parameter registers. - SmallVector SaveXMMOps; - SaveXMMOps.push_back(Chain); - SaveXMMOps.push_back(ALVal); - SaveXMMOps.push_back(DAG.getIntPtrConstant( - FuncInfo->getRegSaveFrameIndex())); - SaveXMMOps.push_back(DAG.getIntPtrConstant( - FuncInfo->getVarArgsFPOffset())); - SaveXMMOps.insert(SaveXMMOps.end(), LiveXMMRegs.begin(), - LiveXMMRegs.end()); - MemOps.push_back(DAG.getNode(X86ISD::VASTART_SAVE_XMM_REGS, dl, - MVT::Other, SaveXMMOps)); - } - - if (!MemOps.empty()) - Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, MemOps); + // Fixup to set vararg frame on shadow area (4 x i64). + if (NumIntRegs < 4) + FuncInfo->setVarArgsFrameIndex(FuncInfo->getRegSaveFrameIndex()); } else { - // Add all GPRs, al, and XMMs to the list of forwards. We will add then - // to the liveout set on a musttail call. - assert(MFI->hasMustTailInVarArgFunc()); - auto &Forwards = FuncInfo->getForwardedMustTailRegParms(); - typedef X86MachineFunctionInfo::Forward Forward; - - for (unsigned I = 0, E = LiveGPRs.size(); I != E; ++I) { - unsigned VReg = - MF.getRegInfo().createVirtualRegister(&X86::GR64RegClass); - Chain = DAG.getCopyToReg(Chain, dl, VReg, LiveGPRs[I]); - Forwards.push_back(Forward(VReg, ArgGPRs[NumIntRegs + I], MVT::i64)); - } - - if (!ArgXMMs.empty()) { - unsigned ALVReg = - MF.getRegInfo().createVirtualRegister(&X86::GR8RegClass); - Chain = DAG.getCopyToReg(Chain, dl, ALVReg, ALVal); - Forwards.push_back(Forward(ALVReg, X86::AL, MVT::i8)); - - for (unsigned I = 0, E = LiveXMMRegs.size(); I != E; ++I) { - unsigned VReg = - MF.getRegInfo().createVirtualRegister(&X86::VR128RegClass); - Chain = DAG.getCopyToReg(Chain, dl, VReg, LiveXMMRegs[I]); - Forwards.push_back( - Forward(VReg, ArgXMMs[NumXMMRegs + I], MVT::v4f32)); - } - } + // For X86-64, if there are vararg parameters that are passed via + // registers, then we must store them to their spots on the stack so + // they may be loaded by deferencing the result of va_next. + FuncInfo->setVarArgsGPOffset(NumIntRegs * 8); + FuncInfo->setVarArgsFPOffset(ArgGPRs.size() * 8 + NumXMMRegs * 16); + FuncInfo->setRegSaveFrameIndex(MFI->CreateStackObject( + ArgGPRs.size() * 8 + ArgXMMs.size() * 16, 16, false)); + } + + // Store the integer parameter registers. + SmallVector MemOps; + SDValue RSFIN = DAG.getFrameIndex(FuncInfo->getRegSaveFrameIndex(), + getPointerTy()); + unsigned Offset = FuncInfo->getVarArgsGPOffset(); + for (SDValue Val : LiveGPRs) { + SDValue FIN = DAG.getNode(ISD::ADD, dl, getPointerTy(), RSFIN, + DAG.getIntPtrConstant(Offset)); + SDValue Store = + DAG.getStore(Val.getValue(1), dl, Val, FIN, + MachinePointerInfo::getFixedStack( + FuncInfo->getRegSaveFrameIndex(), Offset), + false, false, 0); + MemOps.push_back(Store); + Offset += 8; + } + + if (!ArgXMMs.empty() && NumXMMRegs != ArgXMMs.size()) { + // Now store the XMM (fp + vector) parameter registers. + SmallVector SaveXMMOps; + SaveXMMOps.push_back(Chain); + SaveXMMOps.push_back(ALVal); + SaveXMMOps.push_back(DAG.getIntPtrConstant( + FuncInfo->getRegSaveFrameIndex())); + SaveXMMOps.push_back(DAG.getIntPtrConstant( + FuncInfo->getVarArgsFPOffset())); + SaveXMMOps.insert(SaveXMMOps.end(), LiveXMMRegs.begin(), + LiveXMMRegs.end()); + MemOps.push_back(DAG.getNode(X86ISD::VASTART_SAVE_XMM_REGS, dl, + MVT::Other, SaveXMMOps)); + } + + if (!MemOps.empty()) + Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, MemOps); + } + + if (isVarArg && MFI->hasMustTailInVarArgFunc()) { + // Find the largest legal vector type. + MVT VecVT = MVT::Other; + // FIXME: Only some x86_32 calling conventions support AVX512. + if (Subtarget->hasAVX512() && + (Is64Bit || (CallConv == CallingConv::X86_VectorCall || + CallConv == CallingConv::Intel_OCL_BI))) + VecVT = MVT::v16f32; + else if (Subtarget->hasAVX()) + VecVT = MVT::v8f32; + else if (Subtarget->hasSSE2()) + VecVT = MVT::v4f32; + + // We forward some GPRs and some vector types. + SmallVector RegParmTypes; + MVT IntVT = Is64Bit ? MVT::i64 : MVT::i32; + RegParmTypes.push_back(IntVT); + if (VecVT != MVT::Other) + RegParmTypes.push_back(VecVT); + + // Compute the set of forwarded registers. The rest are scratch. + SmallVectorImpl &Forwards = + FuncInfo->getForwardedMustTailRegParms(); + CCInfo.analyzeMustTailForwardedRegisters(Forwards, RegParmTypes, CC_X86); + + // Conservatively forward AL on x86_64, since it might be used for varargs. + if (Is64Bit && !CCInfo.isAllocated(X86::AL)) { + unsigned ALVReg = MF.addLiveIn(X86::AL, &X86::GR8RegClass); + Forwards.push_back(ForwardedRegister(ALVReg, X86::AL, MVT::i8)); + } + + // Copy all forwards from physical to virtual registers. + for (ForwardedRegister &F : Forwards) { + // FIXME: Can we use a less constrained schedule? + SDValue RegVal = DAG.getCopyFromReg(Chain, dl, F.VReg, F.VT); + F.VReg = MF.getRegInfo().createVirtualRegister(getRegClassFor(F.VT)); + Chain = DAG.getCopyToReg(Chain, dl, F.VReg, RegVal); } } @@ -2986,7 +3003,7 @@ DAG.getConstant(NumXMMRegs, MVT::i8))); } - if (Is64Bit && isVarArg && IsMustTail) { + if (isVarArg && IsMustTail) { const auto &Forwards = X86Info->getForwardedMustTailRegParms(); for (const auto &F : Forwards) { SDValue Val = DAG.getCopyFromReg(Chain, dl, F.VReg, F.VT); Index: llvm/trunk/lib/Target/X86/X86MachineFunctionInfo.h =================================================================== --- llvm/trunk/lib/Target/X86/X86MachineFunctionInfo.h +++ llvm/trunk/lib/Target/X86/X86MachineFunctionInfo.h @@ -14,6 +14,7 @@ #ifndef LLVM_LIB_TARGET_X86_X86MACHINEFUNCTIONINFO_H #define LLVM_LIB_TARGET_X86_X86MACHINEFUNCTIONINFO_H +#include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineValueType.h" #include @@ -77,21 +78,10 @@ /// NumLocalDynamics - Number of local-dynamic TLS accesses. unsigned NumLocalDynamics; -public: - /// Describes a register that needs to be forwarded from the prologue to a - /// musttail call. - struct Forward { - Forward(unsigned VReg, MCPhysReg PReg, MVT VT) - : VReg(VReg), PReg(PReg), VT(VT) {} - unsigned VReg; - MCPhysReg PReg; - MVT VT; - }; - private: /// ForwardedMustTailRegParms - A list of virtual and physical registers /// that must be forwarded to every musttail call. - std::vector ForwardedMustTailRegParms; + SmallVector ForwardedMustTailRegParms; public: X86MachineFunctionInfo() : ForceFramePointer(false), @@ -168,7 +158,7 @@ unsigned getNumLocalDynamicTLSAccesses() const { return NumLocalDynamics; } void incNumLocalDynamicTLSAccesses() { ++NumLocalDynamics; } - std::vector &getForwardedMustTailRegParms() { + SmallVectorImpl &getForwardedMustTailRegParms() { return ForwardedMustTailRegParms; } }; Index: llvm/trunk/test/CodeGen/X86/musttail-fastcall.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/musttail-fastcall.ll +++ llvm/trunk/test/CodeGen/X86/musttail-fastcall.ll @@ -0,0 +1,109 @@ +; RUN: llc < %s -mtriple=i686-pc-win32 -mattr=+sse2 | FileCheck %s --check-prefix=CHECK --check-prefix=SSE2 +; RUN: llc < %s -mtriple=i686-pc-win32 -mattr=+sse2,+avx | FileCheck %s --check-prefix=CHECK --check-prefix=AVX +; RUN: llc < %s -mtriple=i686-pc-win32 -mattr=+sse2,+avx,+avx512f | FileCheck %s --check-prefix=CHECK --check-prefix=AVX512 + +; While we don't support varargs with fastcall, we do support forwarding. + +@asdf = internal constant [4 x i8] c"asdf" + +declare void @puts(i8*) + +define i32 @call_fast_thunk() { + %r = call x86_fastcallcc i32 (...)* @fast_thunk(i32 inreg 1, i32 inreg 2, i32 3) + ret i32 %r +} + +define x86_fastcallcc i32 @fast_thunk(...) { + call void @puts(i8* getelementptr ([4 x i8]* @asdf, i32 0, i32 0)) + %r = musttail call x86_fastcallcc i32 (...)* bitcast (i32 (i32, i32, i32)* @fast_target to i32 (...)*) (...) + ret i32 %r +} + +; Check that we spill and fill around the call to puts. + +; CHECK-LABEL: @fast_thunk@0: +; CHECK-DAG: movl %ecx, {{.*}} +; CHECK-DAG: movl %edx, {{.*}} +; CHECK: calll _puts +; CHECK-DAG: movl {{.*}}, %ecx +; CHECK-DAG: movl {{.*}}, %edx +; CHECK: jmp @fast_target@12 + +define x86_fastcallcc i32 @fast_target(i32 inreg %a, i32 inreg %b, i32 %c) { + %a0 = add i32 %a, %b + %a1 = add i32 %a0, %c + ret i32 %a1 +} + +; Repeat the test for vectorcall, which has XMM registers. + +define i32 @call_vector_thunk() { + %r = call x86_vectorcallcc i32 (...)* @vector_thunk(i32 inreg 1, i32 inreg 2, i32 3) + ret i32 %r +} + +define x86_vectorcallcc i32 @vector_thunk(...) { + call void @puts(i8* getelementptr ([4 x i8]* @asdf, i32 0, i32 0)) + %r = musttail call x86_vectorcallcc i32 (...)* bitcast (i32 (i32, i32, i32)* @vector_target to i32 (...)*) (...) + ret i32 %r +} + +; Check that we spill and fill SSE registers around the call to puts. + +; CHECK-LABEL: vector_thunk@@0: +; CHECK-DAG: movl %ecx, {{.*}} +; CHECK-DAG: movl %edx, {{.*}} + +; SSE2-DAG: movups %xmm0, {{.*}} +; SSE2-DAG: movups %xmm1, {{.*}} +; SSE2-DAG: movups %xmm2, {{.*}} +; SSE2-DAG: movups %xmm3, {{.*}} +; SSE2-DAG: movups %xmm4, {{.*}} +; SSE2-DAG: movups %xmm5, {{.*}} + +; AVX-DAG: vmovups %ymm0, {{.*}} +; AVX-DAG: vmovups %ymm1, {{.*}} +; AVX-DAG: vmovups %ymm2, {{.*}} +; AVX-DAG: vmovups %ymm3, {{.*}} +; AVX-DAG: vmovups %ymm4, {{.*}} +; AVX-DAG: vmovups %ymm5, {{.*}} + +; AVX512-DAG: vmovups %zmm0, {{.*}} +; AVX512-DAG: vmovups %zmm1, {{.*}} +; AVX512-DAG: vmovups %zmm2, {{.*}} +; AVX512-DAG: vmovups %zmm3, {{.*}} +; AVX512-DAG: vmovups %zmm4, {{.*}} +; AVX512-DAG: vmovups %zmm5, {{.*}} + +; CHECK: calll _puts + +; SSE2-DAG: movups {{.*}}, %xmm0 +; SSE2-DAG: movups {{.*}}, %xmm1 +; SSE2-DAG: movups {{.*}}, %xmm2 +; SSE2-DAG: movups {{.*}}, %xmm3 +; SSE2-DAG: movups {{.*}}, %xmm4 +; SSE2-DAG: movups {{.*}}, %xmm5 + +; AVX-DAG: vmovups {{.*}}, %ymm0 +; AVX-DAG: vmovups {{.*}}, %ymm1 +; AVX-DAG: vmovups {{.*}}, %ymm2 +; AVX-DAG: vmovups {{.*}}, %ymm3 +; AVX-DAG: vmovups {{.*}}, %ymm4 +; AVX-DAG: vmovups {{.*}}, %ymm5 + +; AVX512-DAG: vmovups {{.*}}, %zmm0 +; AVX512-DAG: vmovups {{.*}}, %zmm1 +; AVX512-DAG: vmovups {{.*}}, %zmm2 +; AVX512-DAG: vmovups {{.*}}, %zmm3 +; AVX512-DAG: vmovups {{.*}}, %zmm4 +; AVX512-DAG: vmovups {{.*}}, %zmm5 + +; CHECK-DAG: movl {{.*}}, %ecx +; CHECK-DAG: movl {{.*}}, %edx +; CHECK: jmp vector_target@@12 + +define x86_vectorcallcc i32 @vector_target(i32 inreg %a, i32 inreg %b, i32 %c) { + %a0 = add i32 %a, %b + %a1 = add i32 %a0, %c + ret i32 %a1 +} Index: llvm/trunk/test/CodeGen/X86/musttail-varargs.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/musttail-varargs.ll +++ llvm/trunk/test/CodeGen/X86/musttail-varargs.ll @@ -5,9 +5,16 @@ ; pack. Doing a normal call will clobber all argument registers, and we will ; spill around it. A simple adjustment should not require any XMM spills. +declare void @llvm.va_start(i8*) nounwind + declare void(i8*, ...)* @get_f(i8* %this) define void @f_thunk(i8* %this, ...) { + ; Use va_start so that we exercise the combination. + %ap = alloca [4 x i8*], align 16 + %ap_i8 = bitcast [4 x i8*]* %ap to i8* + call void @llvm.va_start(i8* %ap_i8) + %fptr = call void(i8*, ...)*(i8*)* @get_f(i8* %this) musttail call void (i8*, ...)* %fptr(i8* %this, ...) ret void