diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -1975,6 +1975,10 @@ /// SmallDenseSet. SetVector, SmallSet> Factors; + /// Baseline cost for currently used SCEV. Drop the best solution by LSR if + /// the solution is not profitable. + Cost BaselineCost; + /// Interesting use types, to facilitate truncation reuse. SmallSetVector Types; @@ -3294,6 +3298,11 @@ BranchInst *ExitBranch = nullptr; bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &TLI); + // For calculating InitialSolutionCost + SmallPtrSet Regs; + DenseSet VisitedRegs; + DenseSet VisitedLSRUse; + for (const IVStrideUse &U : IU) { Instruction *UserInst = U.getUser(); // Skip IV users that are part of profitable IV Chains. @@ -3387,6 +3396,15 @@ LF.Offset = Offset; LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); + // Create SCEV as Formula for calculating baseline cost + if (!VisitedLSRUse.count(LUIdx) && !LF.isUseFullyOutsideLoop(L)) { + Formula F; + F.initialMatch(S, L, SE); + if (!BaselineCost.isLoser()) + BaselineCost.RateFormula(F, Regs, VisitedRegs, LU); + VisitedLSRUse.insert(LUIdx); + } + if (!LU.WidestFixupType || SE.getTypeSizeInBits(LU.WidestFixupType) < SE.getTypeSizeInBits(LF.OperandValToReplace->getType())) @@ -5162,6 +5180,17 @@ }); assert(Solution.size() == Uses.size() && "Malformed solution!"); + + if (!SolutionCost.isLess(const_cast(BaselineCost))) { + LLVM_DEBUG(dbgs() << "The baseline solution requires "; + BaselineCost.print(dbgs()); dbgs() << "\n";); + + LLVM_DEBUG( + dbgs() + << "Baseline solution is more profitable than chosen solution.\n"); + LLVM_DEBUG(dbgs() << "Dropping LSR chosen solution.\n"); + Solution.clear(); + } } /// Helper for AdjustInsertPositionForExpand. Climb up the dominator tree far as @@ -5706,7 +5735,8 @@ MSSAU(MSSAU), AMK(PreferredAddresingMode.getNumOccurrences() > 0 ? PreferredAddresingMode : TTI.getPreferredAddressingMode(L, &SE)), - Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr", false) { + Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr", false), + BaselineCost(L, SE, TTI, AMK) { // If LoopSimplify form is not available, stay out of trouble. if (!L->isLoopSimplifyForm()) return; diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-drop-solution.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-drop-solution.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-drop-solution.ll @@ -0,0 +1,45 @@ +; RUN: opt < %s -passes="loop-reduce" -S | FileCheck %s + +target datalayout = "e-m:e-p:64:64-i64:64-i128:128-n64-S128" +target triple = "riscv64-unknown-linux-gnu" + +define ptr @foo(ptr %a0, ptr %a1, i64 %a2) { +entry: + %0 = ptrtoint ptr %a0 to i64 + %1 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %a2, i64 0, i64 3) + %cmp.not = icmp eq i64 %1, %a2 + br i1 %cmp.not, label %if.end, label %if.then + +if.then: ; preds = %entry + %add = add i64 %0, %a2 + %sub = sub i64 %add, %1 + br label %do.body + +do.body: ; preds = %do.body, %if.then + %lsr.iv = phi i64 [ %lsr.iv.next, %do.body ], [ 0, %if.then ] + %2 = add i64 %0, %lsr.iv + %uglygep29 = getelementptr i8, ptr %a1, i64 %lsr.iv + %3 = tail call @llvm.riscv.vle.nxv64i8.i64( undef, ptr %uglygep29, i64 %1) + %4 = inttoptr i64 %2 to ptr + tail call void @llvm.riscv.vse.nxv64i8.i64( %3, ptr %4, i64 %1) + %lsr.iv.next = add i64 %lsr.iv, %1 + %5 = add i64 %0, %lsr.iv.next + %cmp2 = icmp ugt i64 %sub, %5 + br i1 %cmp2, label %do.body, label %do.end + +do.end: ; preds = %do.body ; EXitBlock + %6 = sub i64 %a2, %lsr.iv.next + %7 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %6, i64 0, i64 3) + %8 = add i64 %0, %lsr.iv.next + %uglygep = getelementptr i8, ptr %a1, i64 %lsr.iv.next + br label %if.end + +if.end: ; preds = %do.end, %entry + %a3.1 = phi i64 [ %8, %do.end ], [ %0, %entry ] + %t0.0 = phi i64 [ %7, %do.end ], [ %a2, %entry ] + %a1.addr.1 = phi ptr [ %uglygep, %do.end ], [ %a1, %entry ] + %9 = tail call @llvm.riscv.vle.nxv64i8.i64( undef, ptr %a1.addr.1, i64 %t0.0) + %10 = inttoptr i64 %a3.1 to ptr + tail call void @llvm.riscv.vse.nxv64i8.i64( %9, ptr %10, i64 %t0.0) + ret ptr %a0 +}