Index: include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- include/llvm/Transforms/Utils/LoopUtils.h +++ include/llvm/Transforms/Utils/LoopUtils.h @@ -85,13 +85,16 @@ RecurrenceDescriptor() : StartValue(nullptr), LoopExitInstr(nullptr), Kind(RK_NoRecurrence), - MinMaxKind(MRK_Invalid), UnsafeAlgebraInst(nullptr) {} + MinMaxKind(MRK_Invalid), UnsafeAlgebraInst(nullptr), + RecurrenceType(nullptr), IsSigned(false) {} RecurrenceDescriptor(Value *Start, Instruction *Exit, RecurrenceKind K, - MinMaxRecurrenceKind MK, - Instruction *UAI /*Unsafe Algebra Inst*/) + MinMaxRecurrenceKind MK, Instruction *UAI, Type *RT, + bool Signed, SmallPtrSetImpl &CI) : StartValue(Start), LoopExitInstr(Exit), Kind(K), MinMaxKind(MK), - UnsafeAlgebraInst(UAI) {} + UnsafeAlgebraInst(UAI), RecurrenceType(RT), IsSigned(Signed) { + CastInsts.insert(CI.begin(), CI.end()); + } /// This POD struct holds information about a potential recurrence operation. class InstDesc { @@ -184,6 +187,45 @@ /// Returns first unsafe algebra instruction in the PHI node's use-chain. Instruction *getUnsafeAlgebraInst() { return UnsafeAlgebraInst; } + /// Returns true if the recurrence kind is an integer kind. + static bool isIntegerRecurrenceKind(RecurrenceKind Kind); + + /// Returns true if the recurrence kind is a floating point kind. + static bool isFloatingPointRecurrenceKind(RecurrenceKind Kind); + + /// Returns true if the recurrence kind is an arithmetic kind. + static bool isArithmeticRecurrenceKind(RecurrenceKind Kind); + + /// Determines if Phi may have been type-promoted. If Phi has a single user, + /// and this user ANDs the Phi by a power of two, return the user. RT is + /// updated to account for the narrower bit width represented by the AND + /// operation. The AND operation is also added to CI. + static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI); + + /// Determines if all the source operands of a recurrence are either sign- or + /// zero-extends. This function is intended to be used with lookThroughAnd to + /// determine if the recurrence has been type-promoted. Returns 1 if all + /// source operands are sign-extends, 0 if all source operands are + /// zero-extends, and a value less than 0 otherwise. + static int getSourceExtensionKind(Instruction *Start, Instruction *Exit, + Type *RT, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI); + + /// Return the type of the recurrence. This type can be narrower than the + /// actual type of the Phi if the recurrence has been type-promoted. + Type *getRecurrenceType() { return RecurrenceType; } + + /// Return a reference to the instructions used for type-promoting the + /// recurrence. These instructions should be ignored by the LoopVectorizer + /// cost model since they will eventually be erased. + SmallPtrSet &getCastInsts() { return CastInsts; } + + /// Return true if the recurrence has been type-promoted from sign-extends. + bool isSigned() { return IsSigned; } + private: // The starting value of the recurrence. // It does not have to be zero! @@ -196,6 +238,12 @@ MinMaxRecurrenceKind MinMaxKind; // First occurance of unasfe algebra in the PHI's use-chain. Instruction *UnsafeAlgebraInst; + // The type of the recurrence. + Type *RecurrenceType; + // True if the recurrence has been type-promoted from sign-extends. + bool IsSigned; + // Instructions used for type-promoting the recurrence. + SmallPtrSet CastInsts; }; BasicBlock *InsertPreheaderForLoop(Loop *L, Pass *P); Index: lib/Transforms/Utils/LoopUtils.cpp =================================================================== --- lib/Transforms/Utils/LoopUtils.cpp +++ lib/Transforms/Utils/LoopUtils.cpp @@ -34,6 +34,112 @@ return true; } +bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurrenceKind Kind) { + bool IsIntegerRecurrenceKind = false; + switch (Kind) { + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_IntegerOr: + case RK_IntegerAnd: + case RK_IntegerXor: + case RK_IntegerMinMax: + IsIntegerRecurrenceKind = true; + default: + break; + } + return IsIntegerRecurrenceKind; +} + +bool RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind Kind) { + return (Kind != RK_NoRecurrence) && !isIntegerRecurrenceKind(Kind); +} + +bool RecurrenceDescriptor::isArithmeticRecurrenceKind(RecurrenceKind Kind) { + bool IsArithmeticRecurrenceKind = false; + switch (Kind) { + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_FloatAdd: + case RK_FloatMult: + IsArithmeticRecurrenceKind = true; + default: + break; + } + return IsArithmeticRecurrenceKind; +} + +Instruction * +RecurrenceDescriptor::lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI) { + if (Phi->hasOneUse()) { + const APInt *M = nullptr; + Instruction *I, *J = cast(Phi->use_begin()->getUser()); + + // Matches either (I & ) or ( & I). If we find a match, we + // update RT with a new integer type of the corresponding bit width. + if (match(J, m_CombineOr(m_And(m_Instruction(I), m_APInt(M)), + m_And(m_APInt(M), m_Instruction(I))))) { + int32_t Bits = (*M + 1).exactLogBase2(); + if (Bits > 0) { + RT = IntegerType::get(Phi->getContext(), Bits); + Visited.insert(Phi); + CI.insert(J); + return J; + } + } + } + return Phi; +} + +int RecurrenceDescriptor::getSourceExtensionKind( + Instruction *Start, Instruction *Exit, Type *RT, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI) { + + SmallVector Worklist; + int IsSigned = -1; + Worklist.push_back(Exit); + + // Traverse the instructions in the reduction expression, beginning with the + // exit value. + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + for (Use &U : I->operands()) { + + // Terminate the traversal if the operand is not an instruction, or we + // reach the starting value. + Instruction *J = dyn_cast(U.get()); + if (!J || J == Start) + continue; + + // Otherwise, investigate the operation if it is also in the expression. + if (Visited.count(J)) { + Worklist.push_back(J); + continue; + } + + // If the operand is not in Visited, it is not a reduction operation, but + // it does feed into one. Make sure it is either a sign- or zero-extend, + // and record which one. Furthermore, ensure that all such extends are of + // the same kind. + CastInst *Cast = dyn_cast(J); + bool IsSExtInst = isa(J); + if (Cast && Cast->hasOneUse() && Cast->getSrcTy() == RT && + (isa(J) || IsSExtInst)) { + CI.insert(Cast); + if (IsSigned == -1) + IsSigned = IsSExtInst; + else + IsSigned = ((unsigned)IsSigned == IsSExtInst) ? IsSigned : -2; + } else { + IsSigned = -2; + } + } + } + return IsSigned; +} + bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, RecurrenceDescriptor &RedDes) { @@ -68,10 +174,32 @@ unsigned NumCmpSelectPatternInst = 0; InstDesc ReduxDesc(false, nullptr); + // Data used for determining if the recurrence has been type-promoted. + Type *RecurrenceType = Phi->getType(); + SmallPtrSet CastInsts; + Instruction *Start = Phi; + int IsSigned = -1; + SmallPtrSet VisitedInsts; SmallVector Worklist; - Worklist.push_back(Phi); - VisitedInsts.insert(Phi); + + // Return early if the recurrence kind does not match the type of Phi. If the + // recurrence kind is arithmetic, we attempt to look through AND operations + // resulting from the type promotion performed by InstCombine. Vector + // operations are not limited to the legal integer widths, so we may be able + // to evaluate the reduction in the narrower width. + if (RecurrenceType->isFloatingPointTy()) { + if (!isFloatingPointRecurrenceKind(Kind)) + return false; + } else { + if (!isIntegerRecurrenceKind(Kind)) + return false; + if (isArithmeticRecurrenceKind(Kind)) + Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts); + } + + Worklist.push_back(Start); + VisitedInsts.insert(Start); // A value in the reduction can be used: // - By the reduction: @@ -110,10 +238,14 @@ !VisitedInsts.count(dyn_cast(Cur->getOperand(0)))) return false; - // Any reduction instruction must be of one of the allowed kinds. - ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); - if (!ReduxDesc.isRecurrence()) - return false; + // Any reduction instruction must be of one of the allowed kinds. We ignore + // the starting value (the Phi or an AND instruction if the Phi has been + // type-promoted). + if (Cur != Start) { + ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); + if (!ReduxDesc.isRecurrence()) + return false; + } // A reduction operation must only have one use of the reduction value. if (!IsAPhi && Kind != RK_IntegerMinMax && Kind != RK_FloatMinMax && @@ -131,7 +263,7 @@ ++NumCmpSelectPatternInst; // Check whether we found a reduction operator. - FoundReduxOp |= !IsAPhi; + FoundReduxOp |= (!IsAPhi && (Cur != Start)); // Process users of current instruction. Push non-PHI nodes after PHI nodes // onto the stack. This way we are going to have seen all inputs to PHI @@ -193,6 +325,17 @@ if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) return false; + // If we found a truncation that may have resulted from a type promotion, we + // need to ensure that all source values of the reduction are sign- or + // zero-extends. If so, we can evaluate the reduction in the narrower bit + // width. + if (Start != Phi) { + IsSigned = getSourceExtensionKind(Start, ExitInstruction, RecurrenceType, + VisitedInsts, CastInsts); + if (IsSigned < 0) + return false; + } + // We found a reduction var if we have reached the original phi node and we // only have a single instruction with out-of-loop users. @@ -200,10 +343,9 @@ // is saved as part of the RecurrenceDescriptor. // Save the description of this reduction variable. - RecurrenceDescriptor RD(RdxStart, ExitInstruction, Kind, - ReduxDesc.getMinMaxKind(), - ReduxDesc.getUnsafeAlgebraInst()); - + RecurrenceDescriptor RD( + RdxStart, ExitInstruction, Kind, ReduxDesc.getMinMaxKind(), + ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts); RedDes = RD; return true; @@ -272,9 +414,6 @@ default: return InstDesc(false, I); case Instruction::PHI: - if (FP && - (Kind != RK_FloatMult && Kind != RK_FloatAdd && Kind != RK_FloatMinMax)) - return InstDesc(false, I); return InstDesc(I, Prev.getMinMaxKind()); case Instruction::Sub: case Instruction::Add: Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1377,11 +1377,10 @@ LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, AssumptionCache *AC, - const Function *F, const LoopVectorizeHints *Hints) + const Function *F, const LoopVectorizeHints *Hints, + SmallPtrSetImpl &ValuesToIgnore) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), - TheFunction(F), Hints(Hints) { - CodeMetrics::collectEphemeralValues(L, AC, EphValues); - } + TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} /// Information about vectorization costs struct VectorizationFactor { @@ -1450,9 +1449,6 @@ emitAnalysisDiag(TheFunction, TheLoop, *Hints, Message); } - /// Values used only by @llvm.assume calls. - SmallPtrSet EphValues; - /// The loop that we evaluate. Loop *TheLoop; /// Scev analysis. @@ -1468,6 +1464,8 @@ const Function *TheFunction; // Loop Vectorize Hint. const LoopVectorizeHints *Hints; + // Values to ignore in the cost model. + const SmallPtrSetImpl &ValuesToIgnore; }; /// \brief This holds vectorization requirements that must be verified late in @@ -1713,8 +1711,19 @@ return false; } + // Collect values we want to ignore in the cost model. This includes + // type-promoting instructions we identified during reduction detection. + SmallPtrSet ValuesToIgnore; + CodeMetrics::collectEphemeralValues(L, AC, ValuesToIgnore); + for (auto &Reduction : *LVL.getReductionVars()) { + RecurrenceDescriptor &RedDes = Reduction.second; + SmallPtrSetImpl &Casts = RedDes.getCastInsts(); + ValuesToIgnore.insert(Casts.begin(), Casts.end()); + } + // Use the cost model. - LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints); + LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints, + ValuesToIgnore); // Check the function attributes to find out if this function should be // optimized for size. @@ -3300,12 +3309,11 @@ // instructions. Builder.SetInsertPoint(LoopMiddleBlock->getFirstInsertionPt()); - VectorParts RdxParts; + VectorParts RdxParts, &RdxExitVal = getVectorValue(LoopExitInst); setDebugLocFromInst(Builder, LoopExitInst); for (unsigned part = 0; part < UF; ++part) { // This PHINode contains the vectorized reduction variable, or // the initial value vector, if we bypass the vector loop. - VectorParts &RdxExitVal = getVectorValue(LoopExitInst); PHINode *NewPhi = Builder.CreatePHI(VecTy, 2, "rdx.vec.exit.phi"); Value *StartVal = (part == 0) ? VectorStart : Identity; for (unsigned I = 1, E = LoopBypassBlocks.size(); I != E; ++I) @@ -3315,6 +3323,32 @@ RdxParts.push_back(NewPhi); } + // We restrict scalar arithmetic to the supported native integer widths + // identified in the data layout. However, this limitation does not hold + // for vector arithmetic. Below, if the vector reduction can be performed + // in a smaller type, we truncate then extend the loop exit value to enable + // InstCombine to come back and evaluate the entire expression in the + // smaller type. + if ((VF > 1) && (RdxPhi->getType() != RdxDesc.getRecurrenceType())) { + Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); + Builder.SetInsertPoint(LoopVectorBody.back()->getTerminator()); + for (unsigned part = 0; part < UF; ++part) { + Value *Trunc = Builder.CreateTrunc(RdxExitVal[part], RdxVecTy); + Value *Extnd = RdxDesc.isSigned() + ? Builder.CreateSExt(Trunc, VecTy) + : Builder.CreateZExt(Trunc, VecTy); + for (Value::user_iterator UI = RdxExitVal[part]->user_begin(); + UI != RdxExitVal[part]->user_end();) + if (*UI != Trunc) + (*UI++)->replaceUsesOfWith(RdxExitVal[part], Extnd); + else + ++UI; + } + Builder.SetInsertPoint(LoopMiddleBlock->getFirstInsertionPt()); + for (unsigned part = 0; part < UF; ++part) + RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy); + } + // Reduce all of the unrolled parts into a single vector. Value *ReducedPartRdx = RdxParts[0]; unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); @@ -3365,6 +3399,15 @@ // The result is in the first element of the vector. ReducedPartRdx = Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); + + // If the vector reduction can be performed in a smaller type, we still + // have to extend the reduction to the wider type before we jump into the + // original loop. + if (RdxPhi->getType() != RdxDesc.getRecurrenceType()) + ReducedPartRdx = + RdxDesc.isSigned() + ? Builder.CreateSExt(ReducedPartRdx, RdxPhi->getType()) + : Builder.CreateZExt(ReducedPartRdx, RdxPhi->getType()); } // Create a phi node that merges control-flow from the backedge-taken check @@ -4739,18 +4782,22 @@ for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { Type *T = it->getType(); - // Ignore ephemeral values. - if (EphValues.count(it)) + // Skip ignored values. + if (ValuesToIgnore.count(it)) continue; // Only examine Loads, Stores and PHINodes. if (!isa(it) && !isa(it) && !isa(it)) continue; - // Examine PHI nodes that are reduction variables. - if (PHINode *PN = dyn_cast(it)) + // Examine PHI nodes that are reduction variables. Update the type to + // account for the recurrence type. + if (PHINode *PN = dyn_cast(it)) { if (!Legal->getReductionVars()->count(PN)) continue; + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[PN]; + T = RdxDesc.getRecurrenceType(); + } // Examine the stored values. if (StoreInst *ST = dyn_cast(it)) @@ -5011,8 +5058,8 @@ // Ignore instructions that are never used within the loop. if (!Ends.count(I)) continue; - // Ignore ephemeral values. - if (EphValues.count(I)) + // Skip ignored values. + if (ValuesToIgnore.count(I)) continue; // Remove all of the instructions that end at this location. @@ -5055,8 +5102,8 @@ if (isa(it)) continue; - // Ignore ephemeral values. - if (EphValues.count(it)) + // Skip ignored values. + if (ValuesToIgnore.count(it)) continue; unsigned C = getInstructionCost(it, VF); Index: test/Transforms/LoopVectorize/reduction-small-size.ll =================================================================== --- /dev/null +++ test/Transforms/LoopVectorize/reduction-small-size.ll @@ -0,0 +1,128 @@ +; RUN: opt < %s -loop-vectorize -force-vector-interleave=1 -dce -instcombine -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +; CHECK-LABEL: @reduction_i8 +; +; char reduction_i8(char *a, char *b, int n) { +; char sum = 0; +; for (int i = 0; i < n; ++i) +; sum += (a[i] + b[i]); +; return sum; +; } +; +; CHECK: vector.body: +; CHECK: phi <16 x i8> +; CHECK: load <16 x i8> +; CHECK: load <16 x i8> +; CHECK: add <16 x i8> +; CHECK: add <16 x i8> +; +; CHECK: middle.block: +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: [[Rdx:%[a-zA-Z0-9.]+]] = extractelement <16 x i8> +; CHECK: zext i8 [[Rdx]] to i32 +; +define i8 @reduction_i8(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %n) { +entry: + %cmp.12 = icmp sgt i32 %n, 0 + br i1 %cmp.12, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: + br label %for.body + +for.cond.for.cond.cleanup_crit_edge: + %add5.lcssa = phi i32 [ %add5, %for.body ] + %conv6 = trunc i32 %add5.lcssa to i8 + br label %for.cond.cleanup + +for.cond.cleanup: + %sum.0.lcssa = phi i8 [ %conv6, %for.cond.for.cond.cleanup_crit_edge ], [ 0, %entry ] + ret i8 %sum.0.lcssa + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %sum.013 = phi i32 [ %add5, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds i8, i8* %a, i64 %indvars.iv + %0 = load i8, i8* %arrayidx, align 1 + %conv = zext i8 %0 to i32 + %arrayidx2 = getelementptr inbounds i8, i8* %b, i64 %indvars.iv + %1 = load i8, i8* %arrayidx2, align 1 + %conv3 = zext i8 %1 to i32 + %conv4 = and i32 %sum.013, 255 + %add = add nuw nsw i32 %conv, %conv4 + %add5 = add nuw nsw i32 %add, %conv3 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %n + br i1 %exitcond, label %for.cond.for.cond.cleanup_crit_edge, label %for.body +} + +; CHECK-LABEL: @reduction_i16 +; +; short reduction_i16(short *a, short *b, int n) { +; short sum = 0; +; for (int i = 0; i < n; ++i) +; sum += (a[i] + b[i]); +; return sum; +; } +; +; CHECK: vector.body: +; CHECK: phi <8 x i16> +; CHECK: load <8 x i16> +; CHECK: load <8 x i16> +; CHECK: add <8 x i16> +; CHECK: add <8 x i16> +; +; CHECK: middle.block: +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: [[Rdx:%[a-zA-Z0-9.]+]] = extractelement <8 x i16> +; CHECK: zext i16 [[Rdx]] to i32 +; +define i16 @reduction_i16(i16* nocapture readonly %a, i16* nocapture readonly %b, i32 %n) { +entry: + %cmp.16 = icmp sgt i32 %n, 0 + br i1 %cmp.16, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: + br label %for.body + +for.cond.for.cond.cleanup_crit_edge: + %add5.lcssa = phi i32 [ %add5, %for.body ] + %conv6 = trunc i32 %add5.lcssa to i16 + br label %for.cond.cleanup + +for.cond.cleanup: + %sum.0.lcssa = phi i16 [ %conv6, %for.cond.for.cond.cleanup_crit_edge ], [ 0, %entry ] + ret i16 %sum.0.lcssa + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %sum.017 = phi i32 [ %add5, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds i16, i16* %a, i64 %indvars.iv + %0 = load i16, i16* %arrayidx, align 2 + %conv.14 = zext i16 %0 to i32 + %arrayidx2 = getelementptr inbounds i16, i16* %b, i64 %indvars.iv + %1 = load i16, i16* %arrayidx2, align 2 + %conv3.15 = zext i16 %1 to i32 + %conv4.13 = and i32 %sum.017, 65535 + %add = add nuw nsw i32 %conv.14, %conv4.13 + %add5 = add nuw nsw i32 %add, %conv3.15 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %n + br i1 %exitcond, label %for.cond.for.cond.cleanup_crit_edge, label %for.body +}