Index: lib/Transforms/Utils/BasicBlockUtils.cpp =================================================================== --- lib/Transforms/Utils/BasicBlockUtils.cpp +++ lib/Transforms/Utils/BasicBlockUtils.cpp @@ -504,6 +504,9 @@ // Insert dummy values as the incoming value. for (BasicBlock::iterator I = BB->begin(); isa(I); ++I) cast(I)->addIncoming(UndefValue::get(I->getType()), NewBB); + + if (DT && BB == DT->getRootNode()->getBlock()) + DT->setNewRoot(NewBB); return NewBB; } Index: unittests/Transforms/Utils/BasicBlockUtils.cpp =================================================================== --- /dev/null +++ unittests/Transforms/Utils/BasicBlockUtils.cpp @@ -0,0 +1,51 @@ +//===- BasicBlockUtils.cpp - Unit tests for BasicBlockUtils ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("BasicBlockUtilsTests", errs()); + return Mod; +} + +TEST(BasicBlockUtils, SplitBlockPredecessors) { + LLVMContext C; + + std::unique_ptr M = parseIR( + C, + "define void @basic_func(i1 %cond) {\n" + "entry:\n" + " br i1 %cond, label %bb0, label %bb1\n" + "bb0:\n" + " ret void\n" + "bb1:\n" + " ret void\n" + "}\n" + "\n" + ); + + auto *F = M->getFunction("basic_func"); + DominatorTree DT(*F); + + // Make sure the dominator tree is properly updated if calling this on the + // entry block. + SplitBlockPredecessors(&F->getEntryBlock(), {}, "split.entry", &DT); + EXPECT_TRUE(DT.verify()); +} Index: unittests/Transforms/Utils/CMakeLists.txt =================================================================== --- unittests/Transforms/Utils/CMakeLists.txt +++ unittests/Transforms/Utils/CMakeLists.txt @@ -8,6 +8,7 @@ add_llvm_unittest(UtilsTests ASanStackFrameLayoutTest.cpp + BasicBlockUtils.cpp Cloning.cpp CodeExtractor.cpp FunctionComparator.cpp Index: unittests/Transforms/Utils/Local.cpp =================================================================== --- unittests/Transforms/Utils/Local.cpp +++ unittests/Transforms/Utils/Local.cpp @@ -100,7 +100,7 @@ EXPECT_EQ(3U, BB->size()); } -std::unique_ptr parseIR(LLVMContext &C, const char *IR) { +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { SMDiagnostic Err; std::unique_ptr Mod = parseAssemblyString(IR, Err, C); if (!Mod)