diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -272,6 +272,7 @@ IRBuilder<> &Builder); bool HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly); + bool hoistCommonInstsOnSwitch(SwitchInst *SI); bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB); bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond, BasicBlock *TrueBB, BasicBlock *FalseBB, @@ -1748,6 +1749,74 @@ return Changed; } +/// Similar to HoistThenElseCodeToIf, but simpler. We don't care about +/// instruction reordering and debugging instructions. +bool SimplifyCFGOpt::hoistCommonInstsOnSwitch(SwitchInst *SI) { + unsigned NumSuccs = SI->getNumSuccessors(); + if (NumSuccs < 2) { + return false; + } + auto *BB1 = SI->getSuccessor(0); + if (BB1->hasAddressTaken() || !BB1->getSinglePredecessor()) { + return false; + } + + BasicBlock::iterator BB1Itr = BB1->begin(); + + Instruction *I1 = &*BB1Itr++; + if (isa(I1)) + return false; + + BasicBlock *BIParent = SI->getParent(); + + bool Changed = false; + + auto _ = make_scope_exit([&]() { + if (Changed) + NumHoistCommonCode += NumSuccs; + }); + + SmallVector OtherSuccIters; + OtherSuccIters.reserve(NumSuccs - 1); + + for (unsigned Idx = 1; Idx < NumSuccs; ++Idx) { + auto *Succ = SI->getSuccessor(Idx); + if (Succ->hasAddressTaken() || !Succ->getSinglePredecessor()) { + return false; + } + BasicBlock::iterator SuccItr = Succ->begin(); + OtherSuccIters.push_back(SuccItr); + } + + for (;;) { + // We expect simplifyUncondBranch to complete the optimization of the + // termination instruction. + if (I1->isTerminator()) { + return Changed; + } + for (auto SuccIter : OtherSuccIters) { + Instruction *I2 = &*SuccIter; + if (!I1->isIdenticalToWhenDefined(I2) || + !shouldHoistCommonInstructions(I1, I2, TTI)) { + return Changed; + } + } + for (auto SuccIter : OtherSuccIters) { + Instruction *I2 = &*SuccIter++; + BIParent->splice(SI->getIterator(), BB1, I1->getIterator()); + if (!I2->use_empty()) + I2->replaceAllUsesWith(I1); + I1->andIRFlags(I2); + combineMetadataForCSE(I1, I2, true); + I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); + I2->eraseFromParent(); + } + I1 = &*BB1Itr++; + Changed = true; + ++NumHoistCommonInstrs; + } +} + // Check lifetime markers. static bool isLifeTimeMarker(const Instruction *I) { if (auto II = dyn_cast(I)) { @@ -6794,6 +6863,9 @@ if (ReduceSwitchRange(SI, Builder, DL, TTI)) return requestResimplify(); + if (HoistCommon && hoistCommonInstsOnSwitch(SI)) + return requestResimplify(); + return false; } diff --git a/llvm/test/Transforms/SimplifyCFG/hoist-common-code-on-switch.ll b/llvm/test/Transforms/SimplifyCFG/hoist-common-code-on-switch.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/hoist-common-code-on-switch.ll @@ -0,0 +1,31 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes='simplifycfg' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s + +define i1 @foo(i64 %a, i64 %b, i64 %c) unnamed_addr { +; CHECK-LABEL: @foo( +; CHECK-NEXT: start: +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i64 [[B:%.*]], [[C:%.*]] +; CHECK-NEXT: ret i1 [[TMP0]] +; +start: + switch i64 %a, label %bb0 [ + i64 1, label %bb1 + i64 2, label %bb2 + ] + +bb0: ; preds = %start + %0 = icmp eq i64 %b, %c + br label %exit + +bb1: ; preds = %start + %1 = icmp eq i64 %b, %c + br label %exit + +bb2: ; preds = %start + %2 = icmp eq i64 %b, %c + br label %exit + +exit: ; preds = %bb2, %bb1, %bb0 + %result = phi i1 [ %0, %bb0 ], [ %1, %bb1 ], [ %2, %bb2 ] + ret i1 %result +}