Index: include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- include/llvm/Transforms/Utils/LoopUtils.h +++ include/llvm/Transforms/Utils/LoopUtils.h @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" @@ -172,15 +173,25 @@ Value *Left, Value *Right); /// Returns true if Phi is a reduction of type Kind and adds it to the - /// RecurrenceDescriptor. + /// RecurrenceDescriptor. If either \p DB is non-null or \p AC and \p DT are + /// non-null, the minimal bit width needed to compute the reduction will be + /// computed. static bool AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, - RecurrenceDescriptor &RedDes); - - /// Returns true if Phi is a reduction in TheLoop. The RecurrenceDescriptor is - /// returned in RedDes. + RecurrenceDescriptor &RedDes, + DemandedBits *DB = nullptr, + AssumptionCache *AC = nullptr, + DominatorTree *DT = nullptr); + + /// Returns true if Phi is a reduction in TheLoop. The RecurrenceDescriptor + /// is returned in RedDes. If either \p DB is non-null or \p AC and \p DT are + /// non-null, the minimal bit width needed to compute the reduction will be + /// computed. static bool isReductionPHI(PHINode *Phi, Loop *TheLoop, - RecurrenceDescriptor &RedDes); + RecurrenceDescriptor &RedDes, + DemandedBits *DB = nullptr, + AssumptionCache *AC = nullptr, + DominatorTree *DT = nullptr); /// Returns true if Phi is a first-order recurrence. A first-order recurrence /// is a non-reduction recurrence relation in which the value of the @@ -218,24 +229,6 @@ /// Returns true if the recurrence kind is an arithmetic kind. static bool isArithmeticRecurrenceKind(RecurrenceKind Kind); - /// Determines if Phi may have been type-promoted. If Phi has a single user - /// that ANDs the Phi with a type mask, return the user. RT is updated to - /// account for the narrower bit width represented by the mask, and the AND - /// instruction is added to CI. - static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, - SmallPtrSetImpl &Visited, - SmallPtrSetImpl &CI); - - /// Returns true if all the source operands of a recurrence are either - /// SExtInsts or ZExtInsts. This function is intended to be used with - /// lookThroughAnd to determine if the recurrence has been type-promoted. The - /// source operands are added to CI, and IsSigned is updated to indicate if - /// all source operands are SExtInsts. - static bool getSourceExtensionKind(Instruction *Start, Instruction *Exit, - Type *RT, bool &IsSigned, - SmallPtrSetImpl &Visited, - SmallPtrSetImpl &CI); - /// Returns the type of the recurrence. This type can be narrower than the /// actual type of the Phi if the recurrence has been type-promoted. Type *getRecurrenceType() { return RecurrenceType; } Index: lib/Transforms/Utils/LoopUtils.cpp =================================================================== --- lib/Transforms/Utils/LoopUtils.cpp +++ lib/Transforms/Utils/LoopUtils.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -30,6 +31,7 @@ #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -77,10 +79,13 @@ return false; } -Instruction * -RecurrenceDescriptor::lookThroughAnd(PHINode *Phi, Type *&RT, - SmallPtrSetImpl &Visited, - SmallPtrSetImpl &CI) { +/// Determines if Phi may have been type-promoted. If Phi has a single user +/// that ANDs the Phi with a type mask, return the user. RT is updated to +/// account for the narrower bit width represented by the mask, and the AND +/// instruction is added to CI. +static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl &Visited, + SmallPtrSetImpl &CI) { if (!Phi->hasOneUse()) return Phi; @@ -101,70 +106,92 @@ return Phi; } -bool RecurrenceDescriptor::getSourceExtensionKind( - Instruction *Start, Instruction *Exit, Type *RT, bool &IsSigned, - SmallPtrSetImpl &Visited, - SmallPtrSetImpl &CI) { +/// Compute the minimal bit width needed to represent a reduction whose exit +/// instruction is given by Exit. +static std::pair computeRecurrenceType(Instruction *Exit, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { + bool IsSigned = false; + const DataLayout &DL = Exit->getModule()->getDataLayout(); + uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType()); + + if (DB) { + // Use the demanded bits analysis to determine the bits that are live out + // of the exit instruction, rounding up to the nearest power of two. If the + // use of demanded bits results in a smaller bit width, we know the value + // must be positive (i.e., IsSigned = false), because if this were not the + // case, the sign bit would have been demanded. + auto Mask = DB->getDemandedBits(Exit); + MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); + } + + if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { + // If demanded bits wasn't able to limit the bit width, we can try to use + // value tracking instead. This can be the case, for example, if the value + // may be negative. + auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT); + auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType()); + MaxBitWidth = NumTypeBits - NumSignBits; + KnownBits Bits = computeKnownBits(Exit, DL); + if (!Bits.isNonNegative()) { + // If the value is not known to be non-negative, we set IsSigned to true, + // meaning that we will use sext instructions instead of zext + // instructions to restore the original type. + IsSigned = true; + if (!Bits.isNegative()) + // If the value is not known to be negative, we don't known what the + // upper bit is, and therefore, we don't know what kind of extend we + // will need. In this case, just increase the bit width by one bit and + // use sext. + ++MaxBitWidth; + } + } + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + + return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth), + IsSigned); +} + +/// Collect cast instructions that can be ignored in the vectorizer's cost +/// model, given a reduction exit value and the minimal type in which the +/// reduction can be represented. +static void collectCastsToIgnore(Loop *TheLoop, Instruction *Exit, + Type *RecurrenceType, + SmallPtrSetImpl &Casts) { SmallVector Worklist; - bool FoundOneOperand = false; - unsigned DstSize = RT->getPrimitiveSizeInBits(); + SmallPtrSet Visited; Worklist.push_back(Exit); - // Traverse the instructions in the reduction expression, beginning with the - // exit value. while (!Worklist.empty()) { - Instruction *I = Worklist.pop_back_val(); - for (Use &U : I->operands()) { - - // Terminate the traversal if the operand is not an instruction, or we - // reach the starting value. - Instruction *J = dyn_cast(U.get()); - if (!J || J == Start) - continue; - - // Otherwise, investigate the operation if it is also in the expression. - if (Visited.count(J)) { - Worklist.push_back(J); + Instruction *Val = Worklist.pop_back_val(); + Visited.insert(Val); + if (auto *Cast = dyn_cast(Val)) + if (Cast->getSrcTy() == RecurrenceType) { + // If the source type of a cast instruction is equal to the recurrence + // type, it will be eliminated, and should be ignored in the vectorizer + // cost model. + Casts.insert(Cast); continue; } - // If the operand is not in Visited, it is not a reduction operation, but - // it does feed into one. Make sure it is either a single-use sign- or - // zero-extend instruction. - CastInst *Cast = dyn_cast(J); - bool IsSExtInst = isa(J); - if (!Cast || !Cast->hasOneUse() || !(isa(J) || IsSExtInst)) - return false; - - // Ensure the source type of the extend is no larger than the reduction - // type. It is not necessary for the types to be identical. - unsigned SrcSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); - if (SrcSize > DstSize) - return false; - - // Furthermore, ensure that all such extends are of the same kind. - if (FoundOneOperand) { - if (IsSigned != IsSExtInst) - return false; - } else { - FoundOneOperand = true; - IsSigned = IsSExtInst; - } - - // Lastly, if the source type of the extend matches the reduction type, - // add the extend to CI so that we can avoid accounting for it in the - // cost model. - if (SrcSize == DstSize) - CI.insert(Cast); - } + // Add all operands to the work list if they are loop-varying values that + // we haven't yet visited. + for (Value *O : cast(Val)->operands()) + if (auto *I = dyn_cast(O)) + if (TheLoop->contains(I) && !Visited.count(I)) + Worklist.push_back(I); } - return true; } bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, - RecurrenceDescriptor &RedDes) { + RecurrenceDescriptor &RedDes, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { if (Phi->getNumIncomingValues() != 2) return false; @@ -353,14 +380,49 @@ if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) return false; - // If we think Phi may have been type-promoted, we also need to ensure that - // all source operands of the reduction are either SExtInsts or ZEstInsts. If - // so, we will be able to evaluate the reduction in the narrower bit width. - if (Start != Phi) - if (!getSourceExtensionKind(Start, ExitInstruction, RecurrenceType, - IsSigned, VisitedInsts, CastInsts)) + if (Start != Phi) { + // If the starting value is not the same as the phi node, we speculatively + // looked through an 'and' instruction when evaluating a potential + // arithmetic reduction to determine if it may have been type-promoted. + // + // We now compute the minimal bit width that is required to represent the + // reduction. If this is the same width that was indicated by the 'and', we + // can represent the reduction in the smaller type. The 'and' instruction + // will be eliminated since it will essentially be a cast instruction that + // can be ignore in the cost model. If we compute a different type than we + // did when evaluating the 'and', the 'and' will not be eliminated, and we + // will end up with different kinds of operations in the recurrence + // expression (e.g., RK_IntegerAND, RK_IntegerADD). We give up if this is + // the case. + // + // The vectorizer relies on InstCombine to perform the actual + // type-shrinking. It does this by inserting instructions to truncate the + // exit value of the reduction to the width indicated by RecurrenceType and + // then extend this value back to the original width. If IsSigned is false, + // a 'zext' instruction will be generated; otherwise, a 'sext' will be + // used. + // + // TODO: We should not rely on InstCombine to rewrite the reduction in the + // smaller type. We should just generate a correctly typed expression + // to begin with. + Type *ComputedType; + std::tie(ComputedType, IsSigned) = + computeRecurrenceType(ExitInstruction, DB, AC, DT); + if (ComputedType != RecurrenceType) return false; + // The recurrence expression will be represented in a narrower type. If + // there are any cast instructions that will be unnecessary, collect them + // in CastInsts. Note that the 'and' instruction was already included in + // this list. + // + // TODO: A better way to represent this may be to tag in some way all the + // instructions that are a part of the reduction. The vectorizer cost + // model could then apply the recurrence type to these instructions, + // without needing a white list of instructions to ignore. + collectCastsToIgnore(TheLoop, ExitInstruction, RecurrenceType, CastInsts); + } + // We found a reduction var if we have reached the original phi node and we // only have a single instruction with out-of-loop users. @@ -480,47 +542,57 @@ return false; } bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, - RecurrenceDescriptor &RedDes) { + RecurrenceDescriptor &RedDes, + DemandedBits *DB, AssumptionCache *AC, + DominatorTree *DT) { BasicBlock *Header = TheLoop->getHeader(); Function &F = *Header->getParent(); bool HasFunNoNaNAttr = F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; - if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, - RedDes)) { + if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, RedDes, + DB, AC, DT)) { DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes)) { + if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi << "\n"); return true; } Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1539,9 +1539,10 @@ const TargetTransformInfo *TTI, std::function *GetLAA, LoopInfo *LI, OptimizationRemarkEmitter *ORE, LoopVectorizationRequirements *R, - LoopVectorizeHints *H) + LoopVectorizeHints *H, DemandedBits *DB, AssumptionCache *AC) : TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT), GetLAA(GetLAA), - ORE(ORE), InterleaveInfo(PSE, L, DT, LI), Requirements(R), Hints(H) {} + ORE(ORE), InterleaveInfo(PSE, L, DT, LI), Requirements(R), Hints(H), + DB(DB), AC(AC) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1830,6 +1831,14 @@ /// Used to emit an analysis of any legality issues. LoopVectorizeHints *Hints; + /// The demanded bits analsyis is used to compute the minimum type size in + /// which a reduction can be computed. + DemandedBits *DB; + + /// The assumption cache analysis is used to compute the minimum type size in + /// which a reduction can be computed. + AssumptionCache *AC; + /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic SmallPtrSet MaskedOp; @@ -5105,7 +5114,8 @@ } RecurrenceDescriptor RedDes; - if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, RedDes)) { + if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, RedDes, DB, AC, + DT)) { if (RedDes.hasUnsafeAlgebra()) Requirements->addUnsafeAlgebraInst(RedDes.getUnsafeAlgebraInst()); AllowedExit.insert(RedDes.getLoopExitInstr()); @@ -8323,7 +8333,7 @@ // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements(*ORE); LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, GetLAA, LI, ORE, - &Requirements, &Hints); + &Requirements, &Hints, DB, AC); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints, ORE); Index: test/Transforms/LoopVectorize/reduction-small-size.ll =================================================================== --- test/Transforms/LoopVectorize/reduction-small-size.ll +++ test/Transforms/LoopVectorize/reduction-small-size.ll @@ -14,7 +14,7 @@ ; CHECK-NEXT: [[TMP17]] = zext <4 x i8> [[TMP16]] to <4 x i32> ; CHECK-NEXT: br i1 {{.*}}, label %middle.block, label %vector.body ; -define void @PR34687(i1 %c, i32 %x, i32 %n) { +define i8 @PR34687(i1 %c, i32 %x, i32 %n) { entry: br label %for.body @@ -36,5 +36,38 @@ for.end: %tmp2 = phi i32 [ %r.next, %if.end ] - ret void + %tmp3 = trunc i32 %tmp2 to i8 + ret i8 %tmp3 +} + +; CHECK-LABEL: @PR35734( +; CHECK: vector.ph: +; CHECK: [[TMP3:%.*]] = insertelement <4 x i32> zeroinitializer, i32 %y, i32 0 +; CHECK-NEXT: br label %vector.body +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %vector.ph ], [ [[INDEX_NEXT:%.*]], %vector.body ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ [[TMP3]], %vector.ph ], [ [[TMP9:%.*]], %vector.body ] +; CHECK: [[TMP5:%.*]] = and <4 x i32> [[VEC_PHI]], +; CHECK-NEXT: [[TMP6:%.*]] = add <4 x i32> [[TMP5]], +; CHECK-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], 4 +; CHECK: [[TMP8:%.*]] = trunc <4 x i32> [[TMP6]] to <4 x i1> +; CHECK-NEXT: [[TMP9]] = sext <4 x i1> [[TMP8]] to <4 x i32> +; CHECK-NEXT: br i1 {{.*}}, label %middle.block, label %vector.body +; +define i32 @PR35734(i32 %x, i32 %y) { +entry: + br label %for.body + +for.body: + %i = phi i32 [ %x, %entry ], [ %i.next, %for.body ] + %r = phi i32 [ %y, %entry ], [ %r.next, %for.body ] + %tmp0 = and i32 %r, 1 + %r.next = add i32 %tmp0, -1 + %i.next = add nsw i32 %i, 1 + %cond = icmp sgt i32 %i, 77 + br i1 %cond, label %for.end, label %for.body + +for.end: + %tmp1 = phi i32 [ %r.next, %for.body ] + ret i32 %tmp1 }