diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -276,6 +276,7 @@ std::function GetTLI; SmallPtrSet SpecializedFuncs; + SmallVector ReplacedWithConstant; public: FunctionSpecializer(SCCPSolver &Solver, @@ -320,6 +321,15 @@ return Changed; } + void removeDeadInstructions() { + for (auto *I : ReplacedWithConstant) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead instruction " + << *I << "\n"); + I->eraseFromParent(); + } + ReplacedWithConstant.clear(); + } + bool tryToReplaceWithConstant(Value *V) { if (!V->getType()->isSingleValueType() || isa(V) || V->user_empty()) @@ -330,6 +340,10 @@ return false; auto *Const = isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing " << *V + << "\nFnSpecialization: with " << *Const << "\n"); + V->replaceAllUsesWith(Const); for (auto *U : Const->users()) @@ -340,7 +354,7 @@ // Remove the instruction from Block and Solver. if (auto *I = dyn_cast(V)) { if (I->isSafeToRemove()) { - I->eraseFromParent(); + ReplacedWithConstant.push_back(I); Solver.removeLatticeValueFor(I); } } @@ -886,7 +900,8 @@ Changed = true; } - // Clean up the IR by removing ssa_copy intrinsics. + // Clean up the IR by removing dead instructions and ssa_copy intrinsics. + FS.removeDeadInstructions(); removeSSACopy(M); return Changed; } diff --git a/llvm/test/Transforms/FunctionSpecialization/bug52821-use-after-free.ll b/llvm/test/Transforms/FunctionSpecialization/bug52821-use-after-free.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/bug52821-use-after-free.ll @@ -0,0 +1,58 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -function-specialization -S < %s | FileCheck %s + +%mystruct = type { i32, [2 x i64] } + +define internal %mystruct* @myfunc(%mystruct* %arg) { +; CHECK-LABEL: @myfunc( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FOR_COND:%.*]] +; CHECK: for.cond: +; CHECK-NEXT: br i1 true, label [[FOR_COND2:%.*]], label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: call void @callee(%mystruct* nonnull null) +; CHECK-NEXT: br label [[FOR_COND]] +; CHECK: for.cond2: +; CHECK-NEXT: br i1 false, label [[FOR_END:%.*]], label [[FOR_BODY2:%.*]] +; CHECK: for.body2: +; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [[MYSTRUCT:%.*]], %mystruct* null, i64 0, i32 1, i64 3 +; CHECK-NEXT: br label [[FOR_COND2]] +; CHECK: for.end: +; CHECK-NEXT: ret %mystruct* [[ARG:%.*]] +; +entry: + br label %for.cond + +for.cond: ; preds = %for.body, %entry + %phi = phi %mystruct* [ undef, %for.body ], [ null, %entry ] + %cond = icmp eq %mystruct* %phi, null + br i1 %cond, label %for.cond2, label %for.body + +for.body: ; preds = %for.cond + call void @callee(%mystruct* nonnull %phi) + br label %for.cond + +for.cond2: ; preds = %for.body2, %for.cond + %phi2 = phi %mystruct* [ undef, %for.body2 ], [ null, %for.cond ] + br i1 undef, label %for.end, label %for.body2 + +for.body2: ; preds = %for.cond2 + %arrayidx = getelementptr inbounds %mystruct, %mystruct* %phi2, i64 0, i32 1, i64 3 + br label %for.cond2 + +for.end: ; preds = %for.cond2 + ret %mystruct* %arg +} + +define %mystruct* @caller() { +; CHECK-LABEL: @caller( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CALL:%.*]] = call %mystruct* @myfunc(%mystruct* undef) +; CHECK-NEXT: ret %mystruct* [[CALL]] +; +entry: + %call = call %mystruct* @myfunc(%mystruct* undef) + ret %mystruct* %call +} + +declare void @callee(%mystruct*)