Index: lib/Transforms/Scalar/LoopStrengthReduce.cpp =================================================================== --- lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -261,9 +261,13 @@ /// canonical representation of a formula is /// 1. BaseRegs.size > 1 implies ScaledReg != NULL and /// 2. ScaledReg != NULL implies Scale != 1 || !BaseRegs.empty(). + /// 3. The reg containing recurrent expr related with currect loop in the + /// formula should be put in the ScaledReg. /// #1 enforces that the scaled register is always used when at least two /// registers are needed by the formula: e.g., reg1 + reg2 is reg1 + 1 * reg2. /// #2 enforces that 1 * reg is reg. + /// #3 ensures invariant regs with respect to current loop can be combined + /// together in LSR codegen. /// This invariant can be temporarly broken while building a formula. /// However, every formula inserted into the LSRInstance must be in canonical /// form. @@ -284,9 +288,9 @@ void initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE); - bool isCanonical() const; + bool isCanonical(const Loop &L) const; - void canonicalize(); + void canonicalize(const Loop &L); bool unscale(); @@ -376,16 +380,35 @@ BaseRegs.push_back(Sum); HasBaseReg = true; } - canonicalize(); + canonicalize(*L); } /// \brief Check whether or not this formula statisfies the canonical /// representation. /// \see Formula::BaseRegs. -bool Formula::isCanonical() const { - if (ScaledReg) - return Scale != 1 || !BaseRegs.empty(); - return BaseRegs.size() <= 1; +bool Formula::isCanonical(const Loop &L) const { + if (!ScaledReg) + return BaseRegs.size() <= 1; + + if (Scale != 1) + return true; + + if (Scale == 1 && BaseRegs.empty()) + return false; + + const SCEVAddRecExpr *SAR = dyn_cast(ScaledReg); + if (SAR && SAR->getLoop() == &L) + return true; + + // If ScaledReg is not a recurrent expr, or it is but its loop is not current + // loop, meanwhile BaseRegs contains a recurrent expr reg related with current + // loop, we want to swap the reg in BaseRegs with ScaledReg. + auto I = + find_if(make_range(BaseRegs.begin(), BaseRegs.end()), [&](const SCEV *S) { + return isa(S) && + (cast(S)->getLoop() == &L); + }); + return I == BaseRegs.end(); } /// \brief Helper method to morph a formula into its canonical representation. @@ -394,21 +417,33 @@ /// field. Otherwise, we would have to do special cases everywhere in LSR /// to treat reg1 + reg2 + ... the same way as reg1 + 1*reg2 + ... /// On the other hand, 1*reg should be canonicalized into reg. -void Formula::canonicalize() { - if (isCanonical()) +void Formula::canonicalize(const Loop &L) { + if (isCanonical(L)) return; // So far we did not need this case. This is easy to implement but it is // useless to maintain dead code. Beside it could hurt compile time. assert(!BaseRegs.empty() && "1*reg => reg, should not be needed."); + // Keep the invariant sum in BaseRegs and one of the variant sum in ScaledReg. - ScaledReg = BaseRegs.back(); - BaseRegs.pop_back(); - Scale = 1; - size_t BaseRegsSize = BaseRegs.size(); - size_t Try = 0; - // If ScaledReg is an invariant, try to find a variant expression. - while (Try < BaseRegsSize && !isa(ScaledReg)) - std::swap(ScaledReg, BaseRegs[Try++]); + if (!ScaledReg) { + ScaledReg = BaseRegs.back(); + BaseRegs.pop_back(); + Scale = 1; + } + + // If ScaledReg is an invariant with respect to L, find the reg from + // BaseRegs containing the recurrent expr related with Loop L. Swap the + // reg with ScaledReg. + const SCEVAddRecExpr *SAR = dyn_cast(ScaledReg); + if (!SAR || SAR->getLoop() != &L) { + auto I = find_if(make_range(BaseRegs.begin(), BaseRegs.end()), + [&](const SCEV *S) { + return isa(S) && + (cast(S)->getLoop() == &L); + }); + if (I != BaseRegs.end()) + std::swap(ScaledReg, *I); + } } /// \brief Get rid of the scale in the formula. @@ -839,7 +874,8 @@ const LSRUse &LU, const Formula &F); // Get the cost of the scaling factor used in F for LU. static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, - const LSRUse &LU, const Formula &F); + const LSRUse &LU, const Formula &F, + const Loop &L); namespace { @@ -1032,7 +1068,7 @@ } bool HasFormulaWithSameRegs(const Formula &F) const; - bool InsertFormula(const Formula &F); + bool InsertFormula(const Formula &F, const Loop &L); void DeleteFormula(Formula &F); void RecomputeRegs(size_t LUIdx, RegUseTracker &Reguses); @@ -1114,7 +1150,7 @@ ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, SmallPtrSetImpl *LoserRegs) { - assert(F.isCanonical() && "Cost is accurate only for canonical formula"); + assert(F.isCanonical(*L) && "Cost is accurate only for canonical formula"); // Tally up the registers. if (const SCEV *ScaledReg = F.ScaledReg) { if (VisitedRegs.count(ScaledReg)) { @@ -1145,7 +1181,7 @@ NumBaseAdds += (F.UnfoldedOffset != 0); // Accumulate non-free scaling amounts. - ScaleCost += getScalingFactorCost(TTI, LU, F); + ScaleCost += getScalingFactorCost(TTI, LU, F, *L); // Tally up the non-zero immediates. for (const LSRFixup &Fixup : LU.Fixups) { @@ -1266,8 +1302,8 @@ /// If the given formula has not yet been inserted, add it to the list, and /// return true. Return false otherwise. The formula must be in canonical form. -bool LSRUse::InsertFormula(const Formula &F) { - assert(F.isCanonical() && "Invalid canonical representation"); +bool LSRUse::InsertFormula(const Formula &F, const Loop &L) { + assert(F.isCanonical(L) && "Invalid canonical representation"); if (!Formulae.empty() && RigidFormula) return false; @@ -1436,7 +1472,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, int64_t MinOffset, int64_t MaxOffset, LSRUse::KindType Kind, MemAccessTy AccessTy, - const Formula &F) { + const Formula &F, const Loop &L) { // For the purpose of isAMCompletelyFolded either having a canonical formula // or a scale not equal to zero is correct. // Problems may arise from non canonical formulae having a scale == 0. @@ -1444,7 +1480,7 @@ // However, when we generate the scaled formulae, we first check that the // scaling factor is profitable before computing the actual ScaledReg for // compile time sake. - assert((F.isCanonical() || F.Scale != 0)); + assert((F.isCanonical(L) || F.Scale != 0)); return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, F.BaseGV, F.BaseOffset, F.HasBaseReg, F.Scale); } @@ -1479,14 +1515,15 @@ } static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, - const LSRUse &LU, const Formula &F) { + const LSRUse &LU, const Formula &F, + const Loop &L) { if (!F.Scale) return 0; // If the use is not completely folded in that instruction, we will have to // pay an extra cost only for scale != 1. if (!isAMCompletelyFolded(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, - LU.AccessTy, F)) + LU.AccessTy, F, L)) return F.Scale != 1; switch (LU.Kind) { @@ -3070,7 +3107,8 @@ // Do not insert formula that we will not be able to expand. assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F) && "Formula is illegal"); - if (!LU.InsertFormula(F)) + + if (!LU.InsertFormula(F, *L)) return false; CountRegisters(F, LUIdx); @@ -3306,7 +3344,7 @@ F.BaseRegs.push_back(*J); // We may have changed the number of register in base regs, adjust the // formula accordingly. - F.canonicalize(); + F.canonicalize(*L); if (InsertFormula(LU, LUIdx, F)) // If that formula hadn't been seen before, recurse to find more like @@ -3318,7 +3356,7 @@ /// Split out subexpressions from adds and the bases of addrecs. void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx, Formula Base, unsigned Depth) { - assert(Base.isCanonical() && "Input must be in the canonical form"); + assert(Base.isCanonical(*L) && "Input must be in the canonical form"); // Arbitrarily cap recursion to protect compile time. if (Depth >= 3) return; @@ -3359,7 +3397,7 @@ // rather than proceed with zero in a register. if (!Sum->isZero()) { F.BaseRegs.push_back(Sum); - F.canonicalize(); + F.canonicalize(*L); (void)InsertFormula(LU, LUIdx, F); } } @@ -3416,7 +3454,7 @@ F.ScaledReg = nullptr; } else F.deleteBaseReg(F.BaseRegs[Idx]); - F.canonicalize(); + F.canonicalize(*L); } else if (IsScaledReg) F.ScaledReg = NewG; else @@ -3579,10 +3617,10 @@ if (LU.Kind == LSRUse::ICmpZero && !Base.HasBaseReg && Base.BaseOffset == 0 && !Base.BaseGV) continue; - // For each addrec base reg, apply the scale, if possible. - for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) - if (const SCEVAddRecExpr *AR = - dyn_cast(Base.BaseRegs[i])) { + // For each addrec base reg, if its loop is current loop, apply the scale. + for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) { + const SCEVAddRecExpr *AR = dyn_cast(Base.BaseRegs[i]); + if (AR && AR->getLoop() == L) { const SCEV *FactorS = SE.getConstant(IntTy, Factor); if (FactorS->isZero()) continue; @@ -3601,6 +3639,7 @@ (void)InsertFormula(LU, LUIdx, F); } } + } } } @@ -3780,7 +3819,7 @@ continue; // OK, looks good. - NewF.canonicalize(); + NewF.canonicalize(*this->L); (void)InsertFormula(LU, LUIdx, NewF); } else { // Use the immediate in a base register. @@ -3812,7 +3851,7 @@ goto skip_formula; // Ok, looks good. - NewF.canonicalize(); + NewF.canonicalize(*this->L); (void)InsertFormula(LU, LUIdx, NewF); break; skip_formula:; Index: test/Transforms/LoopStrengthReduce/X86/canonical.ll =================================================================== --- test/Transforms/LoopStrengthReduce/X86/canonical.ll +++ test/Transforms/LoopStrengthReduce/X86/canonical.ll @@ -0,0 +1,65 @@ +; RUN: opt -mtriple=x86_64-unknown-linux-gnu -loop-reduce -S < %s | FileCheck %s +; Check LSR formula canonicalization will put loop invariant regs before +; induction variable of current loop, so exprs involving loop invariant regs +; can be promoted outside of current loop. + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +define void @foo(i32 %size, i32 %nsteps, i8* nocapture %maxarray, i8* nocapture readnone %buffer, i32 %init) local_unnamed_addr #0 { +entry: + %cmp25 = icmp sgt i32 %nsteps, 0 + br i1 %cmp25, label %for.cond1.preheader.lr.ph, label %for.end12 + +for.cond1.preheader.lr.ph: ; preds = %entry + %cmp223 = icmp sgt i32 %size, 1 + %t0 = sext i32 %init to i64 + %wide.trip.count = zext i32 %size to i64 + %wide.trip.count31 = zext i32 %nsteps to i64 + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.inc10, %for.cond1.preheader.lr.ph + %indvars.iv28 = phi i64 [ 0, %for.cond1.preheader.lr.ph ], [ %indvars.iv.next29, %for.inc10 ] + br i1 %cmp223, label %for.body3.lr.ph, label %for.inc10 + +for.body3.lr.ph: ; preds = %for.cond1.preheader + %t1 = add nsw i64 %indvars.iv28, %t0 + %t2 = trunc i64 %indvars.iv28 to i8 + br label %for.body3 + +; Make sure loop invariant items are grouped together so that load address can +; be represented in one getelementptr. +; CHECK-LABEL: for.body3: +; CHECK-NEXT: [[LSR:%[^,]+]] = phi i64 [ 1, %for.body3.lr.ph ], [ {{.*}}, %for.body3 ] +; CHECK-NOT: = phi i64 +; CHECK-NEXT: [[LOADADDR:%[^,]+]] = getelementptr i8, i8* {{.*}}, i64 [[LSR]] +; CHECK-NEXT: = load i8, i8* [[LOADADDR]], align 1 +; CHECK: br i1 %exitcond, label %for.inc10.loopexit, label %for.body3 + +for.body3: ; preds = %for.body3, %for.body3.lr.ph + %indvars.iv = phi i64 [ 1, %for.body3.lr.ph ], [ %indvars.iv.next, %for.body3 ] + %t5 = trunc i64 %indvars.iv to i8 + %t3 = add nsw i64 %t1, %indvars.iv + %arrayidx = getelementptr inbounds i8, i8* %maxarray, i64 %t3 + %t4 = load i8, i8* %arrayidx, align 1 + %add5 = add i8 %t4, %t5 + %add6 = add i8 %add5, %t2 + %arrayidx9 = getelementptr inbounds i8, i8* %maxarray, i64 %indvars.iv + store i8 %add6, i8* %arrayidx9, align 1 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %for.inc10.loopexit, label %for.body3 + +for.inc10.loopexit: ; preds = %for.body3 + br label %for.inc10 + +for.inc10: ; preds = %for.inc10.loopexit, %for.cond1.preheader + %indvars.iv.next29 = add nuw nsw i64 %indvars.iv28, 1 + %exitcond32 = icmp eq i64 %indvars.iv.next29, %wide.trip.count31 + br i1 %exitcond32, label %for.end12.loopexit, label %for.cond1.preheader + +for.end12.loopexit: ; preds = %for.inc10 + br label %for.end12 + +for.end12: ; preds = %for.end12.loopexit, %entry + ret void +}