Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3938,6 +3938,46 @@ return V->getType() < V2->getType(); } +/// \brief Try and get a reduction value from a phi node. +/// +/// Given a phi node \p P in a block \p ParentBB, consider possible reductions +/// if they come from either \p ParentBB or a containing loop latch. +/// +/// \returns A candidate reduction value if possible, or \code nullptr \endcode +/// if not possible. +static Value *getReductionValue(PHINode *P, BasicBlock *ParentBB, + LoopInfo *LI) { + Value *Rdx = nullptr; + + // Return the incoming value if it comes from the same BB as the phi node. + if (P->getIncomingBlock(0) == ParentBB) { + Rdx = P->getIncomingValue(0); + } else if (P->getIncomingBlock(1) == ParentBB) { + Rdx = P->getIncomingValue(1); + } + + if (Rdx) + return Rdx; + + // Otherwise, check whether we have a loop latch to look at. + Loop *BBL = LI->getLoopFor(ParentBB); + if (!BBL) + return Rdx; + BasicBlock *BBLatch = BBL->getLoopLatch(); + if (!BBLatch) + return Rdx; + + // There is a loop latch, return the incoming value if it comes from + // that. This reduction pattern occassionaly turns up. + if (P->getIncomingBlock(0) == BBLatch) { + Rdx = P->getIncomingValue(0); + } else if (P->getIncomingBlock(1) == BBLatch) { + Rdx = P->getIncomingValue(1); + } + + return Rdx; +} + bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { bool Changed = false; SmallVector Incoming; @@ -4005,11 +4045,9 @@ // Check that the PHI is a reduction PHI. if (P->getNumIncomingValues() != 2) return Changed; - Value *Rdx = - (P->getIncomingBlock(0) == BB - ? (P->getIncomingValue(0)) - : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) - : nullptr)); + + Value *Rdx = getReductionValue(P, BB, LI); + // Check if this is a Binary Operator. BinaryOperator *BI = dyn_cast_or_null(Rdx); if (!BI) Index: test/Transforms/SLPVectorizer/AArch64/horizontal.ll =================================================================== --- test/Transforms/SLPVectorizer/AArch64/horizontal.ll +++ test/Transforms/SLPVectorizer/AArch64/horizontal.ll @@ -71,3 +71,77 @@ %s.0.lcssa = phi i32 [ 0, %entry ], [ %add27, %for.end.loopexit ] ret i32 %s.0.lcssa } + +;; Check whether SLP can find a reduction phi whose incoming blocks are not +;; the same as the block containing the phi. +;; +;; Came from code like, +;; +;; int s = 0; +;; for (int j = 0; j < h; j++) { +;; s += p1[0] * p2[0] +;; s += p1[1] * p2[1]; +;; s += p1[2] * p2[2]; +;; s += p1[3] * p2[3]; +;; if (s >= lim) +;; break; +;; p1 += lx; +;; p2 += lx; +;; } +define i32 @reduction_with_br(i32* noalias nocapture readonly %blk1, i32* noalias nocapture readonly %blk2, i32 %lx, i32 %h, i32 %lim) { +; CHECK-LABEL: reduction_with_br +; CHECK: load <4 x i32> +; CHECK: load <4 x i32> +; CHECK: mul nsw <4 x i32> +entry: + %cmp.16 = icmp sgt i32 %h, 0 + br i1 %cmp.16, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %idx.ext = sext i32 %lx to i64 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %if.end + %s.020 = phi i32 [ 0, %for.body.lr.ph ], [ %add13, %if.end ] + %j.019 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %if.end ] + %p2.018 = phi i32* [ %blk2, %for.body.lr.ph ], [ %add.ptr16, %if.end ] + %p1.017 = phi i32* [ %blk1, %for.body.lr.ph ], [ %add.ptr, %if.end ] + %0 = load i32, i32* %p1.017, align 4 + %1 = load i32, i32* %p2.018, align 4 + %mul = mul nsw i32 %1, %0 + %add = add nsw i32 %mul, %s.020 + %arrayidx2 = getelementptr inbounds i32, i32* %p1.017, i64 1 + %2 = load i32, i32* %arrayidx2, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %p2.018, i64 1 + %3 = load i32, i32* %arrayidx3, align 4 + %mul4 = mul nsw i32 %3, %2 + %add5 = add nsw i32 %add, %mul4 + %arrayidx6 = getelementptr inbounds i32, i32* %p1.017, i64 2 + %4 = load i32, i32* %arrayidx6, align 4 + %arrayidx7 = getelementptr inbounds i32, i32* %p2.018, i64 2 + %5 = load i32, i32* %arrayidx7, align 4 + %mul8 = mul nsw i32 %5, %4 + %add9 = add nsw i32 %add5, %mul8 + %arrayidx10 = getelementptr inbounds i32, i32* %p1.017, i64 3 + %6 = load i32, i32* %arrayidx10, align 4 + %arrayidx11 = getelementptr inbounds i32, i32* %p2.018, i64 3 + %7 = load i32, i32* %arrayidx11, align 4 + %mul12 = mul nsw i32 %7, %6 + %add13 = add nsw i32 %add9, %mul12 + %cmp14 = icmp slt i32 %add13, %lim + br i1 %cmp14, label %if.end, label %for.end.loopexit + +if.end: ; preds = %for.body + %add.ptr = getelementptr inbounds i32, i32* %p1.017, i64 %idx.ext + %add.ptr16 = getelementptr inbounds i32, i32* %p2.018, i64 %idx.ext + %inc = add nuw nsw i32 %j.019, 1 + %cmp = icmp slt i32 %inc, %h + br i1 %cmp, label %for.body, label %for.end.loopexit + +for.end.loopexit: ; preds = %for.body, %if.end + br label %for.end + +for.end: ; preds = %for.end.loopexit, %entry + %s.1 = phi i32 [ 0, %entry ], [ %add13, %for.end.loopexit ] + ret i32 %s.1 +}