Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -10164,10 +10164,30 @@ if (IPred->getLHS() == Expr) return IPred->getRHS(); } - + if (dyn_cast(Expr->getValue())) { + if (const SCEV *AR = createAddRecFromPHIWithCasts(Expr)) + return AR; + } return Expr; } + const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // Check that the AR will not overflow the truncated type. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + const SCEV *TruncAR = SE.getAddRecExpr( + SE.getTruncateExpr(AR->getStart(), Ty), SE.getTruncateExpr(Step, Ty), + L, AR->getNoWrapFlags()); + if (addOverflowAssumption(cast(TruncAR), + SCEVWrapPredicate::IncrementNUSW)) + return TruncAR; + } + return SE.getTruncateExpr(Operand, Expr->getType()); + } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { const SCEV *Operand = visit(Expr->getOperand()); const SCEVAddRecExpr *AR = dyn_cast(Operand); @@ -10201,6 +10221,202 @@ } private: + + // This is a helper-function of createAddRecFromPHIWithCasts. We have a phi + // node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via + // the loop backedge by a SCEVAddExpr, possibly also with a few casts on the + // way. This function evaluates \p Op, an operand of this SCEVAddExpr, + // and checks if it follows one of the following patterns: + // Op == (Sext ix (Trunc iy (%SymbolicPHI) to ix) to iy) + // Op == (Zext ix (Trunc iy (%SymbolicPHI) to ix) to iy) + // If the SCEV expression of \p Op confirms with one of the expected patterns + // we return the type of the truncation operation, and indicate whether the + // tuncated type is should be treated as signed/unsigned by setting + // \p Signed to true/false, respectively. + Type *isSimpleCastedPHI(const SCEV *Op, const SCEV *SymbolicPHI, + bool &Signed) { + const SCEVSignExtendExpr *Sext = dyn_cast(Op); + const SCEVZeroExtendExpr *Zext = dyn_cast(Op); + if (!Sext && !Zext) + return nullptr; + const SCEVTruncateExpr *Trunc = + Sext ? dyn_cast(Sext->getOperand()) + : dyn_cast(Zext->getOperand()); + if (!Trunc) + return nullptr; + const SCEV *X = Trunc->getOperand(); + if (X != SymbolicPHI) + return nullptr; + unsigned SourceBits = SE.getTypeSizeInBits(X->getType()); + unsigned NewBits = Sext ? SE.getTypeSizeInBits(Sext->getType()) + : SE.getTypeSizeInBits(Zext->getType()); + if (SourceBits != NewBits) + return nullptr; + Signed = Sext ? true : false; + return Trunc->getType(); + } + + // Similar to createAddRecFromPHI, but with the additional flexibility of + // adding runtime overflow checks in case casts are encountered. + // + // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the + // computation that updates the phi follows one of the following patterns: + // (Sext ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum + // (Zext ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum + // which correspond to a phi->trunc->sext->add->phi update chain. + // + // Example Flow: + // Say the rewriter is called for the following SCEV: + // 8 * ((sext i32 (trunc i64 %x to i32) to i64) + %step) + // We visitMul->visitAdd->visitSext->visitTrunc->visitUnknown(%x), + // and call this function with %SymbolicPHI = %x. + // The analysis will find that the value coming around the backedge has + // the following SCEV: + // BEValue = ((sext i32 (trunc i64 %x to i32) to i64) + %step) + // And we'll create the following predicate: + // P1: AR: {0,+,(trunc i64 %step to i32)} Added Flags: + // and return the following new AddRec to our caller, visitUnknown: + // Returned AddRec: {0,+,%step} + // + // TODO: + // 1) [CHECKME] Cache analysis results within a single call to the rewriter + // to avoid redundant predicates: + // Consider the example flow above. While going back through the visit + // call chain, addOverflowAssumption will get called two additional times: + // First, from visitTrunc: + // P2: AR: {0,+,(trunc i64 %step to i32)} Added Flags: + // And then from visitSext: + // P3: AR: {0,+,(trunc i64 %step to i32)} Added Flags: + // The last predicate (P3) is implied by the first (P1), so it will not + // get added, but even P2 should be redundant. + // + // 2) Cache analysis results across calls to the rewriter to reduce compile + // time: In future calls to the rewriter it is better not to repeat the + // call to this function. Going back to the example above, we want that + // %x would be associated with the predicate P0, so that in the future we + // could directly return the AddRec {0,+,%step} without repeating this + // analysis. For this we need to slightly extend the WrapPredicate to also + // hold a non-AddRec LHS member (%x in this case), that can be rewritten + // to the AR member. Alternatively, we could extend the EqualPredicate to + // support non-constant RHS ({0,+,%step} in this case). + // + // 3) Extend the Induction descriptor to also support inductions that involve + // casts: When needed (namely, when we are called in the context of the + // vectorizer induction analysis), a Set of cast instructions will be + // populated by this method, and provided back to isInductionPHI. This is + // needed to allow the vectorizer to properly record them to be ignored by + // the cost model and to avoid vectorizing them (otherwise these casts, + // which are redundant under the runtime overflow checks, will be + // vectorized, which can be costly). + // + // 4) Support additional induction/PHISCEV patterns: We also want to support + // inductions where the sext-trunc / zext-trunc operations (partly) occur + // after the induction update operation (the induction increment): + // + // (Trunc iy (Sext ix (%SynbolicPHI + InvariantAccum) to iy) to ix) + // (Trunc iy (Zext ix (%SynbolicPHI + InvariantAccum) to iy) to ix) + // which correspond to a phi->add->trunc->sext->phi update chain. + // + // (Trunc iy ((Sext ix (%SymbolicPhi) to iy) + InvariantAccum) to ix) + // (Trunc iy ((Zext ix (%SymbolicPhi) to iy) + InvariantAccum) to ix) + // which correspond to a phi->trunc->add->sext->phi update chain. + // + const SCEV *createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { + auto *PN = cast(SymbolicPHI->getValue()); + if (L->getHeader() != PN->getParent()) + return nullptr; + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi as an addrec if it has a unique entry value and a unique + // backedge value. + Value *BEValueV = nullptr, *StartValueV = nullptr; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *V = PN->getIncomingValue(i); + if (L->contains(PN->getIncomingBlock(i))) { + if (!BEValueV) { + BEValueV = V; + } else if (BEValueV != V) { + BEValueV = nullptr; + break; + } + } else if (!StartValueV) { + StartValueV = V; + } else if (StartValueV != V) { + StartValueV = nullptr; + break; + } + } + if (!BEValueV || !StartValueV) + return nullptr; + + const SCEV *BEValue = SE.getSCEV(BEValueV); + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, possibly with casts that we can ignore under + // an appropriate runtime guard, then we found a simple induction variable! + const auto *Add = dyn_cast(BEValue); + if (!Add) + return nullptr; + + // If there is a single occurrence of the symbolic value, possibly + // casted, replace it with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + Type *TruncTy = nullptr; + bool Signed; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (Add->getOperand(i) == SymbolicPHI || + (TruncTy = + isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed))) + if (FoundIndex == e) { + FoundIndex = i; + break; + } + + if (FoundIndex == Add->getNumOperands()) + return nullptr; + + // Create an add with everything but the specified operand. + SmallVector Ops; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (i != FoundIndex) + Ops.push_back(Add->getOperand(i)); + const SCEV *Accum = SE.getAddExpr(Ops); + + // This is not a valid addrec if the step amount is varying each + // loop iteration, but is not itself an addrec in this loop. + if (!SE.isLoopInvariant(Accum, L) && + !(isa(Accum) && + cast(Accum)->getLoop() == L)) + return nullptr; + + // Create a truncated addrec for which we will add a no overflow check. + const SCEV *StartVal = SE.getSCEV(StartValueV); + const SCEV *PHISCEV = + SE.getAddRecExpr(SE.getTruncateExpr(StartVal, TruncTy), + SE.getTruncateExpr(Accum, TruncTy), L, SCEV::FlagNSW); + const auto *AR = dyn_cast(PHISCEV); + SCEVWrapPredicate::IncrementWrapFlags AddedFlags = + Signed ? SCEVWrapPredicate::IncrementNSSW + : SCEVWrapPredicate::IncrementNUSW; + + // Now go ahead and try to create the overflow check, and if successful, + // return a new addrec in which the casts had been folded away. + // + // Note: Seems that we could do something better here, and return the above + // truncated addrec, with a narrower type than the one we return below. + // But then the type of the addrec would not match the type of the phi + // node, which would break some assumptions about the induction variable + // later on in the vectorizer, so that would require more changes. Also, we + // reach this point from a visitUnknown called by a visitTruncate; + // returning an already truncated addrec to visitTruncate would require + // also changing visitTruncate since it expects to get the (wider) type + // before truncation. + if (AR && addOverflowAssumption(AR, AddedFlags)) + return SE.getAddRecExpr(StartVal, Accum, L, SCEV::FlagNSW); + + return nullptr; + } + bool addOverflowAssumption(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { auto *A = SE.getWrapPredicate(AR, AddedFlags); Index: llvm/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll +++ llvm/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll @@ -0,0 +1,107 @@ +; RUN: opt -S -loop-vectorize -force-vector-width=4 -force-vector-interleave=1 < %s 2>&1 | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +; Check that the vectorizer identifies the %p.09 phi, +; as an induction variable, despite the potential overflow +; due to the truncation from 32bit to 8bit. +; SCEV will detect the pattern "sext(trunc(%p.09)) + %step" +; and generate the required runtime overflow check under which +; we can assume no overflow. See pr30654. +; +; int a[N]; +; void doit1(int n, int step) { +; int i; +; char p = 0; +; for (i = 0; i < n; i++) { +; a[i] = p; +; p = p + step; +; } +; } +; + +; CHECK-LABEL: @doit1 +; CHECK: vector.scevcheck +; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}}) +; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}}) +; CHECK: vector.body: +; CHECK: <4 x i32> + +@a = common local_unnamed_addr global [250 x i32] zeroinitializer, align 16 + +; Function Attrs: norecurse nounwind uwtable +define void @doit1(i32 %n, i32 %step) local_unnamed_addr { +entry: + %cmp7 = icmp sgt i32 %n, 0 + br i1 %cmp7, label %for.body.preheader, label %for.end + +for.body.preheader: + %wide.trip.count = zext i32 %n to i64 + br label %for.body + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %p.09 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ] + %sext = shl i32 %p.09, 24 + %conv = ashr exact i32 %sext, 24 + %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv + store i32 %conv, i32* %arrayidx, align 4 + %add = add nsw i32 %conv, %step + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} + +; Same as above, but for checkinhg the SCEV "zext(trunc(%p.09)) + %step": +; +; int a[N]; +; void doit2(int n, int step) { +; int i; +; unsigned char p = 0; +; for (i = 0; i < n; i++) { +; a[i] = p; +; p = p + step; +; } +; } +; + +; CHECK-LABEL: @doit2 +; CHECK: vector.scevcheck +; CHECK: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}}) +; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}}) +; CHECK: vector.body: +; CHECK: <4 x i32> + +; Function Attrs: norecurse nounwind uwtable +define void @doit2(i32 %n, i32 %step) local_unnamed_addr { +entry: + %cmp7 = icmp sgt i32 %n, 0 + br i1 %cmp7, label %for.body.preheader, label %for.end + +for.body.preheader: + %wide.trip.count = zext i32 %n to i64 + br label %for.body + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %p.09 = phi i32 [ %add, %for.body ], [ 0, %for.body.preheader ] + %conv = and i32 %p.09, 255 + %arrayidx = getelementptr inbounds [250 x i32], [250 x i32]* @a, i64 0, i64 %indvars.iv + store i32 %conv, i32* %arrayidx, align 4 + %add = add nsw i32 %conv, %step + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +}