diff --git a/llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h b/llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h --- a/llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h +++ b/llvm/include/llvm/Transforms/Scalar/SimpleLoopUnswitch.h @@ -79,6 +79,13 @@ function_ref MapClassName2PassName); }; +// (TODO): use this FunctionPass for all non-trivial SLU +class FuncSimpleLoopUnswitchPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + /// Create the legacy pass object for the simple loop unswitcher. /// /// See the documentaion for `SimpleLoopUnswitchPass` for details. diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -359,6 +359,7 @@ FUNCTION_PASS("scalarizer", ScalarizerPass()) FUNCTION_PASS("separate-const-offset-from-gep", SeparateConstOffsetFromGEPPass()) FUNCTION_PASS("sccp", SCCPPass()) +FUNCTION_PASS("simple-loop-unswitch-func", FuncSimpleLoopUnswitchPass()) FUNCTION_PASS("sink", SinkingPass()) FUNCTION_PASS("slp-vectorizer", SLPVectorizerPass()) FUNCTION_PASS("slsr", StraightLineStrengthReducePass()) diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -3171,6 +3171,93 @@ return PA; } +PreservedAnalyses FuncSimpleLoopUnswitchPass::run(Function &F, + FunctionAnalysisManager &AM) { + LoopInfo &LI = AM.getResult(F); + auto &DT = AM.getResult(F); + auto &SE = AM.getResult(F); + auto MSSA = &AM.getResult(F).getMSSA(); + Optional MSSAU; + if (MSSA) { + MSSAU = MemorySSAUpdater(MSSA); + if (VerifyMemorySSA) + MSSA->verifyMemorySSA(); + } + bool Changed = false; + for (const auto &L : LI) { + Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, MSSAU.getPointer(), + /*PreserveLCSSA=*/false); + Changed |= formLCSSARecursively(*L, DT, &LI, &SE); + } + + SmallPriorityWorklist Worklist; + appendLoopsToWorklist(LI, Worklist); + while (!Worklist.empty()) { + Loop &L = *Worklist.pop_back_val(); + LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L + << "\n"); + + // Save the current loop name in a variable so that we can report it even + // after it has been deleted. + std::string LoopName = std::string(L.getName()); + + auto &LAM = AM.getResult(F).getManager(); + + auto UnswitchCB = [&L, &LAM, &LoopName, &Worklist]( + bool CurrentLoopValid, bool PartiallyInvariant, + ArrayRef NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + if (!NewLoops.empty()) + appendLoopsToWorklist(NewLoops, Worklist); + + // If the current loop remains valid, we should revisit it to catch any + // other unswitch opportunities. Otherwise, we need to mark it as deleted. + if (CurrentLoopValid) { + if (PartiallyInvariant) { + // Mark the new loop as partially unswitched, to avoid unswitching on + // the same condition again. + auto &Context = L.getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.partial.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L.getLoopID(), {"llvm.loop.unswitch.partial"}, + {DisableUnswitchMD}); + L.setLoopID(NewLoopID); + } else + Worklist.insert(&L); + } else + LAM.clear(L, LoopName); + }; + + auto DestroyLoopCB = [&LAM](Loop &L, StringRef Name) { + LAM.clear(L, Name); + }; + Changed |= unswitchLoop( + L, AM.getResult(F), + AM.getResult(F), AM.getResult(F), + AM.getResult(F), AM.getResult(F), + /*Trivial=*/false, /*NonTrivial=*/true, + UnswitchCB, &AM.getResult(F), + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, DestroyLoopCB); + } + if (!Changed) + return PreservedAnalyses::all(); + + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); + + // Historically this pass has had issues with the dominator tree so verify it + // in asserts builds. + assert(AM.getResult(F).verify( + DominatorTree::VerificationLevel::Fast)); + + auto PA = getLoopPassPreservedAnalyses(); + if (MSSA) + PA.preserve(); + return PA; +} + void SimpleLoopUnswitchPass::printPipeline( raw_ostream &OS, function_ref MapClassName2PassName) { static_cast *>(this)->printPipeline( diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-cost.ll b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-cost.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-cost.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-cost.ll @@ -3,6 +3,7 @@ ; RUN: opt -passes='loop(simple-loop-unswitch),verify' -unswitch-threshold=5 -S < %s | FileCheck %s ; RUN: opt -passes='loop-mssa(simple-loop-unswitch),verify' -unswitch-threshold=5 -S < %s | FileCheck %s ; RUN: opt -simple-loop-unswitch -enable-nontrivial-unswitch -unswitch-threshold=5 -verify-memoryssa -S < %s | FileCheck %s +; RUN: opt -passes='simple-loop-unswitch-func,verify' -unswitch-threshold=5 -S < %s | FileCheck %s declare void @a() declare void @b() diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-freeze.ll @@ -2,6 +2,7 @@ ; RUN: opt -freeze-loop-unswitch-cond -passes='loop(simple-loop-unswitch),verify' -S < %s | FileCheck %s ; RUN: opt -freeze-loop-unswitch-cond -passes='loop-mssa(simple-loop-unswitch),verify' -S < %s | FileCheck %s ; RUN: opt -freeze-loop-unswitch-cond -simple-loop-unswitch -enable-nontrivial-unswitch -verify-memoryssa -S < %s | FileCheck %s +; RUN: opt -freeze-loop-unswitch-cond -passes='simple-loop-unswitch-func,verify' -S < %s | FileCheck %s declare i32 @a() declare i32 @b()