Index: llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp +++ llvm/trunk/lib/Transforms/Utils/InlineFunction.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" @@ -478,6 +479,33 @@ } } +/// Returns a musttail call instruction if one immediately precedes the given +/// return instruction with an optional bitcast instruction between them. +static CallInst *getPrecedingMustTailCall(ReturnInst *RI) { + Instruction *Prev = RI->getPrevNode(); + if (!Prev) + return nullptr; + + if (Value *RV = RI->getReturnValue()) { + if (RV != Prev) + return nullptr; + + // Look through the optional bitcast. + if (auto *BI = dyn_cast(Prev)) { + RV = BI->getOperand(0); + Prev = BI->getPrevNode(); + if (!Prev || RV != Prev) + return nullptr; + } + } + + if (auto *CI = dyn_cast(Prev)) { + if (CI->isMustTailCall()) + return CI; + } + return nullptr; +} + /// InlineFunction - This function inlines the called function into the basic /// block of the caller. This returns false if it is not possible to inline /// this call. The program is still in a well defined state if this occurs @@ -503,8 +531,10 @@ // If the call to the callee is not a tail call, we must clear the 'tail' // flags on any calls that we inline. - bool MustClearTailCallFlags = - !(isa(TheCall) && cast(TheCall)->isTailCall()); + CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None; + if (CallInst *CI = dyn_cast(TheCall)) + CallSiteTailKind = CI->getTailCallKind(); + bool MustClearTailCallFlags = false; // If the call to the callee cannot throw, set the 'nounwind' flag on any // calls that we inline. @@ -661,6 +691,41 @@ } } + bool InlinedMustTailCalls = false; + if (InlinedFunctionInfo.ContainsCalls) { + for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E; + ++BB) { + for (Instruction &I : *BB) { + CallInst *CI = dyn_cast(&I); + if (!CI) + continue; + + // We need to reduce the strength of any inlined tail calls. For + // musttail, we have to avoid introducing potential unbounded stack + // growth. For example, if functions 'f' and 'g' are mutually recursive + // with musttail, we can inline 'g' into 'f' so long as we preserve + // musttail on the cloned call to 'f'. If either the inlined call site + // or the cloned call site is *not* musttail, the program already has + // one frame of stack growth, so it's safe to remove musttail. Here is + // a table of example transformations: + // + // f -> musttail g -> musttail f ==> f -> musttail f + // f -> musttail g -> tail f ==> f -> tail f + // f -> g -> musttail f ==> f -> f + // f -> g -> tail f ==> f -> f + CallInst::TailCallKind ChildTCK = CI->getTailCallKind(); + ChildTCK = std::min(CallSiteTailKind, ChildTCK); + CI->setTailCallKind(ChildTCK); + InlinedMustTailCalls |= CI->isMustTailCall(); + + // Calls inlined through a 'nounwind' call site should be marked + // 'nounwind'. + if (MarkNoUnwind) + CI->setDoesNotThrow(); + } + } + } + // Leave lifetime markers for the static alloca's, scoping them to the // function we just inlined. if (InsertLifetime && !IFI.StaticAllocas.empty()) { @@ -693,10 +758,8 @@ } builder.CreateLifetimeStart(AI, AllocaSize); - for (unsigned ri = 0, re = Returns.size(); ri != re; ++ri) { - IRBuilder<> builder(Returns[ri]); - builder.CreateLifetimeEnd(AI, AllocaSize); - } + for (ReturnInst *RI : Returns) + IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } } @@ -714,26 +777,8 @@ // Insert a call to llvm.stackrestore before any return instructions in the // inlined function. - for (unsigned i = 0, e = Returns.size(); i != e; ++i) { - IRBuilder<>(Returns[i]).CreateCall(StackRestore, SavedPtr); - } - } - - // If we are inlining tail call instruction through a call site that isn't - // marked 'tail', we must remove the tail marker for any calls in the inlined - // code. Also, calls inlined through a 'nounwind' call site should be marked - // 'nounwind'. - if (InlinedFunctionInfo.ContainsCalls && - (MustClearTailCallFlags || MarkNoUnwind)) { - for (Function::iterator BB = FirstNewBlock, E = Caller->end(); - BB != E; ++BB) - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (CallInst *CI = dyn_cast(I)) { - if (MustClearTailCallFlags) - CI->setTailCall(false); - if (MarkNoUnwind) - CI->setDoesNotThrow(); - } + for (ReturnInst *RI : Returns) + IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } // If we are inlining for an invoke instruction, we must make sure to rewrite @@ -741,6 +786,42 @@ if (InvokeInst *II = dyn_cast(TheCall)) HandleInlinedInvoke(II, FirstNewBlock, InlinedFunctionInfo); + // Handle any inlined musttail call sites. In order for a new call site to be + // musttail, the source of the clone and the inlined call site must have been + // musttail. Therefore it's safe to return without merging control into the + // phi below. + if (InlinedMustTailCalls) { + // Check if we need to bitcast the result of any musttail calls. + Type *NewRetTy = Caller->getReturnType(); + bool NeedBitCast = !TheCall->use_empty() && TheCall->getType() != NewRetTy; + + // Handle the returns preceded by musttail calls separately. + SmallVector NormalReturns; + for (ReturnInst *RI : Returns) { + CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI); + if (!ReturnedMustTail) { + NormalReturns.push_back(RI); + continue; + } + if (!NeedBitCast) + continue; + + // Delete the old return and any preceding bitcast. + BasicBlock *CurBB = RI->getParent(); + auto *OldCast = dyn_cast_or_null(RI->getReturnValue()); + RI->eraseFromParent(); + if (OldCast) + OldCast->eraseFromParent(); + + // Insert a new bitcast and return with the right type. + IRBuilder<> Builder(CurBB); + Builder.CreateRet(Builder.CreateBitCast(ReturnedMustTail, NewRetTy)); + } + + // Leave behind the normal returns so we can merge control flow. + std::swap(Returns, NormalReturns); + } + // If we cloned in _exactly one_ basic block, and if that block ends in a // return instruction, we splice the body of the inlined callee directly into // the calling basic block. @@ -896,6 +977,11 @@ // Since we are now done with the Call/Invoke, we can delete it. TheCall->eraseFromParent(); + // If we inlined any musttail calls and the original return is now + // unreachable, delete it. It can only contain a bitcast and ret. + if (InlinedMustTailCalls && pred_begin(AfterCallBB) == pred_end(AfterCallBB)) + AfterCallBB->eraseFromParent(); + // We should always be able to fold the entry block of the function into the // single predecessor of the block... assert(cast(Br)->isUnconditional() && "splitBasicBlock broken!"); Index: llvm/trunk/test/Transforms/Inline/inline-tail.ll =================================================================== --- llvm/trunk/test/Transforms/Inline/inline-tail.ll +++ llvm/trunk/test/Transforms/Inline/inline-tail.ll @@ -1,15 +1,146 @@ -; RUN: opt < %s -inline -S | not grep tail +; RUN: opt < %s -inline -S | FileCheck %s -declare void @bar(i32*) +; We have to apply the less restrictive TailCallKind of the call site being +; inlined and any call sites cloned into the caller. -define internal void @foo(i32* %P) { - tail call void @bar( i32* %P ) - ret void +; No tail marker after inlining, since test_capture_c captures an alloca. +; CHECK: define void @test_capture_a( +; CHECK-NOT: tail +; CHECK: call void @test_capture_c( + +declare void @test_capture_c(i32*) +define internal void @test_capture_b(i32* %P) { + tail call void @test_capture_c(i32* %P) + ret void +} +define void @test_capture_a() { + %A = alloca i32 ; captured by test_capture_b + call void @test_capture_b(i32* %A) + ret void +} + +; No musttail marker after inlining, since the prototypes don't match. +; CHECK: define void @test_proto_mismatch_a( +; CHECK-NOT: musttail +; CHECK: call void @test_proto_mismatch_c( + +declare void @test_proto_mismatch_c(i32*) +define internal void @test_proto_mismatch_b(i32* %p) { + musttail call void @test_proto_mismatch_c(i32* %p) + ret void +} +define void @test_proto_mismatch_a() { + call void @test_proto_mismatch_b(i32* null) + ret void } -define void @caller() { - %A = alloca i32 ; [#uses=1] - call void @foo( i32* %A ) - ret void +; After inlining through a musttail call site, we need to keep musttail markers +; to prevent unbounded stack growth. +; CHECK: define void @test_musttail_basic_a( +; CHECK: musttail call void @test_musttail_basic_c( + +declare void @test_musttail_basic_c(i32* %p) +define internal void @test_musttail_basic_b(i32* %p) { + musttail call void @test_musttail_basic_c(i32* %p) + ret void +} +define void @test_musttail_basic_a(i32* %p) { + musttail call void @test_musttail_basic_b(i32* %p) + ret void +} + +; We can't merge the returns. +; CHECK: define void @test_multiret_a( +; CHECK: musttail call void @test_multiret_c( +; CHECK-NEXT: ret void +; CHECK: musttail call void @test_multiret_d( +; CHECK-NEXT: ret void + +declare void @test_multiret_c(i1 zeroext %b) +declare void @test_multiret_d(i1 zeroext %b) +define internal void @test_multiret_b(i1 zeroext %b) { + br i1 %b, label %c, label %d +c: + musttail call void @test_multiret_c(i1 zeroext %b) + ret void +d: + musttail call void @test_multiret_d(i1 zeroext %b) + ret void +} +define void @test_multiret_a(i1 zeroext %b) { + musttail call void @test_multiret_b(i1 zeroext %b) + ret void } +; We have to avoid bitcast chains. +; CHECK: define i32* @test_retptr_a( +; CHECK: musttail call i8* @test_retptr_c( +; CHECK-NEXT: bitcast i8* {{.*}} to i32* +; CHECK-NEXT: ret i32* + +declare i8* @test_retptr_c() +define internal i16* @test_retptr_b() { + %rv = musttail call i8* @test_retptr_c() + %v = bitcast i8* %rv to i16* + ret i16* %v +} +define i32* @test_retptr_a() { + %rv = musttail call i16* @test_retptr_b() + %v = bitcast i16* %rv to i32* + ret i32* %v +} + +; Combine the last two cases: multiple returns with pointer bitcasts. +; CHECK: define i32* @test_multiptrret_a( +; CHECK: musttail call i8* @test_multiptrret_c( +; CHECK-NEXT: bitcast i8* {{.*}} to i32* +; CHECK-NEXT: ret i32* +; CHECK: musttail call i8* @test_multiptrret_d( +; CHECK-NEXT: bitcast i8* {{.*}} to i32* +; CHECK-NEXT: ret i32* + +declare i8* @test_multiptrret_c(i1 zeroext %b) +declare i8* @test_multiptrret_d(i1 zeroext %b) +define internal i16* @test_multiptrret_b(i1 zeroext %b) { + br i1 %b, label %c, label %d +c: + %c_rv = musttail call i8* @test_multiptrret_c(i1 zeroext %b) + %c_v = bitcast i8* %c_rv to i16* + ret i16* %c_v +d: + %d_rv = musttail call i8* @test_multiptrret_d(i1 zeroext %b) + %d_v = bitcast i8* %d_rv to i16* + ret i16* %d_v +} +define i32* @test_multiptrret_a(i1 zeroext %b) { + %rv = musttail call i16* @test_multiptrret_b(i1 zeroext %b) + %v = bitcast i16* %rv to i32* + ret i32* %v +} + +; Inline a musttail call site which contains a normal return and a musttail call. +; CHECK: define i32 @test_mixedret_a( +; CHECK: br i1 %b +; CHECK: musttail call i32 @test_mixedret_c( +; CHECK-NEXT: ret i32 +; CHECK: call i32 @test_mixedret_d(i1 zeroext %b) +; CHECK: add i32 1, +; CHECK-NOT: br +; CHECK: ret i32 + +declare i32 @test_mixedret_c(i1 zeroext %b) +declare i32 @test_mixedret_d(i1 zeroext %b) +define internal i32 @test_mixedret_b(i1 zeroext %b) { + br i1 %b, label %c, label %d +c: + %c_rv = musttail call i32 @test_mixedret_c(i1 zeroext %b) + ret i32 %c_rv +d: + %d_rv = call i32 @test_mixedret_d(i1 zeroext %b) + %d_rv1 = add i32 1, %d_rv + ret i32 %d_rv1 +} +define i32 @test_mixedret_a(i1 zeroext %b) { + %rv = musttail call i32 @test_mixedret_b(i1 zeroext %b) + ret i32 %rv +}