Index: llvm/trunk/lib/Transforms/Scalar/LoopStrengthReduce.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ llvm/trunk/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -256,9 +256,22 @@ /// live in an add immediate field rather than a register. int64_t UnfoldedOffset; + /// ZeroExtendScaledReg - This formula zero extends the scale register to + /// ZeroExtendType before its use. + bool ZeroExtendScaledReg; + + /// ZeroExtendBaseReg - This formula zero extends all the base registers to + /// ZeroExtendType before their use. + bool ZeroExtendBaseReg; + + /// ZeroExtendType - The destination type of the zero extension implied by + /// the above two booleans. + Type *ZeroExtendType; + Formula() : BaseGV(nullptr), BaseOffset(0), HasBaseReg(false), Scale(0), - ScaledReg(nullptr), UnfoldedOffset(0) {} + ScaledReg(nullptr), UnfoldedOffset(0), ZeroExtendScaledReg(false), + ZeroExtendBaseReg(false), ZeroExtendType(nullptr) {} void InitialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE); @@ -413,10 +426,12 @@ /// getType - Return the type of this formula, if it has one, or null /// otherwise. This type is meaningless except for the bit size. Type *Formula::getType() const { - return !BaseRegs.empty() ? BaseRegs.front()->getType() : - ScaledReg ? ScaledReg->getType() : - BaseGV ? BaseGV->getType() : - nullptr; + return ZeroExtendType + ? ZeroExtendType + : !BaseRegs.empty() + ? BaseRegs.front()->getType() + : ScaledReg ? ScaledReg->getType() + : BaseGV ? BaseGV->getType() : nullptr; } /// DeleteBaseReg - Delete the given base reg from the BaseRegs list. @@ -457,7 +472,10 @@ } for (const SCEV *BaseReg : BaseRegs) { if (!First) OS << " + "; else First = false; - OS << "reg(" << *BaseReg << ')'; + if (ZeroExtendBaseReg) + OS << "reg(zext " << *BaseReg << " to " << *ZeroExtendType << ')'; + else + OS << "reg(" << *BaseReg << ')'; } if (HasBaseReg && BaseRegs.empty()) { if (!First) OS << " + "; else First = false; @@ -469,9 +487,12 @@ if (Scale != 0) { if (!First) OS << " + "; else First = false; OS << Scale << "*reg("; - if (ScaledReg) - OS << *ScaledReg; - else + if (ScaledReg) { + if (ZeroExtendScaledReg) + OS << "(zext " << *ScaledReg << " to " << *ZeroExtendType << ')'; + else + OS << *ScaledReg; + } else OS << ""; OS << ')'; } @@ -1732,6 +1753,7 @@ void GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, Formula Base); void GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base); void GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base); + void GenerateZExts(LSRUse &LU, unsigned LUIdx, Formula Base); void GenerateCrossUseConstantOffsets(); void GenerateAllReuseFormulae(); @@ -3627,6 +3649,64 @@ } } +/// GenerateZExts - If a scale or a base register can be rewritten as +/// "Zext({A,+,1})" then consider a formula of that form. +void LSRInstance::GenerateZExts(LSRUse &LU, unsigned LUIdx, Formula Base) { + // Don't bother with symbolic values. + if (Base.BaseGV) + return; + + auto CanBeNarrowed = [&](const SCEV *Reg) -> const SCEV * { + // Check if the register is an increment can be rewritten as zext(R) where + // the zext is free. + + const auto *RegAR = dyn_cast_or_null(Reg); + if (!RegAR) + return nullptr; + + const auto *ZExtStart = dyn_cast(RegAR->getStart()); + const auto *ConstStep = + dyn_cast(RegAR->getStepRecurrence(SE)); + if (!ZExtStart || !ConstStep || ConstStep->getValue()->getValue() != 1) + return nullptr; + + const SCEV *NarrowStart = ZExtStart->getOperand(); + if (!TTI.isZExtFree(NarrowStart->getType(), ZExtStart->getType())) + return nullptr; + + const auto *NarrowAR = dyn_cast( + SE.getAddRecExpr(NarrowStart, SE.getConstant(NarrowStart->getType(), 1), + RegAR->getLoop(), RegAR->getNoWrapFlags())); + + if (!NarrowAR || !NarrowAR->getNoWrapFlags(SCEV::FlagNUW)) + return nullptr; + + return NarrowAR; + }; + + if (Base.ScaledReg && !Base.ZeroExtendType) + if (const SCEV *S = CanBeNarrowed(Base.ScaledReg)) { + Formula F = Base; + F.ZeroExtendType = Base.ScaledReg->getType(); + F.ZeroExtendScaledReg = true; + F.ScaledReg = S; + + if (isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F)) + InsertFormula(LU, LUIdx, F); + } + + if (Base.BaseRegs.size() == 1 && !Base.ZeroExtendType) + if (const SCEV *S = CanBeNarrowed(Base.BaseRegs[0])) { + Formula F = Base; + F.ZeroExtendType = Base.BaseRegs[0]->getType(); + F.ZeroExtendBaseReg = true; + F.BaseRegs[0] = S; + + if (isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F)) + InsertFormula(LU, LUIdx, F); + } +} + namespace { /// WorkItem - Helper class for GenerateCrossUseConstantOffsets. It's used to @@ -3846,6 +3926,8 @@ LSRUse &LU = Uses[LUIdx]; for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) GenerateTruncates(LU, LUIdx, LU.Formulae[i]); + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateZExts(LU, LUIdx, LU.Formulae[i]); } GenerateCrossUseConstantOffsets(); @@ -4483,13 +4565,28 @@ // If we're expanding for a post-inc user, make the post-inc adjustment. PostIncLoopSet &Loops = const_cast(LF.PostIncLoops); - Reg = TransformForPostIncUse(Denormalize, Reg, - LF.UserInst, LF.OperandValToReplace, - Loops, SE, DT); + const SCEV *ExtendedReg = + F.ZeroExtendBaseReg ? SE.getZeroExtendExpr(Reg, F.ZeroExtendType) : Reg; - Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, nullptr, IP))); + const SCEV *PostIncReg = + TransformForPostIncUse(Denormalize, ExtendedReg, LF.UserInst, + LF.OperandValToReplace, Loops, SE, DT); + if (PostIncReg == ExtendedReg) { + Value *Expanded = Rewriter.expandCodeFor(Reg, nullptr, IP); + if (F.ZeroExtendBaseReg) + Expanded = new ZExtInst(Expanded, F.ZeroExtendType, "", IP); + Ops.push_back(SE.getUnknown(Expanded)); + } else { + Ops.push_back( + SE.getUnknown(Rewriter.expandCodeFor(PostIncReg, nullptr, IP))); + } } + // Note on post-inc uses and zero extends -- since the no-wrap behavior for + // the post-inc SCEV can be different from the no-wrap behavior of the pre-inc + // SCEV, if a post-inc transform is required we do the zero extension on the + // pre-inc expression before doing the post-inc transform. + // Expand the ScaledReg portion. Value *ICmpScaledV = nullptr; if (F.Scale != 0) { @@ -4497,22 +4594,33 @@ // If we're expanding for a post-inc user, make the post-inc adjustment. PostIncLoopSet &Loops = const_cast(LF.PostIncLoops); - ScaledS = TransformForPostIncUse(Denormalize, ScaledS, - LF.UserInst, LF.OperandValToReplace, - Loops, SE, DT); + const SCEV *ExtendedScaleS = + F.ZeroExtendScaledReg ? SE.getZeroExtendExpr(ScaledS, F.ZeroExtendType) + : ScaledS; + const SCEV *PostIncScaleS = + TransformForPostIncUse(Denormalize, ExtendedScaleS, LF.UserInst, + LF.OperandValToReplace, Loops, SE, DT); if (LU.Kind == LSRUse::ICmpZero) { // Expand ScaleReg as if it was part of the base regs. + Value *Expanded = nullptr; + if (PostIncScaleS == ExtendedScaleS) { + Expanded = Rewriter.expandCodeFor(ScaledS, nullptr, IP); + if (F.ZeroExtendScaledReg) + Expanded = new ZExtInst(Expanded, F.ZeroExtendType, "", IP); + } else { + Expanded = Rewriter.expandCodeFor(PostIncScaleS, nullptr, IP); + } + if (F.Scale == 1) - Ops.push_back( - SE.getUnknown(Rewriter.expandCodeFor(ScaledS, nullptr, IP))); + Ops.push_back(SE.getUnknown(Expanded)); else { // An interesting way of "folding" with an icmp is to use a negated // scale, which we'll implement by inserting it into the other operand // of the icmp. assert(F.Scale == -1 && "The only scale supported by ICmpZero uses is -1!"); - ICmpScaledV = Rewriter.expandCodeFor(ScaledS, nullptr, IP); + ICmpScaledV = Expanded; } } else { // Otherwise just expand the scaled register and an explicit scale, @@ -4526,7 +4634,17 @@ Ops.clear(); Ops.push_back(SE.getUnknown(FullV)); } - ScaledS = SE.getUnknown(Rewriter.expandCodeFor(ScaledS, nullptr, IP)); + + Value *Expanded = nullptr; + if (PostIncScaleS == ExtendedScaleS) { + Expanded = Rewriter.expandCodeFor(ScaledS, nullptr, IP); + if (F.ZeroExtendScaledReg) + Expanded = new ZExtInst(Expanded, F.ZeroExtendType, "", IP); + } else { + Expanded = Rewriter.expandCodeFor(PostIncScaleS, nullptr, IP); + } + + ScaledS = SE.getUnknown(Expanded); if (F.Scale != 1) ScaledS = SE.getMulExpr(ScaledS, SE.getConstant(ScaledS->getType(), F.Scale)); Index: llvm/trunk/test/Transforms/LoopStrengthReduce/zext-of-scale.ll =================================================================== --- llvm/trunk/test/Transforms/LoopStrengthReduce/zext-of-scale.ll +++ llvm/trunk/test/Transforms/LoopStrengthReduce/zext-of-scale.ll @@ -0,0 +1,70 @@ +; RUN: opt < %s -S -loop-reduce | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +%struct = type { [8 x i8] } + +declare void @use_32(i32) +declare void @use_64(i64) + +define void @f(i32 %tmp156, i32* %length_buf_1, i32* %length_buf_0, %struct* %b, + %struct* %c, %struct* %d, %struct* %e, i32* %length_buf_2, + i32 %tmp160) { +; CHECK-LABEL: @f( +entry: + %begin151 = getelementptr inbounds %struct, %struct* %b, i64 0, i32 0, i64 12 + %tmp21 = bitcast i8* %begin151 to i32* + %begin157 = getelementptr inbounds %struct, %struct* %c, i64 0, i32 0, i64 16 + %tmp23 = bitcast i8* %begin157 to double* + %begin163 = getelementptr inbounds %struct, %struct* %d, i64 0, i32 0, i64 16 + %tmp25 = bitcast i8* %begin163 to double* + %length.i820 = load i32, i32* %length_buf_1, align 4, !range !0 + %enter = icmp ne i32 %tmp156, -1 + br i1 %enter, label %ok_146, label %block_81_2 + +ok_146: + %var_13 = phi double [ %tmp186, %ok_161 ], [ 0.000000e+00, %entry ] + %var_17 = phi i32 [ %tmp187, %ok_161 ], [ %tmp156, %entry ] + %tmp174 = zext i32 %var_17 to i64 + %tmp175 = icmp ult i32 %var_17, %length.i820 + br i1 %tmp175, label %ok_152, label %block_81_2 + +ok_152: + %tmp176 = getelementptr inbounds i32, i32* %tmp21, i64 %tmp174 + %tmp177 = load i32, i32* %tmp176, align 4 + %tmp178 = zext i32 %tmp177 to i64 + %length.i836 = load i32, i32* %length_buf_2, align 4, !range !0 + %tmp179 = icmp ult i32 %tmp177, %length.i836 + br i1 %tmp179, label %ok_158, label %block_81_2 + +ok_158: + %tmp180 = getelementptr inbounds double, double* %tmp23, i64 %tmp178 + %tmp181 = load double, double* %tmp180, align 8 + %length.i = load i32, i32* %length_buf_0, align 4, !range !0 + %tmp182 = icmp slt i32 %var_17, %length.i + br i1 %tmp182, label %ok_161, label %block_81_2 + +ok_161: +; CHECK-LABEL: ok_161: +; CHECK: add +; CHECK-NOT: add + %tmp183 = getelementptr inbounds double, double* %tmp25, i64 %tmp174 + %tmp184 = load double, double* %tmp183, align 8 + %tmp185 = fmul double %tmp181, %tmp184 + %tmp186 = fadd double %var_13, %tmp185 + %tmp187 = add nsw i32 %var_17, 1 + %tmp188 = icmp slt i32 %tmp187, %tmp160 +; CHECK: br + br i1 %tmp188, label %ok_146, label %block_81 + +block_81: + call void @use_64(i64 %tmp174) ;; pre-inc use + call void @use_32(i32 %tmp187) ;; post-inc use + ret void + +block_81_2: + ret void +} + +!0 = !{i32 0, i32 2147483647}