Index: llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3659,16 +3659,17 @@ unsigned ReductionOpcode; /// The opcode of the values we perform a reduction on. unsigned ReducedValueOpcode; - /// The width of one full horizontal reduction operation. - unsigned ReduxWidth; /// Should we model this reduction as a pairwise reduction tree or a tree that /// splits the vector in halves and adds those halves. bool IsPairwiseReduction; public: + /// The width of one full horizontal reduction operation. + unsigned ReduxWidth; + HorizontalReduction() : ReductionRoot(nullptr), ReductionPHI(nullptr), ReductionOpcode(0), - ReducedValueOpcode(0), ReduxWidth(0), IsPairwiseReduction(false) {} + ReducedValueOpcode(0), IsPairwiseReduction(false), ReduxWidth(0) {} /// \brief Try to find a reduction tree. bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { @@ -3825,8 +3826,11 @@ return VectorizedTree != nullptr; } -private: + unsigned numReductionValues() const { + return ReducedVals.size(); + } +private: /// \brief Calculate the cost of a reduction. int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal) { Type *ScalarTy = FirstReducedVal->getType(); @@ -3973,6 +3977,30 @@ return Rdx; } +/// \brief Attempt to reduce a horizontal reduction. +/// If it is legal to match a horizontal reduction feeding +/// the phi node P with reduction operators BI, then check if it +/// can be done. +/// \returns true if a horizontal reduction was matched and reduced. +/// \returns false if a horizontal reduction was not matched. +static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI, + BoUpSLP &R, TargetTransformInfo *TTI) { + if (!ShouldVectorizeHor) + return false; + + HorizontalReduction HorRdx; + if (!HorRdx.matchAssociativeReduction(P, BI)) + return false; + + // If there is a sufficient number of reduction values, reduce + // to a nearby power-of-2. Can safely generate oversized + // vectors and rely on the backend to split them to legal sizes. + HorRdx.ReduxWidth = + std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues())); + + return HorRdx.tryToReduce(R, TTI); +} + bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { bool Changed = false; SmallVector Incoming; @@ -4049,9 +4077,7 @@ continue; // Try to match and vectorize a horizontal reduction. - HorizontalReduction HorRdx; - if (ShouldVectorizeHor && HorRdx.matchAssociativeReduction(P, BI) && - HorRdx.tryToReduce(R, TTI)) { + if (canMatchHorizontalReduction(P, BI, R, TTI)) { Changed = true; it = BB->begin(); e = BB->end(); @@ -4074,15 +4100,12 @@ continue; } - // Try to vectorize horizontal reductions feeding into a store. if (ShouldStartVectorizeHorAtStore) if (StoreInst *SI = dyn_cast(it)) if (BinaryOperator *BinOp = dyn_cast(SI->getValueOperand())) { - HorizontalReduction HorRdx; - if (((HorRdx.matchAssociativeReduction(nullptr, BinOp) && - HorRdx.tryToReduce(R, TTI)) || - tryToVectorize(BinOp, R))) { + if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI) || + tryToVectorize(BinOp, R)) { Changed = true; it = BB->begin(); e = BB->end(); Index: llvm/trunk/test/Transforms/SLPVectorizer/AArch64/horizontal.ll =================================================================== --- llvm/trunk/test/Transforms/SLPVectorizer/AArch64/horizontal.ll +++ llvm/trunk/test/Transforms/SLPVectorizer/AArch64/horizontal.ll @@ -145,3 +145,126 @@ %s.1 = phi i32 [ 0, %entry ], [ %add13, %for.end.loopexit ] ret i32 %s.1 } + +; CHECK: test_unrolled_select +; CHECK: load <8 x i8> +; CHECK: load <8 x i8> +; CHECK: select <8 x i1> +define i32 @test_unrolled_select(i8* noalias nocapture readonly %blk1, i8* noalias nocapture readonly %blk2, i32 %lx, i32 %h, i32 %lim) #0 { +entry: + %cmp.43 = icmp sgt i32 %h, 0 + br i1 %cmp.43, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %idx.ext = sext i32 %lx to i64 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %if.end.86 + %s.047 = phi i32 [ 0, %for.body.lr.ph ], [ %add82, %if.end.86 ] + %j.046 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %if.end.86 ] + %p2.045 = phi i8* [ %blk2, %for.body.lr.ph ], [ %add.ptr88, %if.end.86 ] + %p1.044 = phi i8* [ %blk1, %for.body.lr.ph ], [ %add.ptr, %if.end.86 ] + %0 = load i8, i8* %p1.044, align 1 + %conv = zext i8 %0 to i32 + %1 = load i8, i8* %p2.045, align 1 + %conv2 = zext i8 %1 to i32 + %sub = sub nsw i32 %conv, %conv2 + %cmp3 = icmp slt i32 %sub, 0 + %sub5 = sub nsw i32 0, %sub + %sub5.sub = select i1 %cmp3, i32 %sub5, i32 %sub + %add = add nsw i32 %sub5.sub, %s.047 + %arrayidx6 = getelementptr inbounds i8, i8* %p1.044, i64 1 + %2 = load i8, i8* %arrayidx6, align 1 + %conv7 = zext i8 %2 to i32 + %arrayidx8 = getelementptr inbounds i8, i8* %p2.045, i64 1 + %3 = load i8, i8* %arrayidx8, align 1 + %conv9 = zext i8 %3 to i32 + %sub10 = sub nsw i32 %conv7, %conv9 + %cmp11 = icmp slt i32 %sub10, 0 + %sub14 = sub nsw i32 0, %sub10 + %v.1 = select i1 %cmp11, i32 %sub14, i32 %sub10 + %add16 = add nsw i32 %add, %v.1 + %arrayidx17 = getelementptr inbounds i8, i8* %p1.044, i64 2 + %4 = load i8, i8* %arrayidx17, align 1 + %conv18 = zext i8 %4 to i32 + %arrayidx19 = getelementptr inbounds i8, i8* %p2.045, i64 2 + %5 = load i8, i8* %arrayidx19, align 1 + %conv20 = zext i8 %5 to i32 + %sub21 = sub nsw i32 %conv18, %conv20 + %cmp22 = icmp slt i32 %sub21, 0 + %sub25 = sub nsw i32 0, %sub21 + %sub25.sub21 = select i1 %cmp22, i32 %sub25, i32 %sub21 + %add27 = add nsw i32 %add16, %sub25.sub21 + %arrayidx28 = getelementptr inbounds i8, i8* %p1.044, i64 3 + %6 = load i8, i8* %arrayidx28, align 1 + %conv29 = zext i8 %6 to i32 + %arrayidx30 = getelementptr inbounds i8, i8* %p2.045, i64 3 + %7 = load i8, i8* %arrayidx30, align 1 + %conv31 = zext i8 %7 to i32 + %sub32 = sub nsw i32 %conv29, %conv31 + %cmp33 = icmp slt i32 %sub32, 0 + %sub36 = sub nsw i32 0, %sub32 + %v.3 = select i1 %cmp33, i32 %sub36, i32 %sub32 + %add38 = add nsw i32 %add27, %v.3 + %arrayidx39 = getelementptr inbounds i8, i8* %p1.044, i64 4 + %8 = load i8, i8* %arrayidx39, align 1 + %conv40 = zext i8 %8 to i32 + %arrayidx41 = getelementptr inbounds i8, i8* %p2.045, i64 4 + %9 = load i8, i8* %arrayidx41, align 1 + %conv42 = zext i8 %9 to i32 + %sub43 = sub nsw i32 %conv40, %conv42 + %cmp44 = icmp slt i32 %sub43, 0 + %sub47 = sub nsw i32 0, %sub43 + %sub47.sub43 = select i1 %cmp44, i32 %sub47, i32 %sub43 + %add49 = add nsw i32 %add38, %sub47.sub43 + %arrayidx50 = getelementptr inbounds i8, i8* %p1.044, i64 5 + %10 = load i8, i8* %arrayidx50, align 1 + %conv51 = zext i8 %10 to i32 + %arrayidx52 = getelementptr inbounds i8, i8* %p2.045, i64 5 + %11 = load i8, i8* %arrayidx52, align 1 + %conv53 = zext i8 %11 to i32 + %sub54 = sub nsw i32 %conv51, %conv53 + %cmp55 = icmp slt i32 %sub54, 0 + %sub58 = sub nsw i32 0, %sub54 + %v.5 = select i1 %cmp55, i32 %sub58, i32 %sub54 + %add60 = add nsw i32 %add49, %v.5 + %arrayidx61 = getelementptr inbounds i8, i8* %p1.044, i64 6 + %12 = load i8, i8* %arrayidx61, align 1 + %conv62 = zext i8 %12 to i32 + %arrayidx63 = getelementptr inbounds i8, i8* %p2.045, i64 6 + %13 = load i8, i8* %arrayidx63, align 1 + %conv64 = zext i8 %13 to i32 + %sub65 = sub nsw i32 %conv62, %conv64 + %cmp66 = icmp slt i32 %sub65, 0 + %sub69 = sub nsw i32 0, %sub65 + %sub69.sub65 = select i1 %cmp66, i32 %sub69, i32 %sub65 + %add71 = add nsw i32 %add60, %sub69.sub65 + %arrayidx72 = getelementptr inbounds i8, i8* %p1.044, i64 7 + %14 = load i8, i8* %arrayidx72, align 1 + %conv73 = zext i8 %14 to i32 + %arrayidx74 = getelementptr inbounds i8, i8* %p2.045, i64 7 + %15 = load i8, i8* %arrayidx74, align 1 + %conv75 = zext i8 %15 to i32 + %sub76 = sub nsw i32 %conv73, %conv75 + %cmp77 = icmp slt i32 %sub76, 0 + %sub80 = sub nsw i32 0, %sub76 + %v.7 = select i1 %cmp77, i32 %sub80, i32 %sub76 + %add82 = add nsw i32 %add71, %v.7 + %cmp83 = icmp slt i32 %add82, %lim + br i1 %cmp83, label %if.end.86, label %for.end.loopexit + +if.end.86: ; preds = %for.body + %add.ptr = getelementptr inbounds i8, i8* %p1.044, i64 %idx.ext + %add.ptr88 = getelementptr inbounds i8, i8* %p2.045, i64 %idx.ext + %inc = add nuw nsw i32 %j.046, 1 + %cmp = icmp slt i32 %inc, %h + br i1 %cmp, label %for.body, label %for.end.loopexit + +for.end.loopexit: ; preds = %for.body, %if.end.86 + br label %for.end + +for.end: ; preds = %for.end.loopexit, %entry + %s.1 = phi i32 [ 0, %entry ], [ %add82, %for.end.loopexit ] + ret i32 %s.1 +} +