Index: llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h =================================================================== --- llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h +++ llvm/include/llvm/Transforms/Utils/SimplifyCFGOptions.h @@ -29,6 +29,7 @@ bool SinkCommonInsts = false; bool SimplifyCondBranch = true; bool FoldTwoEntryPHINode = true; + int SwitchRemovalThreshold = 0; AssumptionCache *AC = nullptr; @@ -57,6 +58,10 @@ SinkCommonInsts = B; return *this; } + SimplifyCFGOptions &setSwitchRemovalThreshold(int B) { + SwitchRemovalThreshold = B; + return *this; + } SimplifyCFGOptions &setAssumptionCache(AssumptionCache *Cache) { AC = Cache; return *this; Index: llvm/lib/Passes/PassBuilderPipelines.cpp =================================================================== --- llvm/lib/Passes/PassBuilderPipelines.cpp +++ llvm/lib/Passes/PassBuilderPipelines.cpp @@ -182,6 +182,19 @@ "enable-merge-functions", cl::init(false), cl::Hidden, cl::desc("Enable function merging as part of the optimization pipeline")); +static cl::opt +RemoveSwitchBlocks("remove-switch-blocks", cl::init(true), cl::Hidden, + cl::desc("Convert switch blocks into a branch sequence " + "prior to vectorization.")); + +// This value determines the point at which we stop removing switch statements +// before the vectorizer pass. Removing switch blocks and replacing them with +// compares and branches allows architectures that support predication to +// vectorize. This value was chosen initially because it was needed to +// vectorise a TSVC loop, however this value can be tweaked over time if higher +// numbers are found to improve performance. +static const int RemoveSwitchCaseThreshold = 4; + PipelineTuningOptions::PipelineTuningOptions() { LoopInterleaving = true; LoopVectorization = true; @@ -965,6 +978,14 @@ /// TODO: Should LTO cause any differences to this set of passes? void PassBuilder::addVectorPasses(OptimizationLevel Level, FunctionPassManager &FPM, bool IsFullLTO) { + + // Removing switch blocks and replacing them with compares and branches + // allows architectures that support predication to vectorize. + if (RemoveSwitchBlocks) + FPM.addPass(SimplifyCFGPass(SimplifyCFGOptions() + .setSwitchRemovalThreshold( + RemoveSwitchCaseThreshold))); + FPM.addPass(LoopVectorizePass( LoopVectorizeOptions(!PTO.LoopInterleaving, !PTO.LoopVectorization))); Index: llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -55,6 +55,11 @@ "bonus-inst-threshold", cl::Hidden, cl::init(1), cl::desc("Control the number of bonus instructions (default = 1)")); +static cl::opt UserSwitchRemovalThreshold( + "switch-removal-threshold", cl::Hidden, cl::init(0), + cl::desc("Set the threshold for the number of switch cases where we" + "convert switch blocks to branches and compares")); + static cl::opt UserKeepLoops( "keep-loops", cl::Hidden, cl::init(true), cl::desc("Preserve canonical loop structure (default = true)")); @@ -75,7 +80,6 @@ "sink-common-insts", cl::Hidden, cl::init(false), cl::desc("Sink common instructions (default = false)")); - STATISTIC(NumSimpl, "Number of blocks simplified"); static bool tailMergeBlocksWithSimilarFunctionTerminators(Function &F, @@ -311,6 +315,8 @@ Options.HoistCommonInsts = UserHoistCommonInsts; if (UserSinkCommonInsts.getNumOccurrences()) Options.SinkCommonInsts = UserSinkCommonInsts; + if (UserSwitchRemovalThreshold.getNumOccurrences()) + Options.SwitchRemovalThreshold = UserSwitchRemovalThreshold; } SimplifyCFGPass::SimplifyCFGPass() : Options() { Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -6180,6 +6180,107 @@ return true; } +static bool turnSmallSwitchIntoICmps(SwitchInst *SI, IRBuilder<> &Builder) { + assert(SI->getNumCases() > 1 && "Degenerate switch?"); + + // Check to see if we have a genuine default, reachable block with executable + // instructions in them. + bool HasDefault = + !isa(SI->getDefaultDest()->getFirstNonPHIOrDbg()); + + BasicBlock *DefaultBlock = HasDefault ? SI->getDefaultDest() : nullptr; + SmallVector UniqueBlocks; + BasicBlock *BB = SI->getParent(); + + // We don't attempt to deal with ranges here. + for (auto Case : SI->cases()) { + BasicBlock *Dest = Case.getCaseSuccessor(); + for (auto Block : UniqueBlocks) { + // We don't support multiple cases with the same dest + if (Block == Dest) + return false; + } + UniqueBlocks.push_back(Dest); + } + + // Record the total weighting for this switch block. + uint64_t TotalWeight = 0; + SmallVector Weights; + if (HasBranchWeights(SI) && + Weights.size() == (SI->getNumCases() + 1)) { + GetBranchWeights(SI, Weights); + for (auto W : Weights) + TotalWeight += W; + } + + BasicBlock *OtherDest = nullptr; + uint64_t FalseWeight = TotalWeight; + for (auto CI : SI->cases()) { + BasicBlock *TrueDest = CI.getCaseSuccessor(); + Value *Cmp = + Builder.CreateICmpEQ(SI->getCondition(), CI.getCaseValue(), "switch"); + + // Walk through PHIs in TrueDest and see which ones came + // from the switch block, then remap them. + if (OtherDest != nullptr) { + for (PHINode &PN : TrueDest->phis()) { + for (auto PB : PN.blocks()) { + if (PB == BB) { + Value *V = PN.getIncomingValueForBlock(BB); + PN.removeIncomingValue(BB, false); + PN.addIncoming(V, OtherDest); + } + } + } + } + + BasicBlock *MoveAfter = OtherDest ? OtherDest : BB; + OtherDest = + BasicBlock::Create(BB->getContext(), + BB->getName() + ".switch", + BB->getParent(), BB); + OtherDest->moveAfter(MoveAfter); + + Instruction *I = Builder.CreateCondBr(Cmp, TrueDest, OtherDest); + // Update weight for the newly-created conditional branch. + if (TotalWeight) { + int index = CI.getSuccessorIndex(); + FalseWeight -= Weights[index]; + setBranchWeights(I, Weights[index], FalseWeight); + } + Builder.SetInsertPoint(OtherDest); + + } + + if (DefaultBlock) { + // The last block we created is empty, which is bad mmm'k! + Builder.CreateBr(DefaultBlock); + + // The block that we jump to may have had some PHIs that came + // from the block containing the switch statement. Now that we + // are removing the switch statement we need to fix up the PHIs. + + // Walk through PHIs in DefaultBlock and see which ones came + // from the switch block, then remap them. + for (PHINode &PN : DefaultBlock->phis()) { + for (auto PB : PN.blocks()) { + if (PB == BB) { + Value *V = PN.getIncomingValueForBlock(BB); + PN.removeIncomingValue(BB, false); + PN.addIncoming(V, OtherDest); + } + } + } + } else + Builder.CreateUnreachable(); + + // Drop the switch. + SI->eraseFromParent(); + + Builder.SetInsertPoint(BB); + + return true; +} bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { BasicBlock *BB = SI->getParent(); @@ -6202,8 +6303,14 @@ return requestResimplify(); } + unsigned NumCases = SI->getNumCases(); + bool RemoveSwitches = Options.SwitchRemovalThreshold >= NumCases; + + if (RemoveSwitches && turnSmallSwitchIntoICmps(SI, Builder)) + return simplifyCFG(BB, TTI, DTU, Options) | true; + // Try to transform the switch into an icmp and a branch. - if (TurnSwitchRangeIntoICmp(SI, Builder)) + if (!RemoveSwitches && TurnSwitchRangeIntoICmp(SI, Builder)) return requestResimplify(); // Remove unreachable cases. @@ -6448,16 +6555,19 @@ if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) return requestResimplify(); - // This block must be empty, except for the setcond inst, if it exists. - // Ignore dbg and pseudo intrinsics. - auto I = BB->instructionsWithoutDebug(true).begin(); - if (&*I == BI) { - if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return requestResimplify(); - } else if (&*I == cast(BI->getCondition())) { - ++I; - if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return requestResimplify(); + bool RemoveSwitches = Options.SwitchRemovalThreshold > 0; + if (!RemoveSwitches) { + // This block must be empty, except for the setcond inst, if it exists. + // Ignore dbg and pseudo intrinsics. + auto I = BB->instructionsWithoutDebug(true).begin(); + if (&*I == BI) { + if (FoldValueComparisonIntoPredecessors(BI, Builder)) + return requestResimplify(); + } else if (&*I == cast(BI->getCondition())) { + ++I; + if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) + return requestResimplify(); + } } } Index: llvm/test/Other/new-pm-lto-defaults.ll =================================================================== --- llvm/test/Other/new-pm-lto-defaults.ll +++ llvm/test/Other/new-pm-lto-defaults.ll @@ -108,6 +108,7 @@ ; CHECK-O23SZ-NEXT: Running pass: LoopDeletionPass on Loop ; CHECK-O23SZ-NEXT: Running pass: LoopFullUnrollPass on Loop ; CHECK-O23SZ-NEXT: Running pass: LoopDistributePass on foo +; CHECK-O23SZ-NEXT: Running pass: SimplifyCFGPass on foo ; CHECK-O23SZ-NEXT: Running pass: LoopVectorizePass on foo ; CHECK-O23SZ-NEXT: Running analysis: BlockFrequencyAnalysis on foo ; CHECK-O23SZ-NEXT: Running analysis: BranchProbabilityAnalysis on foo Index: llvm/test/Transforms/LoopVectorize/switch_vectorization.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopVectorize/switch_vectorization.ll @@ -0,0 +1,80 @@ +; RUN: opt -simplifycfg -loop-vectorize -remove-switch-blocks=true -scalable-vectorization=on -switch-removal-threshold=3 -pass-remarks='loop-vectorize' %s -S 2>&1 | FileCheck %s --check-prefix=CHECK-REMARKS + +; Convert switch blocks into branch sequence, which allows architectures that +; support predication to vectorize. + +; CHECK-REMARKS: remark: {{.*}} vectorized loop + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-unknown-linux-gnu" + +@indx = dso_local global [8000 x i32] zeroinitializer, align 64 +@b = dso_local global [8000 x float] zeroinitializer, align 64 +@a = dso_local global [8000 x float] zeroinitializer, align 64 +@c = dso_local global [8000 x float] zeroinitializer, align 64 + +; Function Attrs: nofree norecurse nosync nounwind uwtable vscale_range(2,2) +define void @s442() #0 { +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.inc + ret void + +for.body: ; preds = %for.inc, %entry + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.inc ] + %arrayidx = getelementptr inbounds [8000 x i32], [8000 x i32]* @indx, i64 0, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4, !tbaa !10 + switch i32 %0, label %for.inc [ + i32 1, label %if.then1 + i32 2, label %if.then2 + ] + +if.then1: ; preds = %for.body + %arrayidx3 = getelementptr inbounds [8000 x float], [8000 x float]* @b, i64 0, i64 %indvars.iv + %1 = load float, float* %arrayidx3, align 4, !tbaa !14 + %mul = fmul fast float %1, %1 + %arrayidx7 = getelementptr inbounds [8000 x float], [8000 x float]* @a, i64 0, i64 %indvars.iv + %2 = load float, float* %arrayidx7, align 4, !tbaa !14 + %add = fadd fast float %2, %mul + store float %add, float* %arrayidx7, align 4, !tbaa !14 + br label %for.inc + +if.then2: ; preds = %for.body + %arrayidx13 = getelementptr inbounds [8000 x float], [8000 x float]* @c, i64 0, i64 %indvars.iv + %3 = load float, float* %arrayidx13, align 4, !tbaa !14 + %mul16 = fmul fast float %3, %3 + %arrayidx18 = getelementptr inbounds [8000 x float], [8000 x float]* @a, i64 0, i64 %indvars.iv + %4 = load float, float* %arrayidx18, align 4, !tbaa !14 + %add19 = fadd fast float %4, %mul16 + store float %add19, float* %arrayidx18, align 4, !tbaa !14 + br label %for.inc + +for.inc: ; preds = %if.then4, %if.then3, %if.then2, %if.then1, %for.body + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, 8000 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !llvm.loop !16 +} + +attributes #0 = { nofree norecurse nosync nounwind uwtable vscale_range(2,2) "approx-func-fp-math"="true" "frame-pointer"="non-leaf" "min-legal-vector-width"="0" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="generic" "target-features"="+neon,+sve,+v8a" "unsafe-fp-math"="true" } + +!llvm.module.flags = !{!2, !3, !4, !5, !6, !7, !8, !9} + + +!1 = !DIFile(filename: "test-branch.c", directory: "/home/zyd/test/s442") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 1, !"wchar_size", i32 4} +!4 = !{i32 1, !"branch-target-enforcement", i32 0} +!5 = !{i32 1, !"sign-return-address", i32 0} +!6 = !{i32 1, !"sign-return-address-all", i32 0} +!7 = !{i32 1, !"sign-return-address-with-bkey", i32 0} +!8 = !{i32 7, !"uwtable", i32 1} +!9 = !{i32 7, !"frame-pointer", i32 1} +!10 = !{!11, !11, i64 0} +!11 = !{!"int", !12, i64 0} +!12 = !{!"omnipotent char", !13, i64 0} +!13 = !{!"Simple C/C++ TBAA"} +!14 = !{!15, !15, i64 0} +!15 = !{!"float", !12, i64 0} +!16 = distinct !{!16, !17} +!17 = !{!"llvm.loop.mustprogress"}