Index: llvm/trunk/include/llvm/IR/Argument.h =================================================================== --- llvm/trunk/include/llvm/IR/Argument.h +++ llvm/trunk/include/llvm/IR/Argument.h @@ -90,6 +90,9 @@ /// Return true if this argument has the sret attribute. bool hasStructRetAttr() const; + /// Return true if this argument has the inreg attribute. + bool hasInRegAttr() const; + /// Return true if this argument has the returned attribute. bool hasReturnedAttr() const; Index: llvm/trunk/lib/IR/Function.cpp =================================================================== --- llvm/trunk/lib/IR/Function.cpp +++ llvm/trunk/lib/IR/Function.cpp @@ -145,6 +145,10 @@ return hasAttribute(Attribute::StructRet); } +bool Argument::hasInRegAttr() const { + return hasAttribute(Attribute::InReg); +} + bool Argument::hasReturnedAttr() const { return hasAttribute(Attribute::Returned); } Index: llvm/trunk/lib/Target/AArch64/AArch64CallingConvention.td =================================================================== --- llvm/trunk/lib/Target/AArch64/AArch64CallingConvention.td +++ llvm/trunk/lib/Target/AArch64/AArch64CallingConvention.td @@ -34,7 +34,23 @@ CCIfBigEndian>>, - // An SRet is passed in X8, not X0 like a normal pointer parameter. + // In AAPCS, an SRet is passed in X8, not X0 like a normal pointer parameter. + // However, on windows, in some circumstances, the SRet is passed in X0 or X1 + // instead. The presence of the inreg attribute indicates that SRet is + // passed in the alternative register (X0 or X1), not X8: + // - X0 for non-instance methods. + // - X1 for instance methods. + + // The "sret" attribute identifies indirect returns. + // The "inreg" attribute identifies non-aggregate types. + // The position of the "sret" attribute identifies instance/non-instance + // methods. + // "sret" on argument 0 means non-instance methods. + // "sret" on argument 1 means instance methods. + + CCIfInReg>>>>, + CCIfSRet>>, // Put ByVal arguments directly on the stack. Minimum size and alignment of a Index: llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/trunk/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3208,6 +3208,26 @@ } } + // On Windows, InReg pointers must be returned, so record the pointer in a + // virtual register at the start of the function so it can be returned in the + // epilogue. + if (IsWin64) { + for (unsigned I = 0, E = Ins.size(); I != E; ++I) { + if (Ins[I].Flags.isInReg()) { + assert(!FuncInfo->getSRetReturnReg()); + + MVT PtrTy = getPointerTy(DAG.getDataLayout()); + unsigned Reg = + MF.getRegInfo().createVirtualRegister(getRegClassFor(PtrTy)); + FuncInfo->setSRetReturnReg(Reg); + + SDValue Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, Reg, InVals[I]); + Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Copy, Chain); + break; + } + } + } + unsigned StackArgSize = CCInfo.getNextStackOffset(); bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; if (DoesCalleeRestoreStack(CallConv, TailCallOpt)) { @@ -3403,10 +3423,20 @@ // X86) but less efficient and uglier in LowerCall. for (Function::const_arg_iterator i = CallerF.arg_begin(), e = CallerF.arg_end(); - i != e; ++i) + i != e; ++i) { if (i->hasByValAttr()) return false; + // On Windows, "inreg" attributes signify non-aggregate indirect returns. + // In this case, it is necessary to save/restore X0 in the callee. Tail + // call opt interferes with this. So we disable tail call opt when the + // caller has an argument with "inreg" attribute. + + // FIXME: Check whether the callee also has an "inreg" argument. + if (i->hasInRegAttr()) + return false; + } + if (getTargetMachine().Options.GuaranteedTailCallOpt) return canGuaranteeTCO(CalleeCC) && CCMatch; @@ -3924,6 +3954,9 @@ const SmallVectorImpl &Outs, const SmallVectorImpl &OutVals, const SDLoc &DL, SelectionDAG &DAG) const { + auto &MF = DAG.getMachineFunction(); + auto *FuncInfo = MF.getInfo(); + CCAssignFn *RetCC = CallConv == CallingConv::WebKit_JS ? RetCC_AArch64_WebKit_JS : RetCC_AArch64_AAPCS; @@ -3962,6 +3995,23 @@ Flag = Chain.getValue(1); RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT())); } + + // Windows AArch64 ABIs require that for returning structs by value we copy + // the sret argument into X0 for the return. + // We saved the argument into a virtual register in the entry block, + // so now we copy the value out and into X0. + if (unsigned SRetReg = FuncInfo->getSRetReturnReg()) { + SDValue Val = DAG.getCopyFromReg(RetOps[0], DL, SRetReg, + getPointerTy(MF.getDataLayout())); + + unsigned RetValReg = AArch64::X0; + Chain = DAG.getCopyToReg(Chain, DL, RetValReg, Val, Flag); + Flag = Chain.getValue(1); + + RetOps.push_back( + DAG.getRegister(RetValReg, getPointerTy(DAG.getDataLayout()))); + } + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); const MCPhysReg *I = TRI->getCalleeSavedRegsViaCopy(&DAG.getMachineFunction()); Index: llvm/trunk/lib/Target/AArch64/AArch64MachineFunctionInfo.h =================================================================== --- llvm/trunk/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ llvm/trunk/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -91,6 +91,11 @@ /// other stack allocations. bool CalleeSaveStackHasFreeSpace = false; + /// SRetReturnReg - sret lowering includes returning the value of the + /// returned struct in a register. This field holds the virtual register into + /// which the sret argument is passed. + unsigned SRetReturnReg = 0; + /// Has a value when it is known whether or not the function uses a /// redzone, and no value otherwise. /// Initialized during frame lowering, unless the function has the noredzone @@ -165,6 +170,9 @@ unsigned getVarArgsFPRSize() const { return VarArgsFPRSize; } void setVarArgsFPRSize(unsigned Size) { VarArgsFPRSize = Size; } + unsigned getSRetReturnReg() const { return SRetReturnReg; } + void setSRetReturnReg(unsigned Reg) { SRetReturnReg = Reg; } + unsigned getJumpTableEntrySize(int Idx) const { auto It = JumpTableEntryInfo.find(Idx); if (It != JumpTableEntryInfo.end()) Index: llvm/trunk/test/CodeGen/AArch64/arm64-windows-calls.ll =================================================================== --- llvm/trunk/test/CodeGen/AArch64/arm64-windows-calls.ll +++ llvm/trunk/test/CodeGen/AArch64/arm64-windows-calls.ll @@ -0,0 +1,94 @@ +; FIXME: Add tests for global-isel/fast-isel. + +; RUN: llc < %s -mtriple=arm64-windows | FileCheck %s + +; Returns <= 8 bytes should be in X0. +%struct.S1 = type { i32, i32 } +define dso_local i64 @"?f1"() { +entry: +; CHECK-LABEL: f1 +; CHECK: str xzr, [sp, #8] +; CHECK: mov x0, xzr + + %retval = alloca %struct.S1, align 4 + %a = getelementptr inbounds %struct.S1, %struct.S1* %retval, i32 0, i32 0 + store i32 0, i32* %a, align 4 + %b = getelementptr inbounds %struct.S1, %struct.S1* %retval, i32 0, i32 1 + store i32 0, i32* %b, align 4 + %0 = bitcast %struct.S1* %retval to i64* + %1 = load i64, i64* %0, align 4 + ret i64 %1 +} + +; Returns <= 16 bytes should be in X0/X1. +%struct.S2 = type { i32, i32, i32, i32 } +define dso_local [2 x i64] @"?f2"() { +entry: +; CHECK-LABEL: f2 +; CHECK: stp xzr, xzr, [sp], #16 +; CHECK: mov x0, xzr +; CHECK: mov x1, xzr + + %retval = alloca %struct.S2, align 4 + %a = getelementptr inbounds %struct.S2, %struct.S2* %retval, i32 0, i32 0 + store i32 0, i32* %a, align 4 + %b = getelementptr inbounds %struct.S2, %struct.S2* %retval, i32 0, i32 1 + store i32 0, i32* %b, align 4 + %c = getelementptr inbounds %struct.S2, %struct.S2* %retval, i32 0, i32 2 + store i32 0, i32* %c, align 4 + %d = getelementptr inbounds %struct.S2, %struct.S2* %retval, i32 0, i32 3 + store i32 0, i32* %d, align 4 + %0 = bitcast %struct.S2* %retval to [2 x i64]* + %1 = load [2 x i64], [2 x i64]* %0, align 4 + ret [2 x i64] %1 +} + +; Arguments > 16 bytes should be passed in X8. +%struct.S3 = type { i32, i32, i32, i32, i32 } +define dso_local void @"?f3"(%struct.S3* noalias sret %agg.result) { +entry: +; CHECK-LABEL: f3 +; CHECK: stp xzr, xzr, [x8] +; CHECK: str wzr, [x8, #16] + + %a = getelementptr inbounds %struct.S3, %struct.S3* %agg.result, i32 0, i32 0 + store i32 0, i32* %a, align 4 + %b = getelementptr inbounds %struct.S3, %struct.S3* %agg.result, i32 0, i32 1 + store i32 0, i32* %b, align 4 + %c = getelementptr inbounds %struct.S3, %struct.S3* %agg.result, i32 0, i32 2 + store i32 0, i32* %c, align 4 + %d = getelementptr inbounds %struct.S3, %struct.S3* %agg.result, i32 0, i32 3 + store i32 0, i32* %d, align 4 + %e = getelementptr inbounds %struct.S3, %struct.S3* %agg.result, i32 0, i32 4 + store i32 0, i32* %e, align 4 + ret void +} + +; InReg arguments to non-instance methods must be passed in X0 and returns in +; X0. +%class.B = type { i32 } +define dso_local void @"?f4"(%class.B* inreg noalias nocapture sret %agg.result) { +entry: +; CHECK-LABEL: f4 +; CHECK: mov w8, #1 +; CHECK: str w8, [x0] + %X.i = getelementptr inbounds %class.B, %class.B* %agg.result, i64 0, i32 0 + store i32 1, i32* %X.i, align 4 + ret void +} + +; InReg arguments to instance methods must be passed in X1 and returns in X0. +%class.C = type { i8 } +%class.A = type { i8 } + +define dso_local void @"?inst@C"(%class.C* %this, %class.A* inreg noalias sret %agg.result) { +entry: +; CHECK-LABEL: inst@C +; CHECK: str x0, [sp, #8] +; CHECK: mov x0, x1 + + %this.addr = alloca %class.C*, align 8 + store %class.C* %this, %class.C** %this.addr, align 8 + %this1 = load %class.C*, %class.C** %this.addr, align 8 + ret void +} Index: llvm/trunk/test/CodeGen/AArch64/arm64-windows-tailcall.ll =================================================================== --- llvm/trunk/test/CodeGen/AArch64/arm64-windows-tailcall.ll +++ llvm/trunk/test/CodeGen/AArch64/arm64-windows-tailcall.ll @@ -0,0 +1,18 @@ +; FIXME: Add tests for global-isel/fast-isel. + +; RUN: llc < %s -mtriple=arm64-windows | FileCheck %s + +%class.C = type { [1 x i32] } + +define dso_local void @"?bar"(%class.C* inreg noalias sret %agg.result) { +entry: +; CHECK-LABEL: bar +; CHECK: mov x19, x0 +; CHECK: bl "?foo" +; CHECK: mov x0, x19 + + tail call void @"?foo"(%class.C* dereferenceable(4) %agg.result) + ret void +} + +declare dso_local void @"?foo"(%class.C* dereferenceable(4))