diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -275,7 +275,8 @@ std::function GetTTI; std::function GetTLI; - SmallPtrSet SpecializedFuncs; + SmallPtrSet SpecializedFuncs; + SmallPtrSet FullySpecialized; SmallVector ReplacedWithConstant; public: @@ -285,6 +286,12 @@ std::function GetTLI) : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {} + ~FunctionSpecializer() { + // Eliminate dead code. + removeDeadInstructions(); + removeDeadFunctions(); + } + /// Attempt to specialize functions in the module to enable constant /// propagation across function boundaries. /// @@ -331,6 +338,15 @@ ReplacedWithConstant.clear(); } + void removeDeadFunctions() { + for (auto *F : FullySpecialized) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function " + << F->getName() << "\n"); + F->eraseFromParent(); + } + FullySpecialized.clear(); + } + bool tryToReplaceWithConstant(Value *V) { if (!V->getType()->isSingleValueType() || isa(V) || V->user_empty()) @@ -501,8 +517,15 @@ // If the function has been completely specialized, the original function // is no longer needed. Mark it unreachable. - if (!AI.Partial) + if (AI.Fn->getNumUses() == 0 || + all_of(AI.Fn->users(), [&AI](User *U) { + if (auto *CS = dyn_cast(U)) + return CS->getFunction() == AI.Fn; + return false; + })) { Solver.markFunctionUnreachable(AI.Fn); + FullySpecialized.insert(AI.Fn); + } } /// Compute and return the cost of specializing function \p F. @@ -923,8 +946,7 @@ LLVM_DEBUG(dbgs() << "FnSpecialization: Number of specializations = " << NumFuncSpecialized <<"\n"); - // Clean up the IR by removing dead instructions and ssa_copy intrinsics. - FS.removeDeadInstructions(); + // Remove any ssa_copy intrinsics that may have been introduced. removeSSACopy(M); return Changed; } diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression3.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression3.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression3.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization-constant-expression3.ll @@ -18,11 +18,6 @@ declare i32 @eggs() define internal void @wombat(i8* %arg, i64 %arg1, i64 %arg2, i32 (i8*, i8*)* %arg3) { -; CHECK-LABEL: @wombat( -; CHECK-NEXT: bb4: -; CHECK-NEXT: [[TMP:%.*]] = tail call i32 [[ARG3:%.*]](i8* undef, i8* undef) -; CHECK-NEXT: ret void -; bb4: %tmp = tail call i32 %arg3(i8* undef, i8* undef) ret void diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization.ll @@ -32,6 +32,8 @@ ret i64 %tmp2 } +; CHECK-NOT: define internal i64 @compute( +; ; CHECK-LABEL: define internal i64 @compute.1(i64 %x, i64 (i64)* %binop) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[TMP0:%.+]] = call i64 @plus(i64 %x) diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization3.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization3.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization3.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization3.ll @@ -34,20 +34,22 @@ ret i32 %retval.0 } +; FORCE-NOT: define internal i32 @foo( +; ; FORCE: define internal i32 @foo.1(i32 %x, i32* %b) { ; FORCE-NEXT: entry: ; FORCE-NEXT: %0 = load i32, i32* @A, align 4 ; FORCE-NEXT: %add = add nsw i32 %x, %0 ; FORCE-NEXT: ret i32 %add ; FORCE-NEXT: } - +; ; FORCE: define internal i32 @foo.2(i32 %x, i32* %b) { ; FORCE-NEXT: entry: ; FORCE-NEXT: %0 = load i32, i32* @B, align 4 ; FORCE-NEXT: %add = add nsw i32 %x, %0 ; FORCE-NEXT: ret i32 %add ; FORCE-NEXT: } - +; define internal i32 @foo(i32 %x, i32* %b) { entry: %0 = load i32, i32* %b, align 4 diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll --- a/llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll +++ b/llvm/test/Transforms/FunctionSpecialization/function-specialization4.ll @@ -29,6 +29,7 @@ ret i32 %retval.0 } +; CHECK-NOT: define internal i32 @foo( define internal i32 @foo(i32 %x, i32* %b, i32* %c) { entry: %0 = load i32, i32* %b, align 4 diff --git a/llvm/test/Transforms/FunctionSpecialization/remove-dead-recursive-function.ll b/llvm/test/Transforms/FunctionSpecialization/remove-dead-recursive-function.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/remove-dead-recursive-function.ll @@ -0,0 +1,59 @@ +; RUN: opt -function-specialization -func-specialization-size-threshold=3 -S < %s | FileCheck %s + +define i64 @main(i64 %x, i1 %flag) { +entry: + br i1 %flag, label %plus, label %minus + +plus: + %tmp0 = call i64 @compute(i64 %x, i64 (i64)* @plus) + br label %merge + +minus: + %tmp1 = call i64 @compute(i64 %x, i64 (i64)* @minus) + br label %merge + +merge: + %tmp2 = phi i64 [ %tmp0, %plus ], [ %tmp1, %minus] + ret i64 %tmp2 +} + +; CHECK-NOT: define internal i64 @compute( +; +; CHECK-LABEL: define internal i64 @compute.1(i64 %n, i64 (i64)* %binop) { +; CHECK: [[TMP0:%.+]] = call i64 @plus(i64 %n) +; CHECK: [[TMP1:%.+]] = call i64 @compute.1(i64 [[TMP2:%.+]], i64 (i64)* @plus) +; CHECK: add nsw i64 [[TMP1]], [[TMP0]] +; +; CHECK-LABEL: define internal i64 @compute.2(i64 %n, i64 (i64)* %binop) { +; CHECK: [[TMP0:%.+]] = call i64 @minus(i64 %n) +; CHECK: [[TMP1:%.+]] = call i64 @compute.2(i64 [[TMP2:%.+]], i64 (i64)* @minus) +; CHECK: add nsw i64 [[TMP1]], [[TMP0]] +; +define internal i64 @compute(i64 %n, i64 (i64)* %binop) { +entry: + %cmp = icmp sgt i64 %n, 0 + br i1 %cmp, label %if.then, label %if.end + +if.then: + %call = call i64 %binop(i64 %n) + %sub = add nsw i64 %n, -1 + %call1 = call i64 @compute(i64 %sub, i64 (i64)* %binop) + %add2 = add nsw i64 %call1, %call + br label %if.end + +if.end: + %result.0 = phi i64 [ %add2, %if.then ], [ 0, %entry ] + ret i64 %result.0 +} + +define internal i64 @plus(i64 %x) { +entry: + %tmp0 = add i64 %x, 1 + ret i64 %tmp0 +} + +define internal i64 @minus(i64 %x) { +entry: + %tmp0 = sub i64 %x, 1 + ret i64 %tmp0 +}