diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -136,6 +136,13 @@ MemorySSAUpdater *, ScalarEvolution *, ICFLoopSafetyInfo *, SinkAndHoistLICMFlags &, OptimizationRemarkEmitter *); +/// This function deletes dead loop nest. The requirement for this function is +/// the same as deleteDeadLoop(), except deleteDeadLoopNest() allow \p L to have +/// dead inner loops. This also updates the relevant analysis information in +/// \p DT, \p SE, and \p LI if pointers to those are provided. +void deleteDeadLoopNest(Loop *L, DominatorTree *DT, ScalarEvolution *SE, + LoopInfo *LI); + /// This function deletes dead loops. The caller of this function needs to /// guarantee that the loop is infact dead. /// The function requires a bunch or prerequisites to be present: @@ -146,7 +153,6 @@ /// This also updates the relevant analysis information in \p DT, \p SE, and \p /// LI if pointers to those are provided. /// It also updates the loop PM if an updater struct is provided. - void deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, LoopInfo *LI); diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -504,6 +504,16 @@ return Worklist; } +void llvm::deleteDeadLoopNest(Loop *L, DominatorTree *DT = nullptr, + ScalarEvolution *SE = nullptr, + LoopInfo *LI = nullptr) { + for (Loop *SubLoop : *L) + deleteDeadLoopNest(SubLoop, DT, SE, LI); + + assert(L->getSubLoops().empty() && "Expecting to be innermost loop"); + deleteDeadLoop(L, DT, SE, LI); +} + void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, ScalarEvolution *SE = nullptr, LoopInfo *LI = nullptr) { diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -15,6 +15,7 @@ FunctionComparatorTest.cpp IntegerDivisionTest.cpp LocalTest.cpp + LoopUtilsTest.cpp SizeOptsTest.cpp SSAUpdaterBulkTest.cpp UnrollLoopTest.cpp diff --git a/llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp b/llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp @@ -0,0 +1,90 @@ +//===- LoopUtilsTest.cpp - Unit tests for LoopUtils -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("LoopUtilsTests", errs()); + return Mod; +} + +static void run(Module &M, StringRef FuncName, + function_ref + Test) { + auto *F = M.getFunction(FuncName); + DominatorTree DT(*F); + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI(TLII); + AssumptionCache AC(*F); + LoopInfo LI(DT); + ScalarEvolution SE(*F, TLI, AC, DT, LI); + Test(*F, DT, SE, LI); +} + +TEST(LoopUtils, DeleteDeadLoopNest) { + LLVMContext C; + std::unique_ptr M = + parseIR(C, "define void @foo() {\n" + "entry:\n" + " br label %for.i\n" + "for.i:\n" + " %i = phi i64 [ 0, %entry ], [ %inc.i, %for.i.latch ]\n" + " br label %for.j\n" + "for.j:\n" + " %j = phi i64 [ 0, %for.i ], [ %inc.j, %for.j ]\n" + " %inc.j = add nsw i64 %j, 1\n" + " %cmp.j = icmp slt i64 %inc.j, 100\n" + " br i1 %cmp.j, label %for.j, label %for.k.preheader\n" + "for.k.preheader:\n" + " br label %for.k\n" + "for.k:\n" + " %k = phi i64 [ %inc.k, %for.k ], [ 0, %for.k.preheader ]\n" + " %inc.k = add nsw i64 %k, 1\n" + " %cmp.k = icmp slt i64 %inc.k, 100\n" + " br i1 %cmp.k, label %for.k, label %for.i.latch\n" + "for.i.latch:\n" + " %inc.i = add nsw i64 %i, 1\n" + " %cmp.i = icmp slt i64 %inc.i, 100\n" + " br i1 %cmp.i, label %for.i, label %for.end\n" + "for.end:\n" + " ret void\n" + "}\n"); + + run(*M, "foo", + [&](Function &F, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI) { + assert(LI.begin() != LI.end() && "Expecting loops in function F"); + Loop *L = *LI.begin(); + assert(L && L->getName() == "for.i" && "Expecting loop for.i"); + + deleteDeadLoopNest(L, &DT, &SE, &LI); + + assert(DT.verify(DominatorTree::VerificationLevel::Fast) && + "Expecting valid dominator tree"); + LI.verify(DT); + assert(LI.begin() == LI.end() && + "Expecting no loops left in function F"); + SE.verify(); + + Function::iterator FI = F.begin(); + BasicBlock *Entry = &*(FI++); + assert(Entry->getName() == "entry" && "Expecting BasicBlock entry"); + const BranchInst *BI = dyn_cast(Entry->getTerminator()); + assert(BI && "Expecting valid branch instruction"); + EXPECT_EQ(BI->getNumSuccessors(), (unsigned)1); + EXPECT_EQ(BI->getSuccessor(0)->getName(), "for.end"); + }); +}