diff --git a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h --- a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h +++ b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h @@ -136,6 +136,7 @@ bool removeDeadStuffFromFunction(Function *F); bool deleteDeadVarargs(Function &F); bool removeDeadArgumentsFromCallers(Function &F); + void propagateVirtMustcallLiveness(const Module &M); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -85,6 +85,11 @@ virtual bool shouldHackArguments() const { return false; } }; +bool isMustTailCalleeAnalyzable(const CallBase &CB) { + assert(CB.isMustTailCall()); + return CB.getCalledFunction() && !CB.getCalledFunction()->isDeclaration(); +} + } // end anonymous namespace char DAE::ID = 0; @@ -520,8 +525,16 @@ for (const BasicBlock &BB : F) { // If we have any returns of `musttail` results - the signature can't // change - if (BB.getTerminatingMustTailCall() != nullptr) + if (const auto *TC = BB.getTerminatingMustTailCall()) { HasMustTailCalls = true; + // In addition, if the called function is not locally defined (or unknown, + // if this is an indirect call), we can't change the callsite and thus + // can't change this function's signature either. + if (!isMustTailCalleeAnalyzable(*TC)) { + markLive(F); + return; + } + } } if (HasMustTailCalls) { @@ -1081,6 +1094,26 @@ return true; } +void DeadArgumentEliminationPass::propagateVirtMustcallLiveness( + const Module &M) { + // If a function was marked "live", and it has musttail callers, they in turn + // can't change either. + LiveFuncSet NewLiveFuncs(LiveFunctions); + while (!NewLiveFuncs.empty()) { + LiveFuncSet Temp; + for (const auto *F : NewLiveFuncs) + for (const auto *U : F->users()) + if (const auto *CB = dyn_cast(U)) + if (CB->isMustTailCall()) + if (!LiveFunctions.count(CB->getParent()->getParent())) + Temp.insert(CB->getParent()->getParent()); + NewLiveFuncs.clear(); + NewLiveFuncs.insert(Temp.begin(), Temp.end()); + for (const auto *F : Temp) + markLive(*F); + } +} + PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, ModuleAnalysisManager &) { bool Changed = false; @@ -1101,6 +1134,8 @@ for (auto &F : M) surveyFunction(F); + propagateVirtMustcallLiveness(M); + // Now, remove all dead arguments and return values from each function in // turn. We use make_early_inc_range here because functions will probably get // removed (i.e. replaced by new ones). diff --git a/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll b/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll @@ -0,0 +1,65 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: -p --function-signature +; RUN: opt -passes=deadargelim -S < %s | FileCheck %s + +define internal i32 @test_caller(ptr %fptr, i32 %a, i32 %b) { +; CHECK-LABEL: define {{[^@]+}}@test_caller(ptr %fptr, i32 %a, i32 %b) { +; CHECK-NEXT: %r = musttail call i32 @test(ptr %fptr, i32 %a, i32 poison) +; CHECK-NEXT: ret i32 %r +; + %r = musttail call i32 @test(ptr %fptr, i32 %a, i32 %b) + ret i32 %r +} + +define internal i32 @test(ptr %fptr, i32 %a, i32 %b) { +; CHECK-LABEL: define {{[^@]+}}@test(ptr %fptr, i32 %a, i32 %b) { +; CHECK-NEXT: %r = musttail call i32 %fptr(ptr %fptr, i32 %a, i32 0) +; CHECK-NEXT: ret i32 %r +; + %r = musttail call i32 %fptr(ptr %fptr, i32 %a, i32 0) + ret i32 %r +} + +define internal i32 @direct_test() { +; CHECK-LABEL: define {{[^@]+}}@direct_test() { +; CHECK-NEXT: %r = musttail call i32 @foo() +; CHECK-NEXT: ret i32 %r +; + %r = musttail call i32 @foo() + ret i32 %r +} + +declare i32 @foo() + +define internal i32 @ping(i32 %x) { +; CHECK-LABEL: define {{[^@]+}}@ping(i32 %x) { +; CHECK-NEXT: %r = musttail call i32 @pong(i32 %x) +; CHECK-NEXT: ret i32 %r +; + %r = musttail call i32 @pong(i32 %x) + ret i32 %r +} + +define internal i32 @pong(i32 %x) { +; CHECK-LABEL: define {{[^@]+}}@pong(i32 %x) { +; CHECK-NEXT: %cond = icmp eq i32 %x, 2 +; CHECK-NEXT: br i1 %cond, label %yes, label %no +; CHECK: yes: +; CHECK-NEXT: %r1 = musttail call i32 @ping(i32 %x) +; CHECK-NEXT: ret i32 %r1 +; CHECK: no: +; CHECK-NEXT: %r2 = musttail call i32 @bar(i32 %x) +; CHECK-NEXT: ret i32 %r2 +; + %cond = icmp eq i32 %x, 2 + br i1 %cond, label %yes, label %no + +yes: + %r1 = musttail call i32 @ping(i32 %x) + ret i32 %r1 +no: + %r2 = musttail call i32 @bar(i32 %x) + ret i32 %r2 +} + +declare i32 @bar(i32 %x) +