Index: include/llvm/Analysis/DemandedBits.h =================================================================== --- include/llvm/Analysis/DemandedBits.h +++ include/llvm/Analysis/DemandedBits.h @@ -41,6 +41,8 @@ DemandedBits(Function &F, AssumptionCache &AC, DominatorTree &DT) : F(F), AC(AC), DT(DT), Analyzed(false) {} + virtual ~DemandedBits() {} + /// Return the bits demanded from instruction I. APInt getDemandedBits(Instruction *I); @@ -49,13 +51,13 @@ void print(raw_ostream &OS); -private: +protected: Function &F; AssumptionCache &AC; DominatorTree &DT; void performAnalysis(); - void determineLiveOperandBits(const Instruction *UserI, + virtual void determineLiveOperandBits(const Instruction *UserI, const Instruction *I, unsigned OperandNo, const APInt &AOut, APInt &AB, APInt &KnownZero, APInt &KnownOne, @@ -68,6 +70,32 @@ DenseMap AliveBits; }; +/// A version of DemandedBits that can be fed facts about edges in +/// the use-def graph that it may not be able to compute itself. +class DemandedBitsWithAssumptions : public DemandedBits { +public: + /// A fact about an edge in the use-def graph. Read as " demands + /// from ". + struct EdgeAssumption { + Value *User, *Operand; + APInt DemandedBits; + }; + + /// Create a new DemandedBits implementation based upon \c DB, but with + /// facts from \c Assumptions. + DemandedBitsWithAssumptions(DemandedBits &DB, + ArrayRef Assumptions); + +private: + SmallVector Assumptions; + + virtual void determineLiveOperandBits(const Instruction *UserI, + const Instruction *I, unsigned OperandNo, + const APInt &AOut, APInt &AB, + APInt &KnownZero, APInt &KnownOne, + APInt &KnownZero2, APInt &KnownOne2); +}; + class DemandedBitsWrapperPass : public FunctionPass { private: mutable Optional DB; Index: lib/Analysis/DemandedBits.cpp =================================================================== --- lib/Analysis/DemandedBits.cpp +++ lib/Analysis/DemandedBits.cpp @@ -383,6 +383,32 @@ } } +DemandedBitsWithAssumptions::DemandedBitsWithAssumptions( + DemandedBits &DB, ArrayRef Assumptions_) + : DemandedBits(DB) { + // Always recompute the DB graph. + Analyzed = false; + + for (auto &A : Assumptions_) + Assumptions.push_back(A); +} + +void DemandedBitsWithAssumptions::determineLiveOperandBits( + const Instruction *UserI, const Instruction *I, unsigned OperandNo, + const APInt &AOut, APInt &AB, APInt &KnownZero, APInt &KnownOne, + APInt &KnownZero2, APInt &KnownOne2) { + for (auto &A : Assumptions) { + if (A.User == UserI && A.Operand == I) { + AB = A.DemandedBits; + return; + } + } + + return DemandedBits::determineLiveOperandBits(UserI, I, OperandNo, AOut, AB, + KnownZero, KnownOne, KnownZero2, + KnownOne2); +} + FunctionPass *llvm::createDemandedBitsWrapperPass() { return new DemandedBitsWrapperPass(); } Index: lib/Analysis/VectorUtils.cpp =================================================================== --- lib/Analysis/VectorUtils.cpp +++ lib/Analysis/VectorUtils.cpp @@ -319,7 +319,11 @@ DenseMap DBits; SmallPtrSet InstructionSet; MapVector MinBWs; + SmallVector Assumptions; + assert(!Blocks.empty() && "Need at least one block!"); + const DataLayout &DL = Blocks[0]->getParent()->getParent()->getDataLayout(); + // Determine the roots. We work bottom-up, from truncs or icmps. bool SeenExtFromIllegalType = false; for (auto *BB : Blocks) @@ -341,12 +345,59 @@ Worklist.push_back(&I); Roots.insert(&I); + + // DemandedBits can never truncate ICmps on its own; its model + // is that any bit not demanded becomes undefined. Some ICmps + // have unneeded high bits, but those bits can only be + // removed, not arbitrarily modified. + // + // ICmps where both operands have more than one leading sign + // bit (for signed comparisons) or at least one leading zero + // (for unsigned comparisons) can be executed on truncated + // types. + // + // We identify these here and pass the information to DemandedBits + // using the helper class DemandedBitsWithAssumptions. + if (auto *ICI = dyn_cast(&I)) { + if (isa(ICI->getOperand(0)->getType())) { + unsigned NSB0 = ComputeNumSignBits(ICI->getOperand(0), DL); + unsigned NSB1 = ComputeNumSignBits(ICI->getOperand(1), DL); + unsigned NSB = std::min(NSB0, NSB1); + + bool CanTruncate = true; + if (ICI->isUnsigned()) { + // We know how many sign bits there are. But do we know that the + // sign bit is zero? If it's not, then an unsigned comparison + // cannot be truncated. + bool KnownZero0 = false, KnownZero1 = false, KnownOne = false; + ComputeSignBit(ICI->getOperand(0), KnownZero0, KnownOne, DL); + ComputeSignBit(ICI->getOperand(1), KnownZero1, KnownOne, DL); + + CanTruncate = KnownZero0 && KnownZero1; + } else { + // Make sure at least one sign bit remains after the truncation! + --NSB; + } + + if (CanTruncate) { + auto BW = ICI->getOperand(0)->getType()->getIntegerBitWidth(); + APInt B = APInt::getLowBitsSet(BW, BW - NSB); + + Assumptions.push_back({ICI, ICI->getOperand(0), B}); + Assumptions.push_back({ICI, ICI->getOperand(1), B}); + } + } + } + } } + // Early exit. if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) return MinBWs; + DemandedBitsWithAssumptions DBA(DB, Assumptions); + // Now proceed breadth-first, unioning values together. while (!Worklist.empty()) { Value *Val = Worklist.pop_back_val(); @@ -363,10 +414,10 @@ // If we encounter a type that is larger than 64 bits, we can't represent // it so bail out. - if (DB.getDemandedBits(I).getBitWidth() > 64) + if (DBA.getDemandedBits(I).getBitWidth() > 64) return MapVector(); - uint64_t V = DB.getDemandedBits(I).getZExtValue(); + uint64_t V = DBA.getDemandedBits(I).getZExtValue(); DBits[Leader] |= V; DBits[I] = V; Index: test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll =================================================================== --- test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll +++ test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll @@ -263,5 +263,41 @@ br i1 %exitcond, label %for.cond.cleanup, label %for.body } +; CHECK-LABEL: @add_g +; CHECK: load <16 x i8> +; CHECK: xor <16 x i8> +; CHECK: icmp ult <16 x i8> +; CHECK: select <16 x i1> {{.*}}, <16 x i8> +; CHECK: store <16 x i8> +define void @add_g(i8* noalias nocapture readonly %p, i8* noalias nocapture readonly %q, i8* noalias nocapture +%r, i8 %arg1, i32 %len) #0 { + %1 = icmp sgt i32 %len, 0 + br i1 %1, label %.lr.ph, label %._crit_edge + +.lr.ph: ; preds = %0 + %2 = sext i8 %arg1 to i64 + br label %3 + +._crit_edge: ; preds = %3, %0 + ret void + +;