Index: include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- include/llvm/Transforms/Utils/LoopUtils.h +++ include/llvm/Transforms/Utils/LoopUtils.h @@ -85,13 +85,16 @@ RecurrenceDescriptor() : StartValue(nullptr), LoopExitInstr(nullptr), Kind(RK_NoRecurrence), - MinMaxKind(MRK_Invalid), UnsafeAlgebraInst(nullptr) {} + MinMaxKind(MRK_Invalid), UnsafeAlgebraInst(nullptr), + RecurrenceType(nullptr), IsSigned(false) {} RecurrenceDescriptor(Value *Start, Instruction *Exit, RecurrenceKind K, - MinMaxRecurrenceKind MK, - Instruction *UAI /*Unsafe Algebra Inst*/) + MinMaxRecurrenceKind MK, Instruction *UAI, Type *RT, + bool Signed, SmallPtrSetImpl &CI) : StartValue(Start), LoopExitInstr(Exit), Kind(K), MinMaxKind(MK), - UnsafeAlgebraInst(UAI) {} + UnsafeAlgebraInst(UAI), RecurrenceType(RT), IsSigned(Signed) { + CastInsts.insert(CI.begin(), CI.end()); + } /// This POD struct holds information about a potential recurrence operation. class InstDesc { @@ -184,6 +187,47 @@ /// Returns first unsafe algebra instruction in the PHI node's use-chain. Instruction *getUnsafeAlgebraInst() { return UnsafeAlgebraInst; } + /// Returns true if the recurrence kind is an integer kind. + static bool isIntegerRecurrenceKind(RecurrenceKind Kind); + + /// Returns true if the recurrence kind is a floating point kind. + static bool isFloatingPointRecurrenceKind(RecurrenceKind Kind); + + /// Returns true if the recurrence kind is an arithmetic kind. Arithmetic + /// recurrence kinds include addition and multiplication. We distinguish + /// these from the bitwise (and, or, xor) and logical (min/max) kinds because + /// the arithmetic operations are subject to type-promotion. + static bool isArithmeticRecurrenceKind(RecurrenceKind Kind); + + /// Determines if Phi may have been type-promoted. If Phi has a single user + /// that ANDs the Phi with a type mask, return the user. RT is updated to + /// account for the narrower bit width represented by the mask, and the AND + /// instruction is added to CI. + static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI); + + /// Returns true if all the source operands of a recurrence are either + /// SExtInsts or ZExtInsts. This function is intended to be used with + /// lookThroughAnd to determine if the recurrence has been type-promoted. The + /// source operands are added to CI, and IsSigned is updated to indicate if + /// all source operands are SExtInsts. + static bool getSourceExtensionKind(Instruction *Start, Instruction *Exit, + Type *RT, bool &IsSigned, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI); + + /// Returns the type of the recurrence. This type can be narrower than the + /// actual type of the Phi if the recurrence has been type-promoted. + Type *getRecurrenceType() { return RecurrenceType; } + + /// Returns a reference to the instructions used for type-promoting the + /// recurrence. + SmallPtrSet &getCastInsts() { return CastInsts; } + + /// Returns true if all source operands of the recurrence are SExtInsts. + bool isSigned() { return IsSigned; } + private: // The starting value of the recurrence. // It does not have to be zero! @@ -196,6 +240,12 @@ MinMaxRecurrenceKind MinMaxKind; // First occurance of unasfe algebra in the PHI's use-chain. Instruction *UnsafeAlgebraInst; + // The type of the recurrence. + Type *RecurrenceType; + // True if all source operands of the recurrence are SExtInsts. + bool IsSigned; + // Instructions used for type-promoting the recurrence. + SmallPtrSet CastInsts; }; /// A struct for saving information about induction variables. Index: lib/Transforms/Utils/LoopUtils.cpp =================================================================== --- lib/Transforms/Utils/LoopUtils.cpp +++ lib/Transforms/Utils/LoopUtils.cpp @@ -34,6 +34,116 @@ return true; } +bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurrenceKind Kind) { + switch (Kind) { + default: + break; + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_IntegerOr: + case RK_IntegerAnd: + case RK_IntegerXor: + case RK_IntegerMinMax: + return true; + } + return false; +} + +bool RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind Kind) { + return (Kind != RK_NoRecurrence) && !isIntegerRecurrenceKind(Kind); +} + +bool RecurrenceDescriptor::isArithmeticRecurrenceKind(RecurrenceKind Kind) { + switch (Kind) { + default: + break; + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_FloatAdd: + case RK_FloatMult: + return true; + } + return false; +} + +Instruction * +RecurrenceDescriptor::lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI) { + if (!Phi->hasOneUse()) + return Phi; + + const APInt *M = nullptr; + Instruction *I, *J = cast(Phi->use_begin()->getUser()); + + // Matches either I & 2^x-1 or 2^x-1 & I. If we find a match, we update RT + // with a new integer type of the corresponding bit width. + if (match(J, m_CombineOr(m_And(m_Instruction(I), m_APInt(M)), + m_And(m_APInt(M), m_Instruction(I))))) { + int32_t Bits = (*M + 1).exactLogBase2(); + if (Bits > 0) { + RT = IntegerType::get(Phi->getContext(), Bits); + Visited.insert(Phi); + CI.insert(J); + return J; + } + } + return Phi; +} + +bool RecurrenceDescriptor::getSourceExtensionKind( + Instruction *Start, Instruction *Exit, Type *RT, bool &IsSigned, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI) { + + SmallVector Worklist; + bool FoundOneOperand = false; + Worklist.push_back(Exit); + + // Traverse the instructions in the reduction expression, beginning with the + // exit value. + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + for (Use &U : I->operands()) { + + // Terminate the traversal if the operand is not an instruction, or we + // reach the starting value. + Instruction *J = dyn_cast(U.get()); + if (!J || J == Start) + continue; + + // Otherwise, investigate the operation if it is also in the expression. + if (Visited.count(J)) { + Worklist.push_back(J); + continue; + } + + // If the operand is not in Visited, it is not a reduction operation, but + // it does feed into one. Make sure it is either a single-use sign- or + // zero-extend of the recurrence type. + CastInst *Cast = dyn_cast(J); + bool IsSExtInst = isa(J); + if (!Cast || !Cast->hasOneUse() || Cast->getSrcTy() != RT || + !(isa(J) || IsSExtInst)) + return false; + + // Furthermore, ensure that all such extends are of the same kind. + if (FoundOneOperand) { + if (IsSigned != IsSExtInst) + return false; + } else { + FoundOneOperand = true; + IsSigned = IsSExtInst; + } + + // Lastly, add the sign- or zero-extend to CI so that we can avoid + // accounting for it in the cost model. + CI.insert(Cast); + } + } + return true; +} + bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, RecurrenceDescriptor &RedDes) { @@ -68,10 +178,32 @@ unsigned NumCmpSelectPatternInst = 0; InstDesc ReduxDesc(false, nullptr); + // Data used for determining if the recurrence has been type-promoted. + Type *RecurrenceType = Phi->getType(); + SmallPtrSet CastInsts; + Instruction *Start = Phi; + bool IsSigned = false; + SmallPtrSet VisitedInsts; SmallVector Worklist; - Worklist.push_back(Phi); - VisitedInsts.insert(Phi); + + // Return early if the recurrence kind does not match the type of Phi. If the + // recurrence kind is arithmetic, we attempt to look through AND operations + // resulting from the type promotion performed by InstCombine. Vector + // operations are not limited to the legal integer widths, so we may be able + // to evaluate the reduction in the narrower width. + if (RecurrenceType->isFloatingPointTy()) { + if (!isFloatingPointRecurrenceKind(Kind)) + return false; + } else { + if (!isIntegerRecurrenceKind(Kind)) + return false; + if (isArithmeticRecurrenceKind(Kind)) + Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts); + } + + Worklist.push_back(Start); + VisitedInsts.insert(Start); // A value in the reduction can be used: // - By the reduction: @@ -110,10 +242,14 @@ !VisitedInsts.count(dyn_cast(Cur->getOperand(0)))) return false; - // Any reduction instruction must be of one of the allowed kinds. - ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); - if (!ReduxDesc.isRecurrence()) - return false; + // Any reduction instruction must be of one of the allowed kinds. We ignore + // the starting value (the Phi or an AND instruction if the Phi has been + // type-promoted). + if (Cur != Start) { + ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); + if (!ReduxDesc.isRecurrence()) + return false; + } // A reduction operation must only have one use of the reduction value. if (!IsAPhi && Kind != RK_IntegerMinMax && Kind != RK_FloatMinMax && @@ -131,7 +267,7 @@ ++NumCmpSelectPatternInst; // Check whether we found a reduction operator. - FoundReduxOp |= !IsAPhi; + FoundReduxOp |= !IsAPhi && Cur != Start; // Process users of current instruction. Push non-PHI nodes after PHI nodes // onto the stack. This way we are going to have seen all inputs to PHI @@ -193,6 +329,14 @@ if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) return false; + // If we think Phi may have been type-promoted, we also need to ensure that + // all source operands of the reduction are either SExtInsts or ZEstInsts. If + // so, we will be able to evaluate the reduction in the narrower bit width. + if (Start != Phi) + if (!getSourceExtensionKind(Start, ExitInstruction, RecurrenceType, + IsSigned, VisitedInsts, CastInsts)) + return false; + // We found a reduction var if we have reached the original phi node and we // only have a single instruction with out-of-loop users. @@ -200,10 +344,9 @@ // is saved as part of the RecurrenceDescriptor. // Save the description of this reduction variable. - RecurrenceDescriptor RD(RdxStart, ExitInstruction, Kind, - ReduxDesc.getMinMaxKind(), - ReduxDesc.getUnsafeAlgebraInst()); - + RecurrenceDescriptor RD( + RdxStart, ExitInstruction, Kind, ReduxDesc.getMinMaxKind(), + ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts); RedDes = RD; return true; @@ -272,9 +415,6 @@ default: return InstDesc(false, I); case Instruction::PHI: - if (FP && - (Kind != RK_FloatMult && Kind != RK_FloatAdd && Kind != RK_FloatMinMax)) - return InstDesc(false, I); return InstDesc(I, Prev.getMinMaxKind()); case Instruction::Sub: case Instruction::Add: Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1308,11 +1308,10 @@ LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, AssumptionCache *AC, - const Function *F, const LoopVectorizeHints *Hints) + const Function *F, const LoopVectorizeHints *Hints, + SmallPtrSetImpl &ValuesToIgnore) : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), - TheFunction(F), Hints(Hints) { - CodeMetrics::collectEphemeralValues(L, AC, EphValues); - } + TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} /// Information about vectorization costs struct VectorizationFactor { @@ -1381,9 +1380,6 @@ emitAnalysisDiag(TheFunction, TheLoop, *Hints, Message); } - /// Values used only by @llvm.assume calls. - SmallPtrSet EphValues; - /// The loop that we evaluate. Loop *TheLoop; /// Scev analysis. @@ -1399,6 +1395,8 @@ const Function *TheFunction; // Loop Vectorize Hint. const LoopVectorizeHints *Hints; + // Values to ignore in the cost model. + const SmallPtrSetImpl &ValuesToIgnore; }; /// \brief This holds vectorization requirements that must be verified late in @@ -1643,8 +1641,19 @@ return false; } + // Collect values we want to ignore in the cost model. This includes + // type-promoting instructions we identified during reduction detection. + SmallPtrSet ValuesToIgnore; + CodeMetrics::collectEphemeralValues(L, AC, ValuesToIgnore); + for (auto &Reduction : *LVL.getReductionVars()) { + RecurrenceDescriptor &RedDes = Reduction.second; + SmallPtrSetImpl &Casts = RedDes.getCastInsts(); + ValuesToIgnore.insert(Casts.begin(), Casts.end()); + } + // Use the cost model. - LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints); + LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints, + ValuesToIgnore); // Check the function attributes to find out if this function should be // optimized for size. @@ -3234,12 +3243,11 @@ // instructions. Builder.SetInsertPoint(LoopMiddleBlock->getFirstInsertionPt()); - VectorParts RdxParts; + VectorParts RdxParts, &RdxExitVal = getVectorValue(LoopExitInst); setDebugLocFromInst(Builder, LoopExitInst); for (unsigned part = 0; part < UF; ++part) { // This PHINode contains the vectorized reduction variable, or // the initial value vector, if we bypass the vector loop. - VectorParts &RdxExitVal = getVectorValue(LoopExitInst); PHINode *NewPhi = Builder.CreatePHI(VecTy, 2, "rdx.vec.exit.phi"); Value *StartVal = (part == 0) ? VectorStart : Identity; for (unsigned I = 1, E = LoopBypassBlocks.size(); I != E; ++I) @@ -3249,6 +3257,34 @@ RdxParts.push_back(NewPhi); } + // If the vector reduction can be performed in a smaller type, we truncate + // then extend the loop exit value to enable InstCombine to evaluate the + // entire expression in the smaller type. + // + // FIXME: This isn't an ideal solution. It would preferable to generate + // expressions in the smaller type on the fly as we vectorize the + // loop. In the code below, the truncate/extend is a necessary cue + // for InstCombiner::EvaluateInDifferentType(). + if (VF > 1 && RdxPhi->getType() != RdxDesc.getRecurrenceType()) { + Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); + Builder.SetInsertPoint(LoopVectorBody.back()->getTerminator()); + for (unsigned part = 0; part < UF; ++part) { + Value *Trunc = Builder.CreateTrunc(RdxExitVal[part], RdxVecTy); + Value *Extnd = RdxDesc.isSigned() + ? Builder.CreateSExt(Trunc, VecTy) + : Builder.CreateZExt(Trunc, VecTy); + for (Value::user_iterator UI = RdxExitVal[part]->user_begin(); + UI != RdxExitVal[part]->user_end();) + if (*UI != Trunc) + (*UI++)->replaceUsesOfWith(RdxExitVal[part], Extnd); + else + ++UI; + } + Builder.SetInsertPoint(LoopMiddleBlock->getFirstInsertionPt()); + for (unsigned part = 0; part < UF; ++part) + RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy); + } + // Reduce all of the unrolled parts into a single vector. Value *ReducedPartRdx = RdxParts[0]; unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); @@ -3299,6 +3335,14 @@ // The result is in the first element of the vector. ReducedPartRdx = Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); + + // If the reduction can be performed in a smaller type, we need to extend + // the reduction to the wider type before we branch to the original loop. + if (RdxPhi->getType() != RdxDesc.getRecurrenceType()) + ReducedPartRdx = + RdxDesc.isSigned() + ? Builder.CreateSExt(ReducedPartRdx, RdxPhi->getType()) + : Builder.CreateZExt(ReducedPartRdx, RdxPhi->getType()); } // Create a phi node that merges control-flow from the backedge-taken check @@ -4652,18 +4696,22 @@ for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { Type *T = it->getType(); - // Ignore ephemeral values. - if (EphValues.count(it)) + // Skip ignored values. + if (ValuesToIgnore.count(it)) continue; // Only examine Loads, Stores and PHINodes. if (!isa(it) && !isa(it) && !isa(it)) continue; - // Examine PHI nodes that are reduction variables. - if (PHINode *PN = dyn_cast(it)) + // Examine PHI nodes that are reduction variables. Update the type to + // account for the recurrence type. + if (PHINode *PN = dyn_cast(it)) { if (!Legal->getReductionVars()->count(PN)) continue; + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[PN]; + T = RdxDesc.getRecurrenceType(); + } // Examine the stored values. if (StoreInst *ST = dyn_cast(it)) @@ -4924,8 +4972,8 @@ // Ignore instructions that are never used within the loop. if (!Ends.count(I)) continue; - // Ignore ephemeral values. - if (EphValues.count(I)) + // Skip ignored values. + if (ValuesToIgnore.count(I)) continue; // Remove all of the instructions that end at this location. @@ -4968,8 +5016,8 @@ if (isa(it)) continue; - // Ignore ephemeral values. - if (EphValues.count(it)) + // Skip ignored values. + if (ValuesToIgnore.count(it)) continue; unsigned C = getInstructionCost(it, VF); Index: test/Transforms/LoopVectorize/reduction-small-size.ll =================================================================== --- /dev/null +++ test/Transforms/LoopVectorize/reduction-small-size.ll @@ -0,0 +1,128 @@ +; RUN: opt < %s -loop-vectorize -force-vector-interleave=1 -dce -instcombine -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +; CHECK-LABEL: @reduction_i8 +; +; char reduction_i8(char *a, char *b, int n) { +; char sum = 0; +; for (int i = 0; i < n; ++i) +; sum += (a[i] + b[i]); +; return sum; +; } +; +; CHECK: vector.body: +; CHECK: phi <16 x i8> +; CHECK: load <16 x i8> +; CHECK: load <16 x i8> +; CHECK: add <16 x i8> +; CHECK: add <16 x i8> +; +; CHECK: middle.block: +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: shufflevector <16 x i8> +; CHECK: add <16 x i8> +; CHECK: [[Rdx:%[a-zA-Z0-9.]+]] = extractelement <16 x i8> +; CHECK: zext i8 [[Rdx]] to i32 +; +define i8 @reduction_i8(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %n) { +entry: + %cmp.12 = icmp sgt i32 %n, 0 + br i1 %cmp.12, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: + br label %for.body + +for.cond.for.cond.cleanup_crit_edge: + %add5.lcssa = phi i32 [ %add5, %for.body ] + %conv6 = trunc i32 %add5.lcssa to i8 + br label %for.cond.cleanup + +for.cond.cleanup: + %sum.0.lcssa = phi i8 [ %conv6, %for.cond.for.cond.cleanup_crit_edge ], [ 0, %entry ] + ret i8 %sum.0.lcssa + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %sum.013 = phi i32 [ %add5, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds i8, i8* %a, i64 %indvars.iv + %0 = load i8, i8* %arrayidx, align 1 + %conv = zext i8 %0 to i32 + %arrayidx2 = getelementptr inbounds i8, i8* %b, i64 %indvars.iv + %1 = load i8, i8* %arrayidx2, align 1 + %conv3 = zext i8 %1 to i32 + %conv4 = and i32 %sum.013, 255 + %add = add nuw nsw i32 %conv, %conv4 + %add5 = add nuw nsw i32 %add, %conv3 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %n + br i1 %exitcond, label %for.cond.for.cond.cleanup_crit_edge, label %for.body +} + +; CHECK-LABEL: @reduction_i16 +; +; short reduction_i16(short *a, short *b, int n) { +; short sum = 0; +; for (int i = 0; i < n; ++i) +; sum += (a[i] + b[i]); +; return sum; +; } +; +; CHECK: vector.body: +; CHECK: phi <8 x i16> +; CHECK: load <8 x i16> +; CHECK: load <8 x i16> +; CHECK: add <8 x i16> +; CHECK: add <8 x i16> +; +; CHECK: middle.block: +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: [[Rdx:%[a-zA-Z0-9.]+]] = extractelement <8 x i16> +; CHECK: zext i16 [[Rdx]] to i32 +; +define i16 @reduction_i16(i16* nocapture readonly %a, i16* nocapture readonly %b, i32 %n) { +entry: + %cmp.16 = icmp sgt i32 %n, 0 + br i1 %cmp.16, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: + br label %for.body + +for.cond.for.cond.cleanup_crit_edge: + %add5.lcssa = phi i32 [ %add5, %for.body ] + %conv6 = trunc i32 %add5.lcssa to i16 + br label %for.cond.cleanup + +for.cond.cleanup: + %sum.0.lcssa = phi i16 [ %conv6, %for.cond.for.cond.cleanup_crit_edge ], [ 0, %entry ] + ret i16 %sum.0.lcssa + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %sum.017 = phi i32 [ %add5, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds i16, i16* %a, i64 %indvars.iv + %0 = load i16, i16* %arrayidx, align 2 + %conv.14 = zext i16 %0 to i32 + %arrayidx2 = getelementptr inbounds i16, i16* %b, i64 %indvars.iv + %1 = load i16, i16* %arrayidx2, align 2 + %conv3.15 = zext i16 %1 to i32 + %conv4.13 = and i32 %sum.017, 65535 + %add = add nuw nsw i32 %conv.14, %conv4.13 + %add5 = add nuw nsw i32 %add, %conv3.15 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %n + br i1 %exitcond, label %for.cond.for.cond.cleanup_crit_edge, label %for.body +}