diff --git a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h --- a/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h +++ b/llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h @@ -75,6 +75,7 @@ DeadArgumentEliminationPass(bool ShouldHackArguments = false) : ShouldHackArguments(ShouldHackArguments) {} + void extracted(Module &M); PreservedAnalyses run(Module &M, ModuleAnalysisManager &); /// Convenience wrapper @@ -136,6 +137,7 @@ bool removeDeadStuffFromFunction(Function *F); bool deleteDeadVarargs(Function &F); bool removeDeadArgumentsFromCallers(Function &F); + void propagateVirtMustcallLiveness(const Module &M); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -520,8 +520,15 @@ for (const BasicBlock &BB : F) { // If we have any returns of `musttail` results - the signature can't // change - if (BB.getTerminatingMustTailCall() != nullptr) + if (const auto *TC = BB.getTerminatingMustTailCall()) { HasMustTailCalls = true; + // In addition, if the called function is virtual, we can't change + // the return type. + if (!TC->getCalledFunction()) { + markLive(F); + return; + } + } } if (HasMustTailCalls) { @@ -1081,6 +1088,26 @@ return true; } +void DeadArgumentEliminationPass::propagateVirtMustcallLiveness( + const Module &M) { + // If a function was marked "live", and it has musttail callers, they in turn + // can't change either. + LiveFuncSet NewLiveFuncs(LiveFunctions); + while (!NewLiveFuncs.empty()) { + LiveFuncSet Temp; + for (const auto &F : M) + if (NewLiveFuncs.count(&F)) + for (const auto *U : F.users()) + if (const auto *CB = dyn_cast(U)) + if (CB->isMustTailCall()) + Temp.insert(CB->getParent()->getParent()); + NewLiveFuncs.clear(); + NewLiveFuncs.insert(Temp.begin(), Temp.end()); + for (const auto *F : Temp) + markLive(*F); + } +} + PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, ModuleAnalysisManager &) { bool Changed = false; @@ -1101,6 +1128,8 @@ for (auto &F : M) surveyFunction(F); + propagateVirtMustcallLiveness(M); + // Now, remove all dead arguments and return values from each function in // turn. We use make_early_inc_range here because functions will probably get // removed (i.e. replaced by new ones). diff --git a/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll b/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll @@ -0,0 +1,26 @@ +; RUN: opt -passes=deadargelim,verify -S < %s 2>&1 | FileCheck %s +define internal i32 @test_caller(ptr %fptr, i32 %a, i32 %b) { + %r = musttail call i32 @test(ptr %fptr, i32 %a, i32 %b) + ret i32 %r +} + +define internal i32 @test(ptr %fptr, i32 %a, i32 %b) { + %r = musttail call i32 %fptr(ptr %fptr, i32 %a, i32 0) + ret i32 %r +} + +define internal i32 @direct_test() { + %r = musttail call i32 @foo() + ret i32 %r +} + +declare i32 @foo() + +; CHECK-NOT: cannot guarantee tail call due to mismatched parameter counts + +; CHECK: define internal i32 @test_caller +; CHECK-NEXT: %r = musttail call i32 @test + +; CHECK: define internal i32 @test +; CHECK-NEXT: %r = musttail call i32 %fptr +