diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -3942,6 +3942,9 @@ ArrayRef IndirectDests, ArrayRef Args, ArrayRef Bundles, const Twine &NameStr); + /// Should the Indirect Destinations change, scan + update the Arg list. + void updateArgBlockAddresses(unsigned i, BasicBlock *B); + /// Compute the number of operands to allocate. static int ComputeNumOperands(int NumArgs, int NumIndirectDests, int NumBundleInputs = 0) { @@ -4079,7 +4082,7 @@ return cast(*(&Op<-1>() - getNumIndirectDests() - 1)); } BasicBlock *getIndirectDest(unsigned i) const { - return cast(*(&Op<-1>() - getNumIndirectDests() + i)); + return cast_or_null(*(&Op<-1>() - getNumIndirectDests() + i)); } SmallVector getIndirectDests() const { SmallVector IndirectDests; @@ -4091,6 +4094,7 @@ *(&Op<-1>() - getNumIndirectDests() - 1) = reinterpret_cast(B); } void setIndirectDest(unsigned i, BasicBlock *B) { + updateArgBlockAddresses(i, B); *(&Op<-1>() - getNumIndirectDests() + i) = reinterpret_cast(B); } @@ -4100,11 +4104,10 @@ return i == 0 ? getDefaultDest() : getIndirectDest(i - 1); } - void setSuccessor(unsigned idx, BasicBlock *NewSucc) { - assert(idx < getNumIndirectDests() + 1 && + void setSuccessor(unsigned i, BasicBlock *NewSucc) { + assert(i < getNumIndirectDests() + 1 && "Successor # out of range for callbr!"); - *(&Op<-1>() - getNumIndirectDests() -1 + idx) = - reinterpret_cast(NewSucc); + return i == 0 ? setDefaultDest(NewSucc) : setIndirectDest(i - 1, NewSucc); } unsigned getNumSuccessors() const { return getNumIndirectDests() + 1; } diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -822,6 +822,17 @@ setName(NameStr); } +void CallBrInst::updateArgBlockAddresses(unsigned i, BasicBlock *B) { + assert(getNumIndirectDests() > i && "IndirectDest # out of range for callbr"); + if (BasicBlock *OldBB = getIndirectDest(i)) { + BlockAddress *Old = BlockAddress::get(OldBB); + BlockAddress *New = BlockAddress::get(B); + for (unsigned ArgNo = 0, e = getNumArgOperands(); ArgNo != e; ++ArgNo) + if (dyn_cast(getArgOperand(ArgNo)) == Old) + setArgOperand(ArgNo, New); + } +} + CallBrInst::CallBrInst(const CallBrInst &CBI) : CallBase(CBI.Attrs, CBI.FTy, CBI.getType(), Instruction::CallBr, OperandTraits::op_end(this) - CBI.getNumOperands(), diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -1061,5 +1061,56 @@ FNeg->deleteValue(); } +TEST(InstructionsTest, CallBrInstruction) { + LLVMContext Context; + std::unique_ptr M = parseIR(Context, R"( +define void @foo() { +entry: + callbr void asm sideeffect "// XXX: ${0:l}", "X"(i8* blockaddress(@foo, %branch_test.exit)) + to label %land.rhs.i [label %branch_test.exit] + +land.rhs.i: + br label %branch_test.exit + +branch_test.exit: + %0 = phi i1 [ true, %entry ], [ false, %land.rhs.i ] + br i1 %0, label %if.end, label %if.then + +if.then: + ret void + +if.end: + ret void +} +)"); + Function *Foo = M->getFunction("foo"); + auto BBs = Foo->getBasicBlockList().begin(); + CallBrInst &CBI = cast(BBs->front()); + ++BBs; + ++BBs; + BasicBlock &BranchTestExit = *BBs; + ++BBs; + BasicBlock &IfThen = *BBs; + + // Test that setting the first indirect destination of callbr updates the dest + EXPECT_EQ(&BranchTestExit, CBI.getIndirectDest(0)); + CBI.setIndirectDest(0, &IfThen); + EXPECT_EQ(&IfThen, CBI.getIndirectDest(0)); + + // Further, test that changing the indirect destination updates the arg + // operand to use the block address of the new indirect destination basic + // block. This is a critical invariant of CallBrInst. + BlockAddress *IndirectBA = BlockAddress::get(CBI.getIndirectDest(0)); + BlockAddress *ArgBA = cast(CBI.getArgOperand(0)); + EXPECT_EQ(IndirectBA, ArgBA) + << "After setting the indirect destination, callbr had an indirect " + "destination of '" + << CBI.getIndirectDest(0)->getName() << "', but a argument of '" + << ArgBA->getBasicBlock()->getName() << "'. These should always match:\n" + << CBI; + EXPECT_EQ(IndirectBA->getBasicBlock(), &IfThen); + EXPECT_EQ(ArgBA->getBasicBlock(), &IfThen); +} + } // end anonymous namespace } // end namespace llvm