Index: lib/Transforms/Scalar/LoopRerollPass.cpp =================================================================== --- lib/Transforms/Scalar/LoopRerollPass.cpp +++ lib/Transforms/Scalar/LoopRerollPass.cpp @@ -166,7 +166,9 @@ typedef SmallVector SmallInstructionVector; typedef SmallSet SmallInstructionSet; - // A chain of isomorphic instructions, indentified by a single-use PHI, + DenseMap IVToIncMap; + + // A chain of isomorphic instructions, identified by a single-use PHI, // representing a reduction. Only the last value may be used outside the // loop. struct SimpleLoopReduction { @@ -335,7 +337,7 @@ // x[i*3+1] = y2 // x[i*3+2] = y3 // - // Base instruction -> i*3 + // Base instruction -> i*3 // +---+----+ // / | \ // ST[y1] +1 +2 <-- Roots @@ -366,8 +368,10 @@ struct DAGRootTracker { DAGRootTracker(LoopReroll *Parent, Loop *L, Instruction *IV, ScalarEvolution *SE, AliasAnalysis *AA, - TargetLibraryInfo *TLI) - : Parent(Parent), L(L), SE(SE), AA(AA), TLI(TLI), IV(IV) {} + TargetLibraryInfo *TLI, + DenseMap &IncrMap) + : Parent(Parent), L(L), SE(SE), AA(AA), TLI(TLI), IV(IV), + IVToIncMap(IncrMap) {} /// Stage 1: Find all the DAG roots for the induction variable. bool findRoots(); @@ -417,7 +421,7 @@ // The loop induction variable. Instruction *IV; // Loop step amount. - uint64_t Inc; + int64_t Inc; // Loop reroll count; if Inc == 1, this records the scaling applied // to the indvar: a[i*2+0] = ...; a[i*2+1] = ... ; // If Inc is not 1, Scale = Inc. @@ -430,6 +434,7 @@ // they are used in (or specially, IL_All for instructions // used in the loop increment mechanism). UsesTy Uses; + DenseMap &IVToIncMap; }; void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs); @@ -484,10 +489,18 @@ continue; if (const SCEVConstant *IncSCEV = dyn_cast(PHISCEV->getStepRecurrence(*SE))) { - if (!IncSCEV->getValue()->getValue().isStrictlyPositive()) - continue; - if (IncSCEV->getValue()->uge(MaxInc)) - continue; + const APInt &AInt = IncSCEV->getValue()->getValue(); + if (AInt.isStrictlyPositive()) { + if (IncSCEV->getValue()->uge(MaxInc)) + continue; + IVToIncMap[I] = IncSCEV->getValue()->getZExtValue(); + } else if (AInt.isNegative()) { + const SCEVConstant *PIncSCEV = + dyn_cast(SE->getNegativeSCEV(IncSCEV)); + if (!PIncSCEV || PIncSCEV->getValue()->uge(MaxInc)) + continue; + IVToIncMap[I] = IncSCEV->getValue()->getSExtValue(); + } DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " << *PHISCEV << "\n"); @@ -656,10 +669,14 @@ static bool isLoopIncrement(User *U, Instruction *IV) { BinaryOperator *BO = dyn_cast(U); - if (!BO || BO->getOpcode() != Instruction::Add) + + if (BO && BO->getOpcode() != Instruction::Add) + return false; + + if (!BO && !isa(U)) return false; - for (auto *UU : BO->users()) { + for (auto *UU : U->users()) { PHINode *PN = dyn_cast(UU); if (PN && PN == IV) return true; @@ -704,13 +721,7 @@ // No duplicates, please. return false; - // FIXME: Add support for negative values. - if (V < 0) { - DEBUG(dbgs() << "LRR: Aborting due to negative value: " << V << "\n"); - return false; - } - - Roots[V] = cast(I); + Roots[std::abs(V)] = cast(I); } if (Roots.empty()) @@ -731,7 +742,7 @@ unsigned NumBaseUses = BaseUsers.size(); if (NumBaseUses == 0) NumBaseUses = Roots.begin()->second->getNumUses(); - + // Check that every node has the same number of users. for (auto &KV : Roots) { if (KV.first == 0) @@ -744,7 +755,7 @@ } } - return true; + return true; } bool LoopReroll::DAGRootTracker:: @@ -787,7 +798,7 @@ if (!collectPossibleRoots(IVU, V)) return false; - // If we didn't get a root for index zero, then IVU must be + // If we didn't get a root for index zero, then IVU must be // subsumed. if (V.find(0) == V.end()) SubsumedInsts.insert(IVU); @@ -818,13 +829,10 @@ } bool LoopReroll::DAGRootTracker::findRoots() { - - const SCEVAddRecExpr *RealIVSCEV = cast(SE->getSCEV(IV)); - Inc = cast(RealIVSCEV->getOperand(1))-> - getValue()->getZExtValue(); + Inc = IVToIncMap[IV]; assert(RootSets.empty() && "Unclean state!"); - if (Inc == 1) { + if (std::abs(Inc) == 1) { for (auto *IVU : IV->users()) { if (isLoopIncrement(IVU, IV)) LoopIncs.push_back(cast(IVU)); @@ -1103,15 +1111,15 @@ " vs. " << *RootInst << "\n"); return false; } - + RootIt = TryIt; RootInst = TryIt->first; } // All instructions between the last root and this root - // may belong to some other iteration. If they belong to a + // may belong to some other iteration. If they belong to a // future iteration, then they're dangerous to alias with. - // + // // Note that because we allow a limited amount of flexibility in the order // that we visit nodes, LastRootIt might be *before* RootIt, in which // case we've already checked this set of instructions so we shouldn't @@ -1267,6 +1275,9 @@ ++J; } + int64_t Inc = IVToIncMap[IV]; + bool Negative = Inc < 0; + ; const DataLayout &DL = Header->getModule()->getDataLayout(); // We need to create a new induction variable for each different BaseInst. @@ -1275,10 +1286,9 @@ const SCEVAddRecExpr *RealIVSCEV = cast(SE->getSCEV(DRS.BaseInst)); const SCEV *Start = RealIVSCEV->getStart(); - const SCEVAddRecExpr *H = cast - (SE->getAddRecExpr(Start, - SE->getConstant(RealIVSCEV->getType(), 1), - L, SCEV::FlagAnyWrap)); + const SCEVAddRecExpr *H = cast(SE->getAddRecExpr( + Start, SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1), L, + SCEV::FlagAnyWrap)); { // Limit the lifetime of SCEVExpander. SCEVExpander Expander(*SE, DL, "reroll"); Value *NewIV = Expander.expandCodeFor(H, IV->getType(), Header->begin()); @@ -1294,8 +1304,8 @@ const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); // Iteration count SCEV minus 1 - const SCEV *ICMinus1SCEV = - SE->getMinusSCEV(ICSCEV, SE->getConstant(ICSCEV->getType(), 1)); + const SCEV *ICMinus1SCEV = SE->getMinusSCEV( + ICSCEV, SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1)); Value *ICMinus1; // Iteration count minus 1 if (isa(ICMinus1SCEV)) { @@ -1444,13 +1454,13 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *IterCount, ReductionTracker &Reductions) { - DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI); + DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, IVToIncMap); if (!DAGRoots.findRoots()) return false; DEBUG(dbgs() << "LRR: Found all root induction increments for: " << *IV << "\n"); - + if (!DAGRoots.validate(Reductions)) return false; if (!Reductions.validateSelected()) @@ -1497,6 +1507,7 @@ // First, we need to find the induction variable with respect to which we can // reroll (there may be several possible options). SmallInstructionVector PossibleIVs; + IVToIncMap.clear(); collectPossibleIVs(L, PossibleIVs); if (PossibleIVs.empty()) { Index: test/Transforms/LoopReroll/negative.ll =================================================================== --- /dev/null +++ test/Transforms/LoopReroll/negative.ll @@ -0,0 +1,56 @@ +; RUN: opt -S -loop-reroll %s | FileCheck %s +target triple = "aarch64--linux-gnu" +@buf = global [16 x i8] c"\0A\0A\0A\0A\0A\0A\0A\0A\0A\0A\0A\0A\0A\0A\0A\0A", align 1 +declare i32 @goo(i32) + +define i32 @test1(i32 %len, i8* nocapture readonly %buf) #0 { +entry: + %cmp.13 = icmp sgt i32 %len, 1 + br i1 %cmp.13, label %while.body.lr.ph, label %while.end + +while.body.lr.ph: ; preds = %entry + br label %while.body + +while.body: +;CHECK-LABEL: while.body: +;CHECK-NEXT: %indvar = phi i32 [ %indvar.next, %while.body ], [ 0, %while.body.lr.ph ] +;CHECK-NEXT: %sum4.015 = phi i64 [ 0, %while.body.lr.ph ], [ %add, %while.body ] +;CHECK-NEXT: %5 = mul i32 %indvar, -1 +;CHECK-NEXT: %6 = add i32 %len, %5 +;CHECK-NEXT: %idxprom = sext i32 %6 to i64 +;CHECK-NEXT: %arrayidx = getelementptr inbounds i8, i8* %buf, i64 %idxprom +;CHECK-NEXT: %7 = load i8, i8* %arrayidx, align 1 +;CHECK-NEXT: %conv = zext i8 %7 to i64 +;CHECK-NEXT: %add = add i64 %conv, %sum4.015 +;CHECK-NEXT: %indvar.next = add i32 %indvar, 1 +;CHECK-NEXT: %exitcond = icmp eq i32 %6, %4 +;CHECK-NEXT: br i1 %exitcond, label %while.cond.while.end_crit_edge, label %while.body + + %sum4.015 = phi i64 [ 0, %while.body.lr.ph ], [ %add4, %while.body ] + %len.addr.014 = phi i32 [ %len, %while.body.lr.ph ], [ %sub5, %while.body ] + %idxprom = sext i32 %len.addr.014 to i64 + %arrayidx = getelementptr inbounds i8, i8* %buf, i64 %idxprom + %0 = load i8, i8* %arrayidx, align 1 + %conv = zext i8 %0 to i64 + %add = add i64 %conv, %sum4.015 + %sub = add nsw i32 %len.addr.014, -1 + %idxprom1 = sext i32 %sub to i64 + %arrayidx2 = getelementptr inbounds i8, i8* %buf, i64 %idxprom1 + %1 = load i8, i8* %arrayidx2, align 1 + %conv3 = zext i8 %1 to i64 + %add4 = add i64 %add, %conv3 + %sub5 = add nsw i32 %len.addr.014, -2 + %cmp = icmp sgt i32 %sub5, 1 + br i1 %cmp, label %while.body, label %while.cond.while.end_crit_edge + +while.cond.while.end_crit_edge: ; preds = %while.body + %add4.lcssa = phi i64 [ %add4, %while.body ] + %phitmp = trunc i64 %add4.lcssa to i32 + br label %while.end + +while.end: ; preds = %while.cond.while.end_crit_edge, %entry + %sum4.0.lcssa = phi i32 [ %phitmp, %while.cond.while.end_crit_edge ], [ 0, %entry ] + %call = tail call i32 @goo(i32 %sum4.0.lcssa) + unreachable +} +