Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -2207,6 +2207,25 @@ return r; } +/// Determine if any of the operands in this SCEV are a constant or if +/// any of the add or multiply expressions in this SCEV contain a constant. +static bool containsConstantSomewhere(const SCEV *StartExpr) { + SmallVector Ops; + Ops.push_back(StartExpr); + while (!Ops.empty()) { + const SCEV *CurrentExpr = Ops.pop_back_val(); + if (isa(*CurrentExpr)) + return true; + + if (isa(*CurrentExpr) || isa(*CurrentExpr)) { + const auto *CurrentNAry = cast(CurrentExpr); + for (const SCEV *Operand : CurrentNAry->operands()) + Ops.push_back(Operand); + } + } + return false; +} + /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, @@ -2246,11 +2265,13 @@ // C1*(C2+V) -> C1*C2 + C1*V if (Ops.size() == 2) - if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) - if (Add->getNumOperands() == 2 && - isa(Add->getOperand(0))) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), - getMulExpr(LHSC, Add->getOperand(1))); + if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) + // If any of Add's ops are Adds or Muls with a constant, + // apply this transformation as well. + if (Add->getNumOperands() == 2) + if (containsConstantSomewhere(Add)) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), + getMulExpr(LHSC, Add->getOperand(1))); ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { Index: test/Analysis/ScalarEvolution/scev-mul-expr-fold.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/scev-mul-expr-fold.ll @@ -0,0 +1,62 @@ +; RUN: opt -analyze -scalar-evolution < %s | FileCheck %s +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-n32-S64" + +%ParamStruct = type { i8**, i32*, i8*, i32, i32, i32, i8*, i32, i32, i32, i32 } + +; Function Attrs: nounwind +define void @root.expand(%ParamStruct* nocapture readonly %p, i32 %x1, i32 %x2, i32 %outstep) #0 { +Begin: + %0 = getelementptr inbounds %ParamStruct* %p, i32 0, i32 2 + %1 = bitcast i8** %0 to float** + %2 = load float** %1, align 4, !tbaa !8, !alias.scope !12 + %3 = bitcast %ParamStruct* %p to float*** + %inputs_base2 = load float*** %3, align 4 + %input_base1 = load float** %inputs_base2, align 4, !tbaa !8, !alias.scope !12 + %4 = icmp ugt i32 %x2, %x1 + br i1 %4, label %Loop.preheader, label %Exit + +Loop.preheader: ; preds = %Begin + %5 = sub i32 %x2, %x1 + %xtraiter = and i32 %5, 7 + %lcmp.mod = icmp ne i32 %xtraiter, 0 + %lcmp.overflow = icmp eq i32 %x2, %x1 + %lcmp.or = or i1 %lcmp.overflow, %lcmp.mod + br i1 %lcmp.or, label %Loop.prol, label %Loop.preheader.split + +Loop.prol: ; preds = %Loop.prol, %Loop.preheader + %X.prol = phi i32 [ %7, %Loop.prol ], [ %x1, %Loop.preheader ] + %prol.iter = phi i32 [ %prol.iter.sub, %Loop.prol ], [ %xtraiter, %Loop.preheader ] + %6 = sub i32 %X.prol, %x1 + %7 = add nuw i32 %X.prol, 1 + %prol.iter.sub = add i32 %prol.iter, -1 + %prol.iter.cmp = icmp eq i32 %prol.iter.sub, 0 + br i1 %prol.iter.cmp, label %Loop.preheader.split, label %Loop.prol, !llvm.loop !16 + +Loop.preheader.split: ; preds = %Loop.prol, %Loop.preheader + %X.unr = phi i32 [ %x1, %Loop.preheader ], [ %7, %Loop.prol ] + %8 = icmp ult i32 %5, 8 + br i1 %8, label %Exit, label %Loop + +Exit: ; preds = %Loop, %Loop.preheader.split, %Begin + ret void + +Loop: ; preds = %Loop, %Loop.preheader.split + %X = phi i32 [ %X.unr, %Loop.preheader.split ] + %9 = sub i32 %X, %x1 + %testSCEV = getelementptr float* %2, i32 %9 +; CHECK: ((4 * %X.unr) + (-4 * %x1) + %2) + br label %Exit +} + +attributes #0 = { nounwind } + +!8 = metadata !{metadata !9, metadata !9, i64 0} +!9 = metadata !{metadata !"pointer", metadata !10, i64 0} +!10 = metadata !{metadata !"", metadata !11} +!11 = metadata !{metadata !""} +!12 = metadata !{metadata !12, metadata !13, metadata !"argument_scope"} +!13 = metadata !{metadata !13} +!14 = metadata !{metadata !15, metadata !15, i64 0} +!15 = metadata !{metadata !"", metadata !10, i64 0} +!16 = metadata !{metadata !16, metadata !17} +!17 = metadata !{metadata !"llvm.loop.unroll.disable"}