Index: unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- unittests/Analysis/ScalarEvolutionTest.cpp +++ unittests/Analysis/ScalarEvolutionTest.cpp @@ -15,6 +15,7 @@ #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -23,6 +24,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "gtest/gtest.h" namespace llvm { @@ -1439,5 +1441,167 @@ EXPECT_EQ(S2S->getExpressionSize(), 5u); } +/** + * Entry + * | + * V + * Header <------- + * / \ | + * A C | + * | | | + * B D | + * / \ / | + * E Latch---------- + * | + * V + * Exit + * + */ +TEST_F(ScalarEvolutionsTest, TestDTUpdater) { + Module M("SCEVComputeExpressionSize", Context); + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Context), {}, false); + Function *F = cast(M.getOrInsertFunction("func", FTy)); + Value *Cond = UndefValue::get(Type::getInt1Ty(Context)); + Type *T_int64 = Type::getInt64Ty(Context); + + BasicBlock *BB = BasicBlock::Create(Context, "BB", F); + IRBuilder<> Builder(BB); + + DT.reset(new DominatorTree(*F)); + + auto Unreachable = [&](BasicBlock *BB) { + Builder.SetInsertPoint(BB); + Builder.CreateUnreachable(); + }; + + auto Br = [&](BasicBlock *From, BasicBlock *To) { + Builder.SetInsertPoint(From); + Builder.CreateBr(To); + }; + + auto CondBr = [&](BasicBlock *From, BasicBlock *True, BasicBlock *False) { + Builder.SetInsertPoint(From); + Builder.CreateCondBr(Cond, True, False); + }; + + auto Ret = [&](BasicBlock *BB) { + Builder.SetInsertPoint(BB); + Builder.CreateRetVoid(); + }; + + auto Case = [&](SwitchInst *S, uint64_t V, BasicBlock *Dest) { + S->addCase(cast(ConstantInt::get(T_int64, V)), Dest); + }; + + + BasicBlock *BB1 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB2 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB3 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB4 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB5 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB6 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB7 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB8 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB9 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB10 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB11 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB12 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB13 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB14 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB15 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB16 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB17 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB18 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB19 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB20 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB21 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB22 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB23 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB24 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB25 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB26 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB27 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB28 = BasicBlock::Create(Context, "BB", F); + BasicBlock *BB29 = BasicBlock::Create(Context, "BB", F); + Br(BB, BB1); + CondBr(BB1, BB17, BB2); + Br(BB2, BB3); + Br(BB3, BB4); + Builder.SetInsertPoint(BB4); + SwitchInst *Switch4 = Builder.CreateSwitch(ConstantInt::get(T_int64, 0), BB5); + Case(Switch4, 1, BB19); + Case(Switch4, 2, BB18); + Builder.SetInsertPoint(BB5); + SwitchInst *Switch5 = Builder.CreateSwitch(UndefValue::get(T_int64), BB16); + Case(Switch5, 0, BB15); + Case(Switch5, 1, BB14); + Case(Switch5, 2, BB13); + Case(Switch5, 3, BB12); + Case(Switch5, 4, BB11); + Case(Switch5, 5, BB8); + Case(Switch5, 6, BB10); + Case(Switch5, 7, BB9); + Case(Switch5, 8, BB7); + Br(BB6, BB3); + Unreachable(BB7); + Builder.SetInsertPoint(BB8); + SwitchInst *Switch8 = Builder.CreateSwitch(UndefValue::get(T_int64), BB28); + Case(Switch8, 0, BB27); + Case(Switch8, 1, BB26); + Case(Switch8, 2, BB23); + Case(Switch8, 3, BB24); + Case(Switch8, 4, BB25); + Case(Switch8, 5, BB29); + Case(Switch8, 6, BB22); + Case(Switch8, 7, BB20); + Case(Switch8, 8, BB21); + Unreachable(BB9); + Unreachable(BB10); + Br(BB11, BB8); + Unreachable(BB12); + Unreachable(BB13); + Unreachable(BB14); + Unreachable(BB15); + Unreachable(BB16); + Ret(BB17); + Br(BB18, BB6); + Ret(BB19); + Unreachable(BB20); + Unreachable(BB21); + Unreachable(BB22); + Unreachable(BB23); + Unreachable(BB24); + Unreachable(BB25); + Unreachable(BB26); + Unreachable(BB27); + Unreachable(BB28); + Br(BB29, BB6); + + // Build the DT. + DT.reset(new DominatorTree(*F)); + + // Make the transform. Replace BB4's term. + DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Eager); + BasicBlock *Split = BasicBlock::Create(Context, "Split", F); + BB2->getTerminator()->eraseFromParent(); + CondBr(BB2, Split, BB19); + Br(Split, BB3); + DTU.deleteEdge(BB2, BB3); + DTU.insertEdge(BB2, Split); + DTU.insertEdge(Split, BB3); + DTU.insertEdge(BB2, BB19); + + // Replace BB4's term. + BB4->getTerminator()->eraseFromParent(); + Br(BB4, BB5); + DTU.deleteEdge(BB4, BB19); + DTU.deleteEdge(BB4, BB18); + DeleteDeadBlock(BB18, &DTU); + + DTU.flush(); + EXPECT_TRUE((*DT).verify()); +} + } // end anonymous namespace } // end namespace llvm