Index: lib/Transforms/Scalar/LoopRerollPass.cpp =================================================================== --- lib/Transforms/Scalar/LoopRerollPass.cpp +++ lib/Transforms/Scalar/LoopRerollPass.cpp @@ -394,6 +394,7 @@ bool instrDependsOn(Instruction *I, UsesTy::iterator Start, UsesTy::iterator End); + void replaceIV(Instruction *Inst, Instruction *IV, const SCEV *IterCount); LoopReroll *Parent; @@ -455,6 +456,48 @@ return false; } +static const SCEVConstant *getIncremntFactorSCEV(ScalarEvolution *SE, + const SCEV *SCEVExpr, + Instruction *IV) { + const SCEVMulExpr *MulSCEV = dyn_cast(SCEVExpr); + + // If StepRecurrence of a SCEVExpr is a constant (c1 * c2, c2 = sizeof(ptr)), + // Return c1. + if (!MulSCEV && IV->getType()->isPointerTy()) + if (const SCEVConstant *IncSCEV = dyn_cast(SCEVExpr)) { + const PointerType *PTy = cast(IV->getType()); + Type *ElTy = PTy->getElementType(); + const SCEV *SizeOfExpr = + SE->getSizeOfExpr(SE->getEffectiveSCEVType(IV->getType()), ElTy); + if (IncSCEV->getValue()->getValue().isNegative()) { + const SCEV *NewSCEV = + SE->getUDivExpr(SE->getNegativeSCEV(SCEVExpr), SizeOfExpr); + return dyn_cast(SE->getNegativeSCEV(NewSCEV)); + } else { + return dyn_cast(SE->getUDivExpr(SCEVExpr, SizeOfExpr)); + } + } + + if (!MulSCEV) + return nullptr; + + // If StepRecurrence of a SCEVExpr is a c * sizeof(x), where c is constant, + // Return c. + const SCEVConstant *CIncSCEV = nullptr; + for (const SCEV *Operand : MulSCEV->operands()) { + if (const SCEVConstant *Constant = dyn_cast(Operand)) { + CIncSCEV = Constant; + } else if (const SCEVUnknown *Unknown = dyn_cast(Operand)) { + Type *AllocTy; + if (!Unknown->isSizeOf(AllocTy)) + break; + } else { + return nullptr; + } + } + return CIncSCEV; +} + // Collect the list of loop induction variables with respect to which it might // be possible to reroll the loop. void LoopReroll::collectPossibleIVs(Loop *L, @@ -464,7 +507,7 @@ IE = Header->getFirstInsertionPt(); I != IE; ++I) { if (!isa(I)) continue; - if (!I->getType()->isIntegerTy()) + if (!I->getType()->isIntegerTy() && !I->getType()->isPointerTy()) continue; if (const SCEVAddRecExpr *PHISCEV = @@ -473,8 +516,12 @@ continue; if (!PHISCEV->isAffine()) continue; - if (const SCEVConstant *IncSCEV = - dyn_cast(PHISCEV->getStepRecurrence(*SE))) { + const SCEVConstant *IncSCEV = nullptr; + if (I->getType()->isPointerTy()) + IncSCEV = getIncremntFactorSCEV(SE, PHISCEV->getStepRecurrence(*SE), I); + else + IncSCEV = dyn_cast(PHISCEV->getStepRecurrence(*SE)); + if (IncSCEV) { const APInt &AInt = IncSCEV->getValue()->getValue().abs(); if (IncSCEV->getValue()->isZero() || AInt.uge(MaxInc)) continue; @@ -646,10 +693,12 @@ static bool isLoopIncrement(User *U, Instruction *IV) { BinaryOperator *BO = dyn_cast(U); - if (!BO || BO->getOpcode() != Instruction::Add) + + if ((BO && BO->getOpcode() != Instruction::Add) || + (!BO && !isa(U))) return false; - for (auto *UU : BO->users()) { + for (auto *UU : U->users()) { PHINode *PN = dyn_cast(UU); if (PN && PN == IV) return true; @@ -1267,61 +1316,84 @@ ++J; } - bool Negative = IVToIncMap[IV] < 0; - const DataLayout &DL = Header->getModule()->getDataLayout(); // We need to create a new induction variable for each different BaseInst. - for (auto &DRS : RootSets) { + for (auto &DRS : RootSets) // Insert the new induction variable. - const SCEVAddRecExpr *RealIVSCEV = - cast(SE->getSCEV(DRS.BaseInst)); - const SCEV *Start = RealIVSCEV->getStart(); - const SCEVAddRecExpr *H = cast(SE->getAddRecExpr( - Start, SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1), L, - SCEV::FlagAnyWrap)); - { // Limit the lifetime of SCEVExpander. - SCEVExpander Expander(*SE, DL, "reroll"); - Value *NewIV = Expander.expandCodeFor(H, IV->getType(), &Header->front()); - - for (auto &KV : Uses) { - if (KV.second.find_first() == 0) - KV.first->replaceUsesOfWith(DRS.BaseInst, NewIV); - } - - if (BranchInst *BI = dyn_cast(Header->getTerminator())) { - // FIXME: Why do we need this check? - if (Uses[BI].find_first() == IL_All) { - const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); + replaceIV(DRS.BaseInst, IV, IterCount); - // Iteration count SCEV minus 1 - const SCEV *ICMinus1SCEV = SE->getMinusSCEV( - ICSCEV, SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1)); + SimplifyInstructionsInBlock(Header, TLI); + DeleteDeadPHIs(Header, TLI); +} - Value *ICMinus1; // Iteration count minus 1 - if (isa(ICMinus1SCEV)) { - ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(), BI); - } else { - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) - Preheader = InsertPreheaderForLoop(L, Parent); +void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst, + Instruction *InstIV, + const SCEV *IterCount) { + BasicBlock *Header = L->getHeader(); + int64_t Inc = IVToIncMap[InstIV]; + bool Negative = Inc < 0; + + const SCEVAddRecExpr *RealIVSCEV = cast(SE->getSCEV(Inst)); + const SCEV *Start = RealIVSCEV->getStart(); + + const SCEV *SizeOfExpr = nullptr; + const SCEV *IncrExpr = + SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1); + if (auto *PTy = dyn_cast(Inst->getType())) { + Type *ElTy = PTy->getElementType(); + SizeOfExpr = + SE->getSizeOfExpr(SE->getEffectiveSCEVType(Inst->getType()), ElTy); + IncrExpr = SE->getMulExpr(IncrExpr, SizeOfExpr); + } + const SCEV *NewIVSCEV = + SE->getAddRecExpr(Start, IncrExpr, L, SCEV::FlagAnyWrap); + + { // Limit the lifetime of SCEVExpander. + const DataLayout &DL = Header->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "reroll"); + Value *NewIV = + Expander.expandCodeFor(NewIVSCEV, InstIV->getType(), Header->begin()); + + for (auto &KV : Uses) + if (KV.second.find_first() == 0) + KV.first->replaceUsesOfWith(Inst, NewIV); + + if (BranchInst *BI = dyn_cast(Header->getTerminator())) { + // FIXME: Why do we need this check? + if (Uses[BI].find_first() == IL_All) { + const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); + + // Iteration count SCEV minus or plus 1 + const SCEV *MinusPlus1SCEV = + SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1); + if (Inst->getType()->isPointerTy()) { + assert(SizeOfExpr && " SizeOfExpr is not initialized"); + MinusPlus1SCEV = SE->getMulExpr(MinusPlus1SCEV, SizeOfExpr); + } - ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(), - Preheader->getTerminator()); - } + const SCEV *ICMinusPlus1SCEV = SE->getMinusSCEV(ICSCEV, MinusPlus1SCEV); + // Iteration count minus 1 + Value *ICMinusPlus1 = nullptr; + if (isa(ICMinusPlus1SCEV)) { + ICMinusPlus1 = + Expander.expandCodeFor(ICMinusPlus1SCEV, NewIV->getType(), BI); + } else { + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) + Preheader = InsertPreheaderForLoop(L, Parent); + ICMinusPlus1 = Expander.expandCodeFor( + ICMinusPlus1SCEV, NewIV->getType(), Preheader->getTerminator()); + } - Value *Cond = - new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinus1, "exitcond"); - BI->setCondition(Cond); + Value *Cond = + new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinusPlus1, "exitcond"); + BI->setCondition(Cond); - if (BI->getSuccessor(1) != Header) - BI->swapSuccessors(); - } + if (BI->getSuccessor(1) != Header) + BI->swapSuccessors(); } } } - - SimplifyInstructionsInBlock(Header, TLI); - DeleteDeadPHIs(Header, TLI); } // Validate the selected reductions. All iterations must have an isomorphic Index: test/Transforms/LoopReroll/ptrindvar.ll =================================================================== --- /dev/null +++ test/Transforms/LoopReroll/ptrindvar.ll @@ -0,0 +1,80 @@ +; RUN: opt -S -loop-reroll %s | FileCheck %s +target triple = "aarch64--linux-gnu" + +define i32 @test(i32* readonly %buf, i32* readnone %end) #0 { +entry: + %cmp.9 = icmp eq i32* %buf, %end + br i1 %cmp.9, label %while.end, label %while.body.preheader + +while.body.preheader: + br label %while.body + +while.body: +;CHECK-LABEL: while.body: +;CHECK-NEXT: %indvar = phi i64 [ %indvar.next, %while.body ], [ 0, %while.body.preheader ] +;CHECK-NEXT: %S.011 = phi i32 [ %add, %while.body ], [ undef, %while.body.preheader ] +;CHECK-NEXT: %scevgep = getelementptr i32, i32* %buf, i64 %indvar +;CHECK-NEXT: %4 = load i32, i32* %scevgep, align 4 +;CHECK-NEXT: %add = add nsw i32 %4, %S.011 +;CHECK-NEXT: %indvar.next = add i64 %indvar, 1 +;CHECK-NEXT: %exitcond = icmp eq i32* %scevgep, %scevgep5 +;CHECK-NEXT: br i1 %exitcond, label %while.end.loopexit, label %while.body + + %S.011 = phi i32 [ %add2, %while.body ], [ undef, %while.body.preheader ] + %buf.addr.010 = phi i32* [ %add.ptr, %while.body ], [ %buf, %while.body.preheader ] + %0 = load i32, i32* %buf.addr.010, align 4 + %add = add nsw i32 %0, %S.011 + %arrayidx1 = getelementptr inbounds i32, i32* %buf.addr.010, i64 1 + %1 = load i32, i32* %arrayidx1, align 4 + %add2 = add nsw i32 %add, %1 + %add.ptr = getelementptr inbounds i32, i32* %buf.addr.010, i64 2 + %cmp = icmp eq i32* %add.ptr, %end + br i1 %cmp, label %while.end.loopexit, label %while.body + +while.end.loopexit: + %add2.lcssa = phi i32 [ %add2, %while.body ] + br label %while.end + +while.end: + %S.0.lcssa = phi i32 [ undef, %entry ], [ %add2.lcssa, %while.end.loopexit ] + ret i32 %S.0.lcssa +} + +define i32 @test2(i32* readonly %buf, i32* readnone %end) #0 { +entry: + %cmp.9 = icmp eq i32* %buf, %end + br i1 %cmp.9, label %while.end, label %while.body.preheader + +while.body.preheader: + br label %while.body + +while.body: +;CHECK-LABEL: while.body: +;CHECK-NEXT: %indvar = phi i64 [ %indvar.next, %while.body ], [ 0, %while.body.preheader ] +;CHECK-NEXT: %S.011 = phi i32 [ %add, %while.body ], [ undef, %while.body.preheader ] +;CHECK-NEXT: %scevgep = getelementptr i32, i32* %buf, i64 %indvar +;CHECK-NEXT: %4 = load i32, i32* %scevgep, align 4 +;CHECK-NEXT: %add = add nsw i32 %4, %S.011 +;CHECK-NEXT: %indvar.next = add i64 %indvar, -1 +;CHECK-NEXT: %exitcond = icmp eq i32* %scevgep, %scevgep5 +;CHECK-NEXT: br i1 %exitcond, label %while.end.loopexit, label %while.body + + %S.011 = phi i32 [ %add2, %while.body ], [ undef, %while.body.preheader ] + %buf.addr.010 = phi i32* [ %add.ptr, %while.body ], [ %buf, %while.body.preheader ] + %0 = load i32, i32* %buf.addr.010, align 4 + %add = add nsw i32 %0, %S.011 + %arrayidx1 = getelementptr inbounds i32, i32* %buf.addr.010, i64 -1 + %1 = load i32, i32* %arrayidx1, align 4 + %add2 = add nsw i32 %add, %1 + %add.ptr = getelementptr inbounds i32, i32* %buf.addr.010, i64 -2 + %cmp = icmp eq i32* %add.ptr, %end + br i1 %cmp, label %while.end.loopexit, label %while.body + +while.end.loopexit: + %add2.lcssa = phi i32 [ %add2, %while.body ] + br label %while.end + +while.end: + %S.0.lcssa = phi i32 [ undef, %entry ], [ %add2.lcssa, %while.end.loopexit ] + ret i32 %S.0.lcssa +}