Index: lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp =================================================================== --- lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp +++ lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp @@ -53,9 +53,9 @@ SmallVectorImpl &Fixups, const MCSubtargetInfo &STI) const override; - void expandFunctionCall(const MCInst &MI, raw_ostream &OS, - SmallVectorImpl &Fixups, - const MCSubtargetInfo &STI) const; + unsigned expandFunctionCall(const MCInst &MI, raw_ostream &OS, + SmallVectorImpl &Fixups, + const MCSubtargetInfo &STI) const; /// TableGen'erated function for getting the binary encoding for an /// instruction. @@ -92,14 +92,34 @@ // We let linker relaxation deal with it. When linker relaxation enabled, // AUIPC and JALR have chance relax to JAL. If C extension is enabled, // JAL has chance relax to C_JAL. -void RISCVMCCodeEmitter::expandFunctionCall(const MCInst &MI, raw_ostream &OS, - SmallVectorImpl &Fixups, - const MCSubtargetInfo &STI) const { +unsigned +RISCVMCCodeEmitter::expandFunctionCall(const MCInst &MI, raw_ostream &OS, + SmallVectorImpl &Fixups, + const MCSubtargetInfo &STI) const { MCInst TmpInst; MCOperand Func = MI.getOperand(0); - unsigned Ra = (MI.getOpcode() == RISCV::PseudoTAIL) ? RISCV::X6 : RISCV::X1; uint32_t Binary; + // For an indirect tail call, emit JALR X0, 0(Reg). + // TODO: We had to custom lower this here since for tail calls we cannot + // allow JALR with any GPR; we need to restrict the register class to only + // caller-saved regs. Hence we had to create a new regclass GPRTC which is a + // subclass of GPR. However, currently tablegen does not support the use of + // the registers of a subclass of the allowed regclass for an instruction. + // Once this is supported, we can replace the custom lowering here with a + // PseudoInstExpansion rule in RISCVInstrInfo.td. + if (MI.getOpcode() == RISCV::PseudoTAILIndirect) { + TmpInst = MCInstBuilder(RISCV::JALR) + .addReg(RISCV::X0) + .addReg(Func.getReg()) + .addImm(0); + Binary = getBinaryCodeForInstr(TmpInst, Fixups, STI); + support::endian::Writer(OS).write(Binary); + return 1 /* MCNumEmitted */; + } + + unsigned Ra = (MI.getOpcode() == RISCV::PseudoTAIL) ? RISCV::X6 : RISCV::X1; + assert(Func.isExpr() && "Expected expression"); const MCExpr *Expr = Func.getExpr(); @@ -119,6 +139,7 @@ TmpInst = MCInstBuilder(RISCV::JALR).addReg(Ra).addReg(Ra).addImm(0); Binary = getBinaryCodeForInstr(TmpInst, Fixups, STI); support::endian::Writer(OS).write(Binary); + return 2 /* MCNumEmitted */; } void RISCVMCCodeEmitter::encodeInstruction(const MCInst &MI, raw_ostream &OS, @@ -129,9 +150,9 @@ unsigned Size = Desc.getSize(); if (MI.getOpcode() == RISCV::PseudoCALL || - MI.getOpcode() == RISCV::PseudoTAIL) { - expandFunctionCall(MI, OS, Fixups, STI); - MCNumEmitted += 2; + MI.getOpcode() == RISCV::PseudoTAIL || + MI.getOpcode() == RISCV::PseudoTAILIndirect) { + MCNumEmitted += expandFunctionCall(MI, OS, Fixups, STI); return; } Index: lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- lib/Target/RISCV/RISCVISelLowering.h +++ lib/Target/RISCV/RISCVISelLowering.h @@ -28,7 +28,8 @@ CALL, SELECT_CC, BuildPairF64, - SplitF64 + SplitF64, + TAIL }; } @@ -100,6 +101,10 @@ SDValue lowerVASTART(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const; + + bool IsEligibleForTailCallOptimization(CCState &CCInfo, + CallLoweringInfo &CLI, MachineFunction &MF, + const SmallVector &ArgLocs) const; }; } Index: lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- lib/Target/RISCV/RISCVISelLowering.cpp +++ lib/Target/RISCV/RISCVISelLowering.cpp @@ -18,6 +18,7 @@ #include "RISCVRegisterInfo.h" #include "RISCVSubtarget.h" #include "RISCVTargetMachine.h" +#include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -36,6 +37,8 @@ #define DEBUG_TYPE "riscv-lower" +STATISTIC(NumTailCalls, "Number of tail calls"); + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -1071,6 +1074,88 @@ return Chain; } +/// IsEligibleForTailCallOptimization - Check whether the call is eligible +/// for tail call optimization. +/// Note: This is modelled after ARM's IsEligibleForTailCallOptimization. +bool RISCVTargetLowering::IsEligibleForTailCallOptimization( + CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF, + const SmallVector &ArgLocs) const { + + auto &Callee = CLI.Callee; + auto CalleeCC = CLI.CallConv; + auto IsVarArg = CLI.IsVarArg; + auto &Outs = CLI.Outs; + auto &Caller = MF.getFunction(); + auto CallerCC = Caller.getCallingConv(); + + // Do not tail call opt functions with "disable-tail-calls" attribute. + if (Caller.getFnAttribute("disable-tail-calls").getValueAsString() == "true") + return false; + + // Exception-handling functions need a special set of instructions to + // indicate a return to the hardware. Tail-calling another function would + // probably break this. + // TODO: The "interrupt" attribute isn't currently defined by RISC-V. This + // should be expanded as new function attributes are introduced. + if (Caller.hasFnAttribute("interrupt")) + return false; + + // Do not tail call opt functions with varargs. + if (IsVarArg) + return false; + + // Do not tail call opt if the stack is used to pass parameters. + if (CCInfo.getNextStackOffset() != 0) + return false; + + // Do not tail call opt if any parameters need to be passed indirectly. + // Since long doubles (fp128) and i128 are larger than 2*XLEN, they are + // passed indirectly. So the address of the value will be passed in a + // register, or if not available, then the address is put on the stack. In + // order to pass indirectly, space on the stack often needs to be allocated + // in order to store the value. In this case the CCInfo.getNextStackOffset() + // != 0 check is not enough and we need to check if any CCValAssign ArgsLocs + // are passed CCValAssign::Indirect. + for (auto &VA : ArgLocs) + if (VA.getLocInfo() == CCValAssign::Indirect) + return false; + + // Do not tail call opt if either caller or callee uses struct return + // semantics. + auto IsCallerStructRet = Caller.hasStructRetAttr(); + auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet(); + if (IsCallerStructRet || IsCalleeStructRet) + return false; + + // Externally-defined functions with weak linkage should not be + // tail-called. The behaviour of branch instructions in this situation (as + // used for tail calls) is implementation-defined, so we cannot rely on the + // linker replacing the tail call with a return. + if (GlobalAddressSDNode *G = dyn_cast(Callee)) { + const GlobalValue *GV = G->getGlobal(); + if (GV->hasExternalWeakLinkage()) + return false; + } + + // The callee has to preserve all registers the caller needs to preserve. + const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo(); + const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC); + if (CalleeCC != CallerCC) { + const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC); + if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved)) + return false; + } + + // Byval parameters hand the function a pointer directly into the stack area + // we want to reuse during a tail call. Working around this *is* possible + // but less efficient and uglier in LowerCall. + for (auto &Arg : Outs) + if (Arg.Flags.isByVal()) + return false; + + return true; +} + // Lower a call to a callseq_start + CALL + callseq_end chain, and add input // and output parameter nodes. SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, @@ -1082,7 +1167,7 @@ SmallVectorImpl &Ins = CLI.Ins; SDValue Chain = CLI.Chain; SDValue Callee = CLI.Callee; - CLI.IsTailCall = false; + bool &IsTailCall = CLI.IsTailCall; CallingConv::ID CallConv = CLI.CallConv; bool IsVarArg = CLI.IsVarArg; EVT PtrVT = getPointerTy(DAG.getDataLayout()); @@ -1095,6 +1180,17 @@ CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext()); analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI); + // Check if it's really possible to do a tail call. + if (IsTailCall) + IsTailCall = IsEligibleForTailCallOptimization(ArgCCInfo, CLI, MF, + ArgLocs); + + if (IsTailCall) + ++NumTailCalls; + else if (CLI.CS && CLI.CS.isMustTailCall()) + report_fatal_error("failed to perform tail call elimination on a call " + "site marked musttail"); + // Get a count of how many bytes are to be pushed on the stack. unsigned NumBytes = ArgCCInfo.getNextStackOffset(); @@ -1116,12 +1212,13 @@ Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Align, /*IsVolatile=*/false, /*AlwaysInline=*/false, - /*isTailCall=*/false, MachinePointerInfo(), + IsTailCall, MachinePointerInfo(), MachinePointerInfo()); ByValArgs.push_back(FIPtr); } - Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL); + if (!IsTailCall) + Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL); // Copy argument values to their designated locations. SmallVector, 8> RegsToPass; @@ -1208,6 +1305,8 @@ RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue)); } else { assert(VA.isMemLoc() && "Argument not register or memory"); + assert(!IsTailCall && "Tail call not allowed if stack is used " + "for passing parameters"); // Work out the address of the stack slot. if (!StackPtr.getNode()) @@ -1253,11 +1352,13 @@ for (auto &Reg : RegsToPass) Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType())); - // Add a register mask operand representing the call-preserved registers. - const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); - const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); - assert(Mask && "Missing call preserved mask for calling convention"); - Ops.push_back(DAG.getRegisterMask(Mask)); + if (!IsTailCall) { + // Add a register mask operand representing the call-preserved registers. + const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); + const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv); + assert(Mask && "Missing call preserved mask for calling convention"); + Ops.push_back(DAG.getRegisterMask(Mask)); + } // Glue the call to the argument copies, if any. if (Glue.getNode()) @@ -1265,6 +1366,12 @@ // Emit the call. SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); + + if (IsTailCall) { + MF.getFrameInfo().setHasTailCall(); + return DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops); + } + Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops); Glue = Chain.getValue(1); @@ -1420,6 +1527,8 @@ return "RISCVISD::BuildPairF64"; case RISCVISD::SplitF64: return "RISCVISD::SplitF64"; + case RISCVISD::TAIL: + return "RISCVISD::TAIL"; } return nullptr; } Index: lib/Target/RISCV/RISCVInstrInfo.cpp =================================================================== --- lib/Target/RISCV/RISCVInstrInfo.cpp +++ lib/Target/RISCV/RISCVInstrInfo.cpp @@ -436,7 +436,10 @@ case TargetOpcode::KILL: case TargetOpcode::DBG_VALUE: return 0; + case RISCV::PseudoTAILIndirect: + return 4; case RISCV::PseudoCALL: + case RISCV::PseudoTAIL: return 8; case TargetOpcode::INLINEASM: { const MachineFunction &MF = *MI.getParent()->getParent(); Index: lib/Target/RISCV/RISCVInstrInfo.td =================================================================== --- lib/Target/RISCV/RISCVInstrInfo.td +++ lib/Target/RISCV/RISCVInstrInfo.td @@ -38,6 +38,9 @@ [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>; def SelectCC : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC, [SDNPInGlue]>; +def Tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall, + [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, + SDNPVariadic]>; //===----------------------------------------------------------------------===// // Operand and SDNode transformation definitions. @@ -661,11 +664,19 @@ // expand to auipc and jalr while encoding. // Define AsmString to print "tail" when compile with -S flag. let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [X2], - hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCodeGenOnly = 0 in + isCodeGenOnly = 0 in def PseudoTAIL : Pseudo<(outs), (ins bare_symbol:$dst), []> { let AsmString = "tail\t$dst"; } +let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [X2] in +def PseudoTAILIndirect : Pseudo<(outs), (ins GPRTC:$dst), []>; + +def : Pat<(Tail (iPTR tglobaladdr:$dst)), + (PseudoTAIL texternalsym:$dst)>; +def : Pat<(Tail GPRTC:$dst), + (PseudoTAILIndirect GPRTC:$dst)>; + /// Loads multiclass LdPat { Index: lib/Target/RISCV/RISCVRegisterInfo.td =================================================================== --- lib/Target/RISCV/RISCVRegisterInfo.td +++ lib/Target/RISCV/RISCVRegisterInfo.td @@ -128,6 +128,18 @@ [RegInfo<32,32,32>, RegInfo<64,64,64>, RegInfo<32,32,32>]>; } +// For tail calls, we can't use callee-saved registers, as they are restored to +// the saved value before the tail call, which would clobber a call address. +def GPRTC : RegisterClass<"RISCV", [XLenVT], 32, (add + (sequence "X%u", 5, 7), + (sequence "X%u", 10, 17), + (sequence "X%u", 28, 31) + )> { + let RegInfos = RegInfoByHwMode< + [RV32, RV64, DefaultMode], + [RegInfo<32,32,32>, RegInfo<64,64,64>, RegInfo<32,32,32>]>; +} + def SP : RegisterClass<"RISCV", [XLenVT], 32, (add X2)> { let RegInfos = RegInfoByHwMode< [RV32, RV64, DefaultMode], Index: test/CodeGen/RISCV/disable-tail-calls.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/disable-tail-calls.ll @@ -0,0 +1,56 @@ +; Check that command line option "-disable-tail-calls" overrides function +; attribute "disable-tail-calls". + +; RUN: llc < %s -mtriple=riscv32-unknown-elf \ +; RUN: | FileCheck %s --check-prefixes=CALLER1,NOTAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \ +; RUN: | FileCheck %s --check-prefixes=CALLER1,NOTAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \ +; RUN: | FileCheck %s --check-prefixes=CALLER1,TAIL + +; RUN: llc < %s -mtriple=riscv32-unknown-elf \ +; RUN: | FileCheck %s --check-prefixes=CALLER2,TAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \ +; RUN: | FileCheck %s --check-prefixes=CALLER2,NOTAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \ +; RUN: | FileCheck %s --check-prefixes=CALLER2,TAIL + +; RUN: llc < %s -mtriple=riscv32-unknown-elf \ +; RUN: | FileCheck %s --check-prefixes=CALLER3,TAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls \ +; RUN: | FileCheck %s --check-prefixes=CALLER3,NOTAIL +; RUN: llc < %s -mtriple=riscv32-unknown-elf -disable-tail-calls=false \ +; RUN: | FileCheck %s --check-prefixes=CALLER3,TAIL + +; CALLER1-LABEL: {{\_?}}caller1 +; CALLER2-LABEL: {{\_?}}caller2 +; CALLER3-LABEL: {{\_?}}caller3 +; NOTAIL-NOT: tail callee +; NOTAIL: call callee +; TAIL: tail callee +; TAIL-NOT: call callee + +; Function with attribute #0 = { "disable-tail-calls"="true" } +define i32 @caller1(i32 %a) #0 { +entry: + %call = tail call i32 @callee(i32 %a) + ret i32 %call +} + +; Function with attribute #1 = { "disable-tail-calls"="false" } +define i32 @caller2(i32 %a) #0 { +entry: + %call = tail call i32 @callee(i32 %a) + ret i32 %call +} + +define i32 @caller3(i32 %a) { +entry: + %call = tail call i32 @callee(i32 %a) + ret i32 %call +} + +declare i32 @callee(i32) + +attributes #0 = { "disable-tail-calls"="true" } +attributes #1 = { "disable-tail-calls"="false" } Index: test/CodeGen/RISCV/musttail-call.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/musttail-call.ll @@ -0,0 +1,20 @@ +; Check that we error out if tail is not possible but call is marked as mustail. + +; RUN: not llc -mtriple riscv32-unknown-linux-gnu -o - %s \ +; RUN: 2>&1 | FileCheck %s +; RUN: not llc -mtriple riscv32-unknown-elf -o - %s \ +; RUN: 2>&1 | FileCheck %s +; RUN: not llc -mtriple riscv64-unknown-linux-gnu -o - %s \ +; RUN: 2>&1 | FileCheck %s +; RUN: not llc -mtriple riscv64-unknown-elf -o - %s \ +; RUN: 2>&1 | FileCheck %s + +%struct.A = type { i32 } + +declare void @callee_musttail(%struct.A* sret %a) +define void @caller_musttail(%struct.A* sret %a) { +; CHECK: LLVM ERROR: failed to perform tail call elimination on a call site marked musttail +entry: + musttail call void @callee_musttail(%struct.A* sret %a) + ret void +} Index: test/CodeGen/RISCV/tail-calls.ll =================================================================== --- /dev/null +++ test/CodeGen/RISCV/tail-calls.ll @@ -0,0 +1,112 @@ +; RUN: llc -mtriple riscv32-unknown-linux-gnu -o - %s | FileCheck %s +; RUN: llc -mtriple riscv32-unknown-elf -o - %s | FileCheck %s + +; Perform tail call optimization. +declare i32 @callee(i32 %i) +define i32 @caller(i32 %i) { +; CHECK-LABEL: caller +; CHECK: tail callee +entry: + %r = tail call i32 @callee(i32 %i) + ret i32 %r +} + +; Do not tail call optimize functions with varargs. +declare i32 @callee_varargs(i32, ...) +define void @caller_varargs(i32 %a, i32 %b) { +; CHECK-LABEL: caller_varargs +; CHECK-NOT: tail callee_varargs +; CHECK: call callee_varargs +entry: + %call = tail call i32 (i32, ...) @callee_varargs(i32 %a, i32 %b, i32 %b, i32 %a) + ret void +} + +; Do not tail call optimize if stack is used to pass parameters. +declare i32 @callee_args(i32 %a, i32 %b, i32 %c, i32 %dd, i32 %e, i32 %ff, i32 %g, i32 %h, i32 %i, i32 %j, i32 %k, i32 %l, i32 %m, i32 %n) +define i32 @caller_args(i32 %a, i32 %b, i32 %c, i32 %dd, i32 %e, i32 %ff, i32 %g, i32 %h, i32 %i, i32 %j, i32 %k, i32 %l, i32 %m, i32 %n) { +; CHECK-LABEL: caller_args +; CHECK-NOT: tail callee_args +; CHECK: call callee_args +entry: + %r = tail call i32 @callee_args(i32 %a, i32 %b, i32 %c, i32 %dd, i32 %e, i32 %ff, i32 %g, i32 %h, i32 %i, i32 %j, i32 %k, i32 %l, i32 %m, i32 %n) + ret i32 %r +} + +; Do not tail call optimize if parameters need to be passed indirectly. +declare i32 @callee_indirect(fp128 %a) +define void @caller_indirect() { +; CHECK-LABEL: caller_indirect +; CHECK-NOT: tail callee_indirect +; CHECK: call callee_indirect +entry: + %call = tail call i32 @callee_indirect(fp128 0xL00000000000000003FFF000000000000) + ret void +} + +; Externally-defined functions with weak linkage should not be tail-called. +; The behaviour of branch instructions in this situation (as used for tail +; calls) is implementation-defined, so we cannot rely on the linker replacing +; the tail call with a return. +declare extern_weak void @callee_weak() +define void @caller_weak() { +; CHECK-LABEL: caller_weak +; CHECK-NOT: tail callee_weak +; CHECK: call callee_weak +entry: + tail call void @callee_weak() + ret void +} + +; Exception-handling functions need a special set of instructions to indicate a +; return to the hardware. Tail-calling another function would probably break +; this. +declare void @callee_irq() +define void @caller_irq() #0 { +; CHECK-LABEL: caller_irq +; CHECK-NOT: tail callee_irq +; CHECK: call callee_irq +entry: + tail call void @callee_irq() + ret void +} +attributes #0 = { "interrupt" } + +; Byval parameters hand the function a pointer directly into the stack area +; we want to reuse during a tail call. Do not tail call optimize functions with +; byval parameters. +declare i32 @callee_byval(i32** byval %a) +define i32 @caller_byval() { +; CHECK-LABEL: caller_byval +; CHECK-NOT: tail callee_byval +; CHECK: call callee_byval +entry: + %a = alloca i32* + %r = tail call i32 @callee_byval(i32** byval %a) + ret i32 %r +} + +; Do not tail call optimize if callee uses structret semantics. +%struct.A = type { i32 } +@a = global %struct.A zeroinitializer + +declare void @callee_struct(%struct.A* sret %a) +define void @caller_nostruct() { +; CHECK-LABEL: caller_nostruct +; CHECK-NOT: tail callee_struct +; CHECK: call callee_struct +entry: + tail call void @callee_struct(%struct.A* sret @a) + ret void +} + +; Do not tail call optimize if caller uses structret semantics. +declare void @callee_nostruct() +define void @caller_struct(%struct.A* sret %a) { +; CHECK-LABEL: caller_struct +; CHECK-NOT: tail callee_nostruct +; CHECK: call callee_nostruct +entry: + tail call void @callee_nostruct() + ret void +}