Index: lib/IR/Instructions.cpp =================================================================== --- lib/IR/Instructions.cpp +++ lib/IR/Instructions.cpp @@ -324,7 +324,7 @@ OperandTraits::op_end(this) - CI.getNumOperands(), CI.getNumOperands()) { setAttributes(CI.getAttributes()); - setTailCall(CI.isTailCall()); + setTailCallKind(CI.getTailCallKind()); setCallingConv(CI.getCallingConv()); std::copy(CI.op_begin(), CI.op_end(), op_begin()); Index: lib/Transforms/Utils/InlineFunction.cpp =================================================================== --- lib/Transforms/Utils/InlineFunction.cpp +++ lib/Transforms/Utils/InlineFunction.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" @@ -29,6 +30,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/raw_ostream.h" using namespace llvm; bool llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, @@ -478,6 +480,17 @@ } } +static bool isPrecededByMustTailCall(const Instruction *Inst) { + const Instruction *Prev = Inst->getPrevNode(); + if (const auto *BI = dyn_cast(Prev)) + Prev = BI->getPrevNode(); + if (const auto *CI = dyn_cast(Prev)) { + if (CI->isMustTailCall()) + return true; + } + return false; +} + /// InlineFunction - This function inlines the called function into the basic /// block of the caller. This returns false if it is not possible to inline /// this call. The program is still in a well defined state if this occurs @@ -503,8 +516,10 @@ // If the call to the callee is not a tail call, we must clear the 'tail' // flags on any calls that we inline. - bool MustClearTailCallFlags = - !(isa(TheCall) && cast(TheCall)->isTailCall()); + CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None; + if (CallInst *CI = dyn_cast(TheCall)) + CallSiteTailKind = CI->getTailCallKind(); + bool MustClearTailCallFlags = false; // If the call to the callee cannot throw, set the 'nounwind' flag on any // calls that we inline. @@ -661,6 +676,38 @@ } } + // We need to reduce the strength of any inlined tail calls. For musttail, we + // have to avoid introducing potential unbounded stack growth. For example, + // if functions 'f' and 'g' are mutually recursive with musttail, we can + // inline 'g' into 'f' so long as we preserve musttail on the cloned call to + // 'f'. If either the inlined call site or the cloned call site is *not* + // musttail, the program already has one frame of stack growth, so it's safe + // to remove musttail. Here is a table of example transformations: + // + // f -> musttail g -> musttail h ==> f -> musttail g + // f -> musttail g -> tail h ==> f -> tail g + // f -> g -> musttail h ==> f -> g + // f -> g -> tail h ==> f -> g + // + // Also, calls inlined through a 'nounwind' call site should be marked + // 'nounwind'. + bool InlinedMustTailCalls = false; + if (InlinedFunctionInfo.ContainsCalls) { + for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E; + ++BB) { + for (Instruction &I : *BB) { + if (CallInst *CI = dyn_cast(&I)) { + CallInst::TailCallKind ChildTCK = CI->getTailCallKind(); + ChildTCK = std::min(CallSiteTailKind, ChildTCK); + CI->setTailCallKind(ChildTCK); + InlinedMustTailCalls |= CI->isMustTailCall(); + if (MarkNoUnwind) + CI->setDoesNotThrow(); + } + } + } + } + // Leave lifetime markers for the static alloca's, scoping them to the // function we just inlined. if (InsertLifetime && !IFI.StaticAllocas.empty()) { @@ -693,9 +740,12 @@ } builder.CreateLifetimeStart(AI, AllocaSize); - for (unsigned ri = 0, re = Returns.size(); ri != re; ++ri) { - IRBuilder<> builder(Returns[ri]); - builder.CreateLifetimeEnd(AI, AllocaSize); + for (ReturnInst *RI : Returns) { + // Don't insert lifetime.end calls after a musttail call, the allocas + // are trivially dead. + if (InlinedMustTailCalls && isPrecededByMustTailCall(RI)) + continue; + IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } } } @@ -714,33 +764,60 @@ // Insert a call to llvm.stackrestore before any return instructions in the // inlined function. - for (unsigned i = 0, e = Returns.size(); i != e; ++i) { - IRBuilder<>(Returns[i]).CreateCall(StackRestore, SavedPtr); + for (ReturnInst *RI : Returns) { + // Don't insert lifetime.end calls after a musttail call, the allocas + // are trivially dead. + if (InlinedMustTailCalls && isPrecededByMustTailCall(RI)) + continue; + IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } } - // If we are inlining tail call instruction through a call site that isn't - // marked 'tail', we must remove the tail marker for any calls in the inlined - // code. Also, calls inlined through a 'nounwind' call site should be marked - // 'nounwind'. - if (InlinedFunctionInfo.ContainsCalls && - (MustClearTailCallFlags || MarkNoUnwind)) { - for (Function::iterator BB = FirstNewBlock, E = Caller->end(); - BB != E; ++BB) - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (CallInst *CI = dyn_cast(I)) { - if (MustClearTailCallFlags) - CI->setTailCall(false); - if (MarkNoUnwind) - CI->setDoesNotThrow(); - } - } - // If we are inlining for an invoke instruction, we must make sure to rewrite // any call instructions into invoke instructions. if (InvokeInst *II = dyn_cast(TheCall)) HandleInlinedInvoke(II, FirstNewBlock, InlinedFunctionInfo); + // Handle any inlined musttail call sites. In order for a new call site to be + // musttail, the source of the clone and the inlined call site must have been + // musttail. Therefore it's safe to return without merging control into the + // phi below. + if (InlinedMustTailCalls) { + Type *NewRetTy = Caller->getReturnType(); + // Partition the returns into normal returns and musttail returns. + auto B = Returns.begin(), E = Returns.end(); + auto M = std::partition(B, E, [](const ReturnInst *Inst) { + return !isPrecededByMustTailCall(Inst); + }); + + if (!TheCall->use_empty() && TheCall->getType() != NewRetTy) { + // Change users of the return value of the call. For a musttail call + // site, the only real user should be a return instruction with an + // optional bitcast. Avoid chains of bitcasts. + for (auto I = M; I != E; ++I) { + // Strip any possible inlined bitcast and grab the musttail call. + ReturnInst *RI = *I; + Value *RV = RI->getReturnValue(); + BitCastInst *OldCast = dyn_cast(RV); + if (OldCast) + RV = OldCast->getOperand(0); + assert(cast(RV)->isMustTailCall()); + + // Insert a new bitcast and return with the right type. + IRBuilder<> Builder(RI); + Builder.CreateRet(Builder.CreateBitCast(RV, NewRetTy)); + + // Delete the old return and any bitcast. + RI->eraseFromParent(); + if (OldCast) + OldCast->eraseFromParent(); + } + } + + // Remove the returns preceded by musttail, since they are early returns. + Returns.erase(M, E); + } + // If we cloned in _exactly one_ basic block, and if that block ends in a // return instruction, we splice the body of the inlined callee directly into // the calling basic block. @@ -779,6 +856,7 @@ // Otherwise, we have the normal case, of more than one block to inline or // multiple return sites. + Type *RTy = CalledFunc->getReturnType(); // We want to clone the entire callee function into the hole between the // "starter" and "ender" blocks. How we accomplish this depends on whether @@ -821,7 +899,6 @@ // Handle all of the return instructions that we just cloned in, and eliminate // any users of the original call/invoke instruction. - Type *RTy = CalledFunc->getReturnType(); PHINode *PHI = 0; if (Returns.size() > 1) { @@ -896,6 +973,11 @@ // Since we are now done with the Call/Invoke, we can delete it. TheCall->eraseFromParent(); + // If we inlined any musttail calls and the original return is now + // unreachable, delete it. It can only contain a bitcast and ret. + if (InlinedMustTailCalls && pred_begin(AfterCallBB) == pred_end(AfterCallBB)) + AfterCallBB->eraseFromParent(); + // We should always be able to fold the entry block of the function into the // single predecessor of the block... assert(cast(Br)->isUnconditional() && "splitBasicBlock broken!"); Index: test/Transforms/Inline/inline-tail.ll =================================================================== --- test/Transforms/Inline/inline-tail.ll +++ test/Transforms/Inline/inline-tail.ll @@ -1,15 +1,182 @@ -; RUN: opt < %s -inline -S | not grep tail +; RUN: opt < %s -inline -S | FileCheck %s -declare void @bar(i32*) +; We have to apply the less restrictive TailCallKind of the call site being +; inlined and any call sites cloned into the caller. -define internal void @foo(i32* %P) { - tail call void @bar( i32* %P ) - ret void +; No tail marker after inlining, since test_capture_c captures an alloca. +; CHECK: define void @test_capture_a( +; CHECK-NOT: tail +; CHECK: call void @test_capture_c( + +declare void @test_capture_c(i32*) +define internal void @test_capture_b(i32* %P) { + tail call void @test_capture_c(i32* %P) + ret void +} +define void @test_capture_a() { + %A = alloca i32 ; captured by test_capture_b + call void @test_capture_b(i32* %A) + ret void +} + +; No musttail marker after inlining, since the prototypes don't match. +; CHECK: define void @test_proto_mismatch_a( +; CHECK-NOT: musttail +; CHECK: call void @test_proto_mismatch_c( + +declare void @test_proto_mismatch_c(i32*) +define internal void @test_proto_mismatch_b(i32* %p) { + musttail call void @test_proto_mismatch_c(i32* %p) + ret void +} +define void @test_proto_mismatch_a() { + call void @test_proto_mismatch_b(i32* null) + ret void +} + +; After inlining through a musttail call site, we need to keep musttail markers +; to prevent unbounded stack growth. +; CHECK: define void @test_musttail_basic_a( +; CHECK: musttail call void @test_musttail_basic_c( + +declare void @test_musttail_basic_c(i32* %p) +define internal void @test_musttail_basic_b(i32* %p) { + musttail call void @test_musttail_basic_c(i32* %p) + ret void +} +define void @test_musttail_basic_a(i32* %p) { + musttail call void @test_musttail_basic_b(i32* %p) + ret void +} + +; Don't insert lifetime end markers here, the lifetime is trivially over due +; the return. +; CHECK: define void @test_byval_a( +; CHECK: musttail call void @test_byval_c( +; CHECK-NEXT: ret void + +declare void @test_byval_c(i32* byval %p) +define internal void @test_byval_b(i32* byval %p) { + musttail call void @test_byval_c(i32* byval %p) + ret void +} +define void @test_byval_a(i32* byval %p) { + musttail call void @test_byval_b(i32* byval %p) + ret void } -define void @caller() { - %A = alloca i32 ; [#uses=1] - call void @foo( i32* %A ) - ret void +; Don't insert a stack restore, we're about to return. +; CHECK: define void @test_dynalloca_a( +; CHECK: call i8* @llvm.stacksave( +; CHECK: alloca i8, i32 %n +; CHECK: musttail call void @test_dynalloca_c( +; CHECK-NEXT: ret void + +declare void @escape(i8* %buf) +declare void @test_dynalloca_c(i32* byval %p, i32 %n) +define internal void @test_dynalloca_b(i32* byval %p, i32 %n) alwaysinline { + %buf = alloca i8, i32 %n ; dynamic alloca + call void @escape(i8* %buf) ; escape it + musttail call void @test_dynalloca_c(i32* byval %p, i32 %n) + ret void +} +define void @test_dynalloca_a(i32* byval %p, i32 %n) { + musttail call void @test_dynalloca_b(i32* byval %p, i32 %n) + ret void } +; We can't merge the returns. +; CHECK: define void @test_multiret_a( +; CHECK: musttail call void @test_multiret_c( +; CHECK-NEXT: ret void +; CHECK: musttail call void @test_multiret_d( +; CHECK-NEXT: ret void + +declare void @test_multiret_c(i1 zeroext %b) +declare void @test_multiret_d(i1 zeroext %b) +define internal void @test_multiret_b(i1 zeroext %b) { + br i1 %b, label %c, label %d +c: + musttail call void @test_multiret_c(i1 zeroext %b) + ret void +d: + musttail call void @test_multiret_d(i1 zeroext %b) + ret void +} +define void @test_multiret_a(i1 zeroext %b) { + musttail call void @test_multiret_b(i1 zeroext %b) + ret void +} + +; We have to avoid bitcast chains. +; CHECK: define i32* @test_retptr_a( +; CHECK: musttail call i8* @test_retptr_c( +; CHECK-NEXT: bitcast i8* {{.*}} to i32* +; CHECK-NEXT: ret i32* + +declare i8* @test_retptr_c() +define internal i16* @test_retptr_b() { + %rv = musttail call i8* @test_retptr_c() + %v = bitcast i8* %rv to i16* + ret i16* %v +} +define i32* @test_retptr_a() { + %rv = musttail call i16* @test_retptr_b() + %v = bitcast i16* %rv to i32* + ret i32* %v +} + +; Combine the last two cases: multiple returns with pointer bitcasts. +; CHECK: define i32* @test_multiptrret_a( +; CHECK: musttail call i8* @test_multiptrret_c( +; CHECK-NEXT: bitcast i8* {{.*}} to i32* +; CHECK-NEXT: ret i32* +; CHECK: musttail call i8* @test_multiptrret_d( +; CHECK-NEXT: bitcast i8* {{.*}} to i32* +; CHECK-NEXT: ret i32* + +declare i8* @test_multiptrret_c(i1 zeroext %b) +declare i8* @test_multiptrret_d(i1 zeroext %b) +define internal i16* @test_multiptrret_b(i1 zeroext %b) { + br i1 %b, label %c, label %d +c: + %c_rv = musttail call i8* @test_multiptrret_c(i1 zeroext %b) + %c_v = bitcast i8* %c_rv to i16* + ret i16* %c_v +d: + %d_rv = musttail call i8* @test_multiptrret_d(i1 zeroext %b) + %d_v = bitcast i8* %d_rv to i16* + ret i16* %d_v +} +define i32* @test_multiptrret_a(i1 zeroext %b) { + %rv = musttail call i16* @test_multiptrret_b(i1 zeroext %b) + %v = bitcast i16* %rv to i32* + ret i32* %v +} + +; Inline a musttail call site which contains a normal return and a musttail call. +; CHECK: define i32 @test_mixedret_a( +; CHECK: br i1 %b +; CHECK: musttail call i32 @test_mixedret_c( +; CHECK-NEXT: ret i32 +; CHECK: call i32 @test_mixedret_d(i1 zeroext %b) +; CHECK: add i32 1, +; CHECK-NOT: br +; CHECK: ret i32 + +declare i32 @test_mixedret_c(i1 zeroext %b) +declare i32 @test_mixedret_d(i1 zeroext %b) +define internal i32 @test_mixedret_b(i1 zeroext %b) { + br i1 %b, label %c, label %d +c: + %c_rv = musttail call i32 @test_mixedret_c(i1 zeroext %b) + ret i32 %c_rv +d: + %d_rv = call i32 @test_mixedret_d(i1 zeroext %b) + %d_rv1 = add i32 1, %d_rv + ret i32 %d_rv1 +} +define i32 @test_mixedret_a(i1 zeroext %b) { + %rv = musttail call i32 @test_mixedret_b(i1 zeroext %b) + ret i32 %rv +}