Index: lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- lib/Target/RISCV/RISCVISelLowering.h +++ lib/Target/RISCV/RISCVISelLowering.h @@ -28,7 +28,8 @@ CALL, SELECT_CC, BuildPairF64, - SplitF64 + SplitF64, + TAIL }; } @@ -100,6 +101,9 @@ SDValue lowerVASTART(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const; + + bool IsEligibleForTailCallOptimization(CCState &CCInfo, + CallLoweringInfo &CLI, MachineFunction &MF) const; }; } Index: lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- lib/Target/RISCV/RISCVISelLowering.cpp +++ lib/Target/RISCV/RISCVISelLowering.cpp @@ -18,6 +18,7 @@ #include "RISCVRegisterInfo.h" #include "RISCVSubtarget.h" #include "RISCVTargetMachine.h" +#include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -36,6 +37,8 @@ #define DEBUG_TYPE "riscv-lower" +STATISTIC(NumTailCalls, "Number of tail calls"); + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -1071,6 +1074,75 @@ return Chain; } +/// IsEligibleForTailCallOptimization - Check whether the call is eligible +/// for tail call optimization. +/// Note: This is modelled after ARM's IsEligibleForTailCallOptimization. +bool RISCVTargetLowering::IsEligibleForTailCallOptimization( + CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF) const { + + auto &Callee = CLI.Callee; + auto CalleeCC = CLI.CallConv; + auto IsVarArg = CLI.IsVarArg; + auto &Outs = CLI.Outs; + auto &Caller = MF.getFunction(); + auto CallerCC = Caller.getCallingConv(); + + // Do not tail call opt functions with "disable-tail-calls" attribute. + if (Caller.getFnAttribute("disable-tail-calls").getValueAsString() == "true") + return false; + + // Do not tail call opt functions with varargs. + if (IsVarArg) + return false; + + // Do not tail call opt if the stack is used to pass parameters. + if (CCInfo.getNextStackOffset() != 0) + return false; + + // Exception-handling functions need a special set of instructions to + // indicate a return to the hardware. Tail-calling another function would + // probably break this. + // TODO: The "interrupt" attribute isn't currently defined by RISC-V. This + // should be expanded as new function attributes are introduced. + if (Caller.hasFnAttribute("interrupt")) + return false; + + // Do not tail call opt if either caller or callee uses struct return + // semantics. + auto IsCallerStructRet = Caller.hasStructRetAttr(); + auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet(); + if (IsCallerStructRet || IsCalleeStructRet) + return false; + + // Externally-defined functions with weak linkage should not be + // tail-called. The behaviour of branch instructions in this situation (as + // used for tail calls) is implementation-defined, so we cannot rely on the + // linker replacing the tail call with a return. + if (GlobalAddressSDNode *G = dyn_cast(Callee)) { + const GlobalValue *GV = G->getGlobal(); + if (GV->hasExternalWeakLinkage()) + return false; + } + + // The callee has to preserve all registers the caller needs to preserve. + const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo(); + const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC); + if (CalleeCC != CallerCC) { + const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC); + if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved)) + return false; + } + + // Byval parameters hand the function a pointer directly into the stack area + // we want to reuse during a tail call. Working around this *is* possible + // but less efficient and uglier in LowerCall. + for (auto &Arg : Outs) + if (Arg.Flags.isByVal()) + return false; + + return true; +} + // Lower a call to a callseq_start + CALL + callseq_end chain, and add input // and output parameter nodes. SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, @@ -1082,7 +1154,7 @@ SmallVectorImpl &Ins = CLI.Ins; SDValue Chain = CLI.Chain; SDValue Callee = CLI.Callee; - CLI.IsTailCall = false; + bool &IsTailCall = CLI.IsTailCall; CallingConv::ID CallConv = CLI.CallConv; bool IsVarArg = CLI.IsVarArg; EVT PtrVT = getPointerTy(DAG.getDataLayout()); @@ -1095,6 +1167,16 @@ CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext()); analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI); + // Check if it's really possible to do a tail call. + if (IsTailCall) + IsTailCall = IsEligibleForTailCallOptimization(ArgCCInfo, CLI, MF); + + if (IsTailCall) + ++NumTailCalls; + else if (CLI.CS && CLI.CS.isMustTailCall()) + report_fatal_error("failed to perform tail call elimination on a call " + "site marked musttail"); + // Get a count of how many bytes are to be pushed on the stack. unsigned NumBytes = ArgCCInfo.getNextStackOffset(); @@ -1116,17 +1198,20 @@ Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Align, /*IsVolatile=*/false, /*AlwaysInline=*/false, - /*isTailCall=*/false, MachinePointerInfo(), + IsTailCall, MachinePointerInfo(), MachinePointerInfo()); ByValArgs.push_back(FIPtr); } - Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL); + if (!IsTailCall) + Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL); + + SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, + getPointerTy(DAG.getDataLayout())); // Copy argument values to their designated locations. SmallVector, 8> RegsToPass; SmallVector MemOpChains; - SDValue StackPtr; for (unsigned i = 0, j = 0, e = ArgLocs.size(); i != e; ++i) { CCValAssign &VA = ArgLocs[i]; SDValue ArgValue = OutVals[i]; @@ -1212,13 +1297,23 @@ // Work out the address of the stack slot. if (!StackPtr.getNode()) StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT); - SDValue Address = - DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, - DAG.getIntPtrConstant(VA.getLocMemOffset(), DL)); - // Emit the store. - MemOpChains.push_back( - DAG.getStore(Chain, DL, ArgValue, Address, MachinePointerInfo())); + if (!IsTailCall) { + SDValue Address = + DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, + DAG.getIntPtrConstant(VA.getLocMemOffset(), DL)); + + // Emit the store. + MemOpChains.push_back( + DAG.getStore(Chain, DL, ArgValue, Address, MachinePointerInfo())); + } else { + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); + int FI = MFI.CreateFixedObject(ArgValue.getValueSizeInBits() / 8, + VA.getLocMemOffset(), false); + SDValue FIN = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout())); + return DAG.getStore(Chain, DL, ArgValue, FIN, MachinePointerInfo(), + /* Alignment = */ 0, MachineMemOperand::MOVolatile); + } } } @@ -1228,12 +1323,27 @@ SDValue Glue; + // Tail call byval lowering might overwrite argument registers so in case of + // tail call optimization the copies to registers are lowered to 'real' stack + // slot. + // Force all the incoming stack arguments to be loaded from the stack + // before any new outgoing arguments are stored to the stack, because the + // outgoing stack slots may alias the incoming argument stack slots, and + // the alias isn't otherwise explicit. This is slightly more conservative + // than necessary, because it means that each store effectively depends + // on every argument instead of just those arguments it would clobber. + if (IsTailCall) + Glue = SDValue(); + // Build a sequence of copy-to-reg nodes, chained and glued together. for (auto &Reg : RegsToPass) { Chain = DAG.getCopyToReg(Chain, DL, Reg.first, Reg.second, Glue); Glue = Chain.getValue(1); } + if (IsTailCall) + Glue = SDValue(); + // If the callee is a GlobalAddress/ExternalSymbol node, turn it into a // TargetGlobalAddress/TargetExternalSymbol node so that legalize won't // split it and then direct call can be matched by PseudoCALL. @@ -1253,11 +1363,13 @@ for (auto &Reg : RegsToPass) Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType())); - // Add a register mask operand representing the call-preserved registers. - const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); - const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); - assert(Mask && "Missing call preserved mask for calling convention"); - Ops.push_back(DAG.getRegisterMask(Mask)); + if (!IsTailCall) { + // Add a register mask operand representing the call-preserved registers. + const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); + const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); + assert(Mask && "Missing call preserved mask for calling convention"); + Ops.push_back(DAG.getRegisterMask(Mask)); + } // Glue the call to the argument copies, if any. if (Glue.getNode()) @@ -1265,6 +1377,12 @@ // Emit the call. SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); + + if (IsTailCall) { + MF.getFrameInfo().setHasTailCall(); + return DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops); + } + Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops); Glue = Chain.getValue(1); @@ -1420,6 +1538,8 @@ return "RISCVISD::BuildPairF64"; case RISCVISD::SplitF64: return "RISCVISD::SplitF64"; + case RISCVISD::TAIL: + return "RISCVISD::TAIL"; } return nullptr; } Index: lib/Target/RISCV/RISCVInstrInfo.cpp =================================================================== --- lib/Target/RISCV/RISCVInstrInfo.cpp +++ lib/Target/RISCV/RISCVInstrInfo.cpp @@ -436,7 +436,10 @@ case TargetOpcode::KILL: case TargetOpcode::DBG_VALUE: return 0; + case RISCV::PseudoTAILIndirect: + return 4; case RISCV::PseudoCALL: + case RISCV::PseudoTAIL: return 8; case TargetOpcode::INLINEASM: { const MachineFunction &MF = *MI.getParent()->getParent(); Index: lib/Target/RISCV/RISCVInstrInfo.td =================================================================== --- lib/Target/RISCV/RISCVInstrInfo.td +++ lib/Target/RISCV/RISCVInstrInfo.td @@ -38,6 +38,9 @@ [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>; def SelectCC : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC, [SDNPInGlue]>; +def Tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall, + [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, + SDNPVariadic]>; //===----------------------------------------------------------------------===// // Operand and SDNode transformation definitions. @@ -657,6 +660,25 @@ def PseudoRET : Pseudo<(outs), (ins), [(RetFlag)]>, PseudoInstExpansion<(JALR X0, X1, 0)>; +// PseudoTAIL is a pseudo instruction similar to PseudoCALL and will eventually +// expand to auipc and jalr while encoding. +// Define AsmString to print "tail" when compile with -S flag. +let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [X2], + isCodeGenOnly = 0 in +def PseudoTAIL : Pseudo<(outs), (ins bare_symbol:$dst), []> { + let AsmString = "tail\t$dst"; +} + +let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [X2] in +def PseudoTAILIndirect : Pseudo<(outs), (ins GPRTC:$dst), []>; + +def : Pat<(Tail (iPTR tglobaladdr:$dst)), + (PseudoTAIL texternalsym:$dst)>; +def : Pat<(Tail (iPTR texternalsym:$dst)), + (PseudoTAIL texternalsym:$dst)>; +def : Pat<(Tail GPRTC:$dst), + (PseudoTAILIndirect GPRTC:$dst)>; + /// Loads multiclass LdPat { Index: lib/Target/RISCV/RISCVRegisterInfo.td =================================================================== --- lib/Target/RISCV/RISCVRegisterInfo.td +++ lib/Target/RISCV/RISCVRegisterInfo.td @@ -128,6 +128,16 @@ [RegInfo<32,32,32>, RegInfo<64,64,64>, RegInfo<32,32,32>]>; } +// For tail calls, we can't use callee-saved registers, as they are restored to +// the saved value before the tail call, which would clobber a call address. +def GPRTC : RegisterClass<"RISCV", [XLenVT], 32, (add + (sequence "X%u", 5, 7) + )> { + let RegInfos = RegInfoByHwMode< + [RV32, RV64, DefaultMode], + [RegInfo<32,32,32>, RegInfo<64,64,64>, RegInfo<32,32,32>]>; +} + def SP : RegisterClass<"RISCV", [XLenVT], 32, (add X2)> { let RegInfos = RegInfoByHwMode< [RV32, RV64, DefaultMode], Index: test/CodeGen/RISCV/disable-tail-calls.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/disable-tail-calls.ll @@ -0,0 +1,40 @@ +; Check that command line option "-disable-tail-calls" overrides function +; attribute "disable-tail-calls". + +; RUN: llc < %s -mtriple=riscv32-unknown-elf \ +; RUN: | FileCheck %s --check-prefixes=CALLER1,NOTAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \ +; RUN: | FileCheck %s --check-prefixes=CALLER1,NOTAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \ +; RUN: | FileCheck %s --check-prefixes=CALLER1,TAIL + +; RUN: llc < %s -mtriple=riscv32-unknown-elf \ +; RUN: | FileCheck %s --check-prefixes=CALLER2,TAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \ +; RUN: | FileCheck %s --check-prefixes=CALLER2,NOTAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \ +; RUN: | FileCheck %s --check-prefixes=CALLER2,TAIL + +; CALLER1-LABEL: {{\_?}}caller1 +; CALLER2-LABEL: {{\_?}}caller2 +; NOTAIL-NOT: tail callee +; NOTAIL: call callee +; TAIL: tail callee +; TAIL-NOT: call callee + +; Function with attribute #0 = { "disable-tail-calls"="true" } +define i32 @caller1(i32 %a) #0 { +entry: + %call = tail call i32 @callee(i32 %a) + ret i32 %call +} + +define i32 @caller2(i32 %a) { +entry: + %call = tail call i32 @callee(i32 %a) + ret i32 %call +} + +declare i32 @callee(i32) + +attributes #0 = { "disable-tail-calls"="true" } Index: test/CodeGen/RISCV/tail-calls.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/tail-calls.ll @@ -0,0 +1,79 @@ +; RUN: llc -mtriple riscv32-unknown-linux-gnu -o - %s | FileCheck %s +; RUN: llc -mtriple riscv32-unknown-elf -o - %s | FileCheck %s +; RUN: llc -mtriple riscv64-unknown-linux-gnu -o - %s | FileCheck %s +; RUN: llc -mtriple riscv64-unknown-elf -o - %s | FileCheck %s + +; Perform tail call optimization. +declare i32 @callee(i32 %i) +define i32 @caller(i32 %i) { +; CHECK-LABEL: caller +; CHECK: tail callee +entry: + %r = tail call i32 @callee(i32 %i) + ret i32 %r +} + +; Do not do tail call optimization for functions with varargs. +declare i32 @callee_varargs(i32, ...) +define void @caller_varargs(i32 %a, i32 %b) { +; CHECK-LABEL: caller_varargs: +; CHECK-NOT: tail callee_varargs +; CHECK: call callee_varargs +entry: + %call = tail call i32 (i32, ...) @callee_varargs(i32 %a, i32 %b, i32 %b, i32 %a) + ret void +} + +; RISCV has 8 function argument registers: [x10-x11], [x12-x17]. +; If a function requires more argument registers we cannot tail call optimize it. +declare i32 @callee_args(i32 %a, i32 %b, i32 %c, i32 %dd, i32 %e, i32 %ff, i32 %g, i32 %h, i32 %i) +define i32 @caller_args(i32 %a, i32 %b, i32 %c, i32 %dd, i32 %e, i32 %ff, i32 %g, i32 %h, i32 %i) { +; CHECK-LABEL: caller_args +; CHECK-NOT: tail callee_args +; CHECK: call callee_args +entry: + %r = tail call i32 @callee_args(i32 %a, i32 %b, i32 %c, i32 %dd, i32 %e, i32 %ff, i32 %g, i32 %h, i32 %i) + ret i32 %r +} + +; Externally-defined functions with weak linkage should not be tail-called. +; The behaviour of branch instructions in this situation (as used for tail +; calls) is implementation-defined, so we cannot rely on the linker replacing +; the tail call with a return. +declare extern_weak void @callee_weak() +define void @caller_weak() { +; CHECK-LABEL: caller_weak: +; CHECK-NOT: tail callee_weak +; CHECK: call callee_weak +entry: + tail call void @callee_weak() + ret void +} + +; Exception-handling functions need a special set of instructions to indicate a +; return to the hardware. Tail-calling another function would probably break +; this. +declare void @callee_irq() +define void @caller_irq() #0 { +; CHECK-LABEL: caller_irq: +; CHECK-NOT: tail callee_irq +; CHECK: call callee_irq +entry: + tail call void @callee_irq() + ret void +} +attributes #0 = { "interrupt" } + +; Byval parameters hand the function a pointer directly into the stack area +; we want to reuse during a tail call. Do not tail call optimize functions with +; byval parameters. +declare i32 @callee_byval(i32** byval %a) +define i32 @caller_byval() { +; CHECK-LABEL: caller_byval: +; CHECK-NOT: tail callee_byval +; CHECK: call callee_byval +entry: + %a = alloca i32* + %r = tail call i32 @callee_byval(i32** byval %a) + ret i32 %r +}