diff --git a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h --- a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h +++ b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h @@ -136,6 +136,7 @@ bool removeDeadStuffFromFunction(Function *F); bool deleteDeadVarargs(Function &F); bool removeDeadArgumentsFromCallers(Function &F); + void propagateVirtMustcallLiveness(const Module &M); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -85,6 +85,10 @@ virtual bool shouldHackArguments() const { return false; } }; +bool isCalleeAnalyzable(const CallBase &CB) { + return CB.getCalledFunction() && !CB.getCalledFunction()->isDeclaration(); +} + } // end anonymous namespace char DAE::ID = 0; @@ -520,8 +524,16 @@ for (const BasicBlock &BB : F) { // If we have any returns of `musttail` results - the signature can't // change - if (BB.getTerminatingMustTailCall() != nullptr) + if (const auto *TC = BB.getTerminatingMustTailCall()) { HasMustTailCalls = true; + // In addition, if the called function is not locally defined (or unknown, + // if this is an indirect call), we can't change the callsite and thus + // can't change this function's signature either. + if (!isCalleeAnalyzable(*TC)) { + markLive(F); + return; + } + } } if (HasMustTailCalls) { @@ -1081,6 +1093,26 @@ return true; } +void DeadArgumentEliminationPass::propagateVirtMustcallLiveness( + const Module &M) { + // If a function was marked "live", and it has musttail callers, they in turn + // can't change either. + LiveFuncSet NewLiveFuncs(LiveFunctions); + while (!NewLiveFuncs.empty()) { + LiveFuncSet Temp; + for (const auto *F : NewLiveFuncs) + for (const auto *U : F->users()) + if (const auto *CB = dyn_cast(U)) + if (CB->isMustTailCall()) + if (!LiveFunctions.count(CB->getParent()->getParent())) + Temp.insert(CB->getParent()->getParent()); + NewLiveFuncs.clear(); + NewLiveFuncs.insert(Temp.begin(), Temp.end()); + for (const auto *F : Temp) + markLive(*F); + } +} + PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, ModuleAnalysisManager &) { bool Changed = false; @@ -1101,6 +1133,8 @@ for (auto &F : M) surveyFunction(F); + propagateVirtMustcallLiveness(M); + // Now, remove all dead arguments and return values from each function in // turn. We use make_early_inc_range here because functions will probably get // removed (i.e. replaced by new ones). diff --git a/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll b/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll @@ -0,0 +1,46 @@ +; RUN: opt -passes=deadargelim,verify -S < %s 2>&1 | FileCheck %s +define internal i32 @test_caller(ptr %fptr, i32 %a, i32 %b) { + %r = musttail call i32 @test(ptr %fptr, i32 %a, i32 %b) + ret i32 %r +} + +define internal i32 @test(ptr %fptr, i32 %a, i32 %b) { + %r = musttail call i32 %fptr(ptr %fptr, i32 %a, i32 0) + ret i32 %r +} + +define internal i32 @direct_test() { + %r = musttail call i32 @foo() + ret i32 %r +} + +declare i32 @foo() + +define internal i32 @ping(i32 %x) { + %r = musttail call i32 @pong(i32 %x) + ret i32 %r +} + +define internal i32 @pong(i32 %x) { + %cond = icmp eq i32 %x, 2 + br i1 %cond, label %yes, label %no + +yes: + %r1 = musttail call i32 @ping(i32 %x) + ret i32 %r1 +no: + %r2 = musttail call i32 @bar(i32 %x) + ret i32 %r2 +} + +declare i32 @bar(i32 %x) + +; This is the main check - that we produce valid IR. +; CHECK-NOT: cannot guarantee tail call due to mismatched parameter counts + +; CHECK: define internal i32 @test_caller +; CHECK-NEXT: %r = musttail call i32 @test + +; CHECK: define internal i32 @test +; CHECK-NEXT: %r = musttail call i32 %fptr +