diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -122,6 +122,10 @@ void removeLatticeValueFor(Value *V); + /// ReplaceLatticeValueFor - Find which LatticeValueElement in ValueState + /// has Old as its ConstVal, and replace its ConstVal with New. + void replaceLatticeValueFor(Constant *Old, Constant *New); + const ValueLatticeElement &getLatticeValueFor(Value *V) const; /// getTrackedRetVals - Get the inferred return value map. diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -40,6 +40,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -519,6 +520,31 @@ // nodes in executable blocks we found values for. The function's entry // block is not part of BlocksToErase, so we have to handle it separately. for (BasicBlock *BB : BlocksToErase) { + + // If the BB we're about to erase (because it's unreachable) has its + // address taken, we need to: + if (BlockAddress *BA = BlockAddress::lookup(BB)) { + // Replace it with `inttoptr i32 1 to i8*` + LLVMContext &Context = BB->getParent()->getContext(); + ConstantInt *CI = ConstantInt::get(Type::getInt32Ty(Context), 1); + Constant *C = + ConstantExpr::getIntToPtr(CI, Type::getInt8PtrTy(Context)); + // Update the Solver to refer to the updated Constant. + for (User *U : BA->users()) { + // If the use is a ConstantExpr, construct a new ConstantExpr which + // uses `inttoptr i32 1 to i8*` instead of the `blockaddress`. + if (auto *CE = dyn_cast(U)) { + std::vector Ops; + for (Use &Op : CE->operands()) + Ops.push_back(Op == BA ? C : cast(Op)); + Constant *C = CE->getWithOperands(Ops); + Solver.replaceLatticeValueFor(CE, C); + } else if (auto *Old = dyn_cast(U)) + Solver.replaceLatticeValueFor(Old, C); + } + BA->replaceAllUsesWith(C); + } + NumInstRemoved += changeToUnreachable(BB->getFirstNonPHI(), /*PreserveLCSSA=*/false, &DTU); } diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -412,6 +412,14 @@ void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + void replaceLatticeValueFor(Constant *Old, Constant *New) { + // Walk the {Key, Val} entries of ValueState. Find Val == Old, replace w/ + // ValueState[Key] = New. + for (auto &VLE : ValueState) + if (VLE.second.isConstant() && VLE.second.getConstant() == Old) + ValueState[VLE.first] = ValueLatticeElement::get(New); + } + const ValueLatticeElement &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); @@ -1685,6 +1693,10 @@ return Visitor->removeLatticeValueFor(V); } +void SCCPSolver::replaceLatticeValueFor(Constant *Old, Constant *New) { + return Visitor->replaceLatticeValueFor(Old, New); +} + const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const { return Visitor->getLatticeValueFor(V); } diff --git a/llvm/test/Transforms/SCCP/dangling-block-address.ll b/llvm/test/Transforms/SCCP/dangling-block-address.ll --- a/llvm/test/Transforms/SCCP/dangling-block-address.ll +++ b/llvm/test/Transforms/SCCP/dangling-block-address.ll @@ -40,3 +40,21 @@ entry: ret i32 0 } + +; https://github.com/llvm/llvm-project/issues/54238 +; https://github.com/llvm/llvm-project/issues/54251 +; https://github.com/llvm/llvm-project/issues/54328 +define i32 @test1() { + %1 = bitcast i8* blockaddress(@test1, %redirected) to i64* + call void @set_return_addr(i64* %1) + ret i32 0 + +redirected: + ret i32 0 +} + +define internal void @set_return_addr(i64* %addr) { + %addr.addr = alloca i64*, i32 0, align 8 + store i64* %addr, i64** %addr.addr, align 8 + ret void +}