Index: include/llvm/Analysis/DemandedBits.h =================================================================== --- include/llvm/Analysis/DemandedBits.h +++ include/llvm/Analysis/DemandedBits.h @@ -44,6 +44,9 @@ /// Return the bits demanded from instruction I. APInt getDemandedBits(Instruction *I); + /// Return the number of sign bits in I as deduced by DemandedBits. + unsigned getNumSignBits(Instruction *I); + /// Return true if, during analysis, I could not be reached. bool isInstructionDead(Instruction *I); @@ -53,19 +56,20 @@ Function &F; AssumptionCache &AC; DominatorTree &DT; - + bool Analyzed; + void performAnalysis(); void determineLiveOperandBits(const Instruction *UserI, - const Instruction *I, unsigned OperandNo, - const APInt &AOut, APInt &AB, - APInt &KnownZero, APInt &KnownOne, - APInt &KnownZero2, APInt &KnownOne2); - - bool Analyzed; + const Instruction *I, unsigned OperandNo, + const APInt &AOut, int SOut, + APInt &AB, int &SB, + APInt &KnownZero, APInt &KnownOne, + APInt &KnownZero2, APInt &KnownOne2); // The set of visited instructions (non-integer-typed only). SmallPtrSet Visited; DenseMap AliveBits; + DenseMap SignBits; }; class DemandedBitsWrapperPass : public FunctionPass { Index: lib/Analysis/DemandedBits.cpp =================================================================== --- lib/Analysis/DemandedBits.cpp +++ lib/Analysis/DemandedBits.cpp @@ -72,8 +72,8 @@ void DemandedBits::determineLiveOperandBits( const Instruction *UserI, const Instruction *I, unsigned OperandNo, - const APInt &AOut, APInt &AB, APInt &KnownZero, APInt &KnownOne, - APInt &KnownZero2, APInt &KnownOne2) { + const APInt &AOut, int SOut, APInt &AB, int &SB, + APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2) { unsigned BitWidth = AB.getBitWidth(); // We're called once per operand, but for some instructions, we need to @@ -109,6 +109,7 @@ // The alive bits of the input are the swapped alive bits of // the output. AB = AOut.byteSwap(); + SB = 0; break; case Intrinsic::ctlz: if (OperandNo == 0) { @@ -118,6 +119,7 @@ ComputeKnownBits(BitWidth, I, nullptr); AB = APInt::getHighBitsSet(BitWidth, std::min(BitWidth, KnownOne.countLeadingZeros()+1)); + SB = 0; } break; case Intrinsic::cttz: @@ -128,6 +130,7 @@ ComputeKnownBits(BitWidth, I, nullptr); AB = APInt::getLowBitsSet(BitWidth, std::min(BitWidth, KnownOne.countTrailingZeros()+1)); + SB = 0; } break; } @@ -139,6 +142,7 @@ // bits than that (adds, and thus subtracts, ripple only to the // left). AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits()); + SB = 0; break; case Instruction::Shl: if (OperandNo == 0) @@ -146,6 +150,7 @@ dyn_cast(UserI->getOperand(1))) { uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1); AB = AOut.lshr(ShiftAmt); + SB = 0; // The sign bit has changed. // If the shift is nuw/nsw, then the high bits are not dead // (because we've promised that they *must* be zero). @@ -162,6 +167,10 @@ dyn_cast(UserI->getOperand(1))) { uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1); AB = AOut.shl(ShiftAmt); + // FIXME: LShr shifts in zero bits. If the MSB was originally + // a zero, it is shifting in more sign bits and we can be + // less conservative below. + SB = 0; // If the shift is exact, then the low bits are not dead // (they must be zero). @@ -175,6 +184,8 @@ dyn_cast(UserI->getOperand(1))) { uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1); AB = AOut.shl(ShiftAmt); + SB = std::max(0, SOut - (int)ShiftAmt); + // Because the high input bit is replicated into the // high-order bits of the result, if we need any of those // bits, then we must keep the highest input bit. @@ -190,7 +201,8 @@ break; case Instruction::And: AB = AOut; - + SB = 0; + // For bits that are known zero, the corresponding bits in the // other operand are dead (unless they're both zero, in which // case they can't both be dead, so just mark the LHS bits as @@ -206,6 +218,7 @@ break; case Instruction::Or: AB = AOut; + SB = 0; // For bits that are known one, the corresponding bits in the // other operand are dead (unless they're both one, in which @@ -221,17 +234,42 @@ } break; case Instruction::Xor: + AB = AOut; + SB = 0; + + if (OperandNo == 0 && isa(UserI->getOperand(1))) { + auto *CI = cast(UserI->getOperand(1)); + auto &C = CI->getValue(); + + // If the constant to xor with starts with a zero, the number + // of sign bits is truncated to the number of leading zeroes + // because the sign bit is not changed and neither is any bit + // xor'd with zero. + // + // If it starts with a one, the sign bit is flipped. Therefore + // the number of sign bits is truncated to the number of leading + // *ones*. + auto TruncateTo = + std::max(C.countLeadingZeros(), C.countLeadingOnes()); + SB = std::min(SOut, (int)TruncateTo); + } + + break; case Instruction::PHI: AB = AOut; + SB = SOut; break; case Instruction::Trunc: AB = AOut.zext(BitWidth); + SB = 0; break; case Instruction::ZExt: AB = AOut.trunc(BitWidth); + SB = 0; break; case Instruction::SExt: AB = AOut.trunc(BitWidth); + SB = 0; // Because the high input bit is replicated into the // high-order bits of the result, if we need any of those // bits, then we must keep the highest input bit. @@ -241,8 +279,20 @@ AB.setBit(BitWidth-1); break; case Instruction::Select: - if (OperandNo != 0) + if (OperandNo != 0) { AB = AOut; + SB = SOut; + } + break; + case Instruction::ICmp: + // Count the number of leading zeroes in each operand. + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + auto NumLeadingZeroes = std::min(KnownZero.countLeadingOnes(), + KnownZero2.countLeadingOnes()); + auto NumLeadingOnes = std::min(KnownOne.countLeadingOnes(), + KnownOne2.countLeadingOnes()); + auto NumSignBits = std::max(NumLeadingZeroes, NumLeadingOnes); + SB = NumSignBits; break; } } @@ -266,6 +316,7 @@ Visited.clear(); AliveBits.clear(); + SignBits.clear(); SmallVector Worklist; @@ -281,7 +332,11 @@ // all bits as live). if (IntegerType *IT = dyn_cast(I.getType())) { if (!AliveBits.count(&I)) { + const DataLayout &DL = I.getModule()->getDataLayout(); + auto NSB = ComputeNumSignBits(&I, DL); + AliveBits[&I] = APInt(IT->getBitWidth(), 0); + SignBits[&I] = NSB; Worklist.push_back(&I); } @@ -291,8 +346,10 @@ // Non-integer-typed instructions... for (Use &OI : I.operands()) { if (Instruction *J = dyn_cast(OI)) { - if (IntegerType *IT = dyn_cast(J->getType())) + if (IntegerType *IT = dyn_cast(J->getType())) { AliveBits[J] = APInt::getAllOnesValue(IT->getBitWidth()); + SignBits[J] = 0; + } Worklist.push_back(J); } } @@ -308,9 +365,12 @@ DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI); APInt AOut; + int SOut; if (UserI->getType()->isIntegerTy()) { AOut = AliveBits[UserI]; + SOut = SignBits[UserI]; DEBUG(dbgs() << " Alive Out: " << AOut); + DEBUG(dbgs() << " Sign Out: " << SOut); } DEBUG(dbgs() << "\n"); @@ -326,6 +386,7 @@ if (IntegerType *IT = dyn_cast(I->getType())) { unsigned BitWidth = IT->getBitWidth(); APInt AB = APInt::getAllOnesValue(BitWidth); + int SB = 0; if (UserI->getType()->isIntegerTy() && !AOut && !isAlwaysLive(UserI)) { AB = APInt(BitWidth, 0); @@ -333,7 +394,7 @@ // If all bits of the output are dead, then all bits of the input // Bits of each operand that are used to compute alive bits of the // output are alive, all others are dead. - determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB, + determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, SOut, AB, SB, KnownZero, KnownOne, KnownZero2, KnownOne2); } @@ -346,10 +407,24 @@ if (ABI != AliveBits.end()) ABPrev = ABI->second; + bool AddedToWorklist = false; APInt ABNew = AB | ABPrev; if (ABNew != ABPrev || ABI == AliveBits.end()) { AliveBits[I] = std::move(ABNew); Worklist.push_back(I); + AddedToWorklist = true; + } + + APInt SPrev(BitWidth, 0); + auto SI = SignBits.find(I); + if (SI != SignBits.end()) + SPrev = SI->second; + + auto SNew = SB; + if (SNew != SPrev || SI == SignBits.end()) { + SignBits[I] = std::move(SNew); + if (!AddedToWorklist) + Worklist.push_back(I); } } else if (!Visited.count(I)) { Worklist.push_back(I); @@ -368,6 +443,14 @@ return APInt::getAllOnesValue(DL.getTypeSizeInBits(I->getType())); } +unsigned DemandedBits::getNumSignBits(Instruction *I) { + performAnalysis(); + + if (SignBits.count(I)) + return SignBits[I]; + return 0; +} + bool DemandedBits::isInstructionDead(Instruction *I) { performAnalysis(); @@ -378,8 +461,10 @@ void DemandedBits::print(raw_ostream &OS) { performAnalysis(); for (auto &KV : AliveBits) { - OS << "DemandedBits: 0x" << utohexstr(KV.second.getLimitedValue()) << " for " - << *KV.first << "\n"; + auto SB = SignBits.find(KV.first)->second; + auto DB = KV.second; + OS << "DemandedBits: 0x" << utohexstr(DB.getLimitedValue()) << " for " + << *KV.first << " (SignBits=" << SB << ")\n"; } } Index: lib/Analysis/VectorUtils.cpp =================================================================== --- lib/Analysis/VectorUtils.cpp +++ lib/Analysis/VectorUtils.cpp @@ -363,12 +363,14 @@ // If we encounter a type that is larger than 64 bits, we can't represent // it so bail out. - if (DB.getDemandedBits(I).getBitWidth() > 64) + unsigned BW = DB.getDemandedBits(I).getBitWidth(); + if (BW > 64) return MapVector(); - uint64_t V = DB.getDemandedBits(I).getZExtValue(); - DBits[Leader] |= V; - DBits[I] = V; + APInt V1 = APInt::getLowBitsSet(BW, BW - DB.getNumSignBits(I)); + APInt V2 = DB.getDemandedBits(I); + DBits[Leader] |= (V2 & V1).getZExtValue(); + DBits[I] |= (V2 & V1).getZExtValue(); // Casts, loads and instructions outside of our range terminate a chain // successfully. Index: test/Analysis/DemandedBits/basic.ll =================================================================== --- test/Analysis/DemandedBits/basic.ll +++ test/Analysis/DemandedBits/basic.ll @@ -10,3 +10,39 @@ %3 = trunc i32 %2 to i8 ret i8 %3 } + +; CHECK-DAG: DemandedBits: 0x1 for %3 = icmp eq i32 %1, %2 +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %1 = and i32 %a, 255 (SignBits=0) +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %2 = shl i32 %1, 4 (SignBits=20) +define i1 @test_icmp1(i32 %a, i32 %b) { + %1 = and i32 %a, 255 + %2 = shl i32 %1, 4 + %3 = icmp eq i32 %1, %2 + ret i1 %3 +} + +; CHECK-DAG: DemandedBits: 0x1 for %3 = icmp eq i32 %1, %2 +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %1 = and i32 %a, 255 (SignBits=20) +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %2 = ashr i32 %1, 4 (SignBits=24) +define i1 @test_icmp2(i32 %a, i32 %b) { + %1 = and i32 %a, 255 + %2 = ashr i32 %1, 4 + %3 = icmp eq i32 %1, %2 + ret i1 %3 +} + +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %1 = and i32 %a, 255 (SignBits=0) +; CHECK-DAG: DemandedBits: 0x1 for %2 = icmp eq i32 -1, %1 +define i1 @test_icmp3(i32 %a) { + %1 = and i32 %a, 255 + %2 = icmp eq i32 -1, %1 + ret i1 %2 +} + +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %1 = or i32 %a, -1073741824 (SignBits=2) +; CHECK-DAG: DemandedBits: 0x1 for %2 = icmp eq i32 -1, %1 +define i1 @test_icmp4(i32 %a) { + %1 = or i32 %a, 3221225472 + %2 = icmp eq i32 -1, %1 + ret i1 %2 +} Index: test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll =================================================================== --- test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll +++ test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll @@ -263,5 +263,40 @@ br i1 %exitcond, label %for.cond.cleanup, label %for.body } +; CHECK-LABEL: @add_g +; CHECK: load <16 x i8> +; CHECK: xor <16 x i8> +; CHECK: icmp ult <16 x i8> +; CHECK: select <16 x i1> {{.*}}, <16 x i8> +; CHECK: store <16 x i8> +define void @add_g(i8* noalias nocapture readonly %p, i8* noalias nocapture readonly %q, i8* noalias nocapture %r, i8 %arg1, i32 %len) #0 { + %1 = icmp sgt i32 %len, 0 + br i1 %1, label %.lr.ph, label %._crit_edge + +.lr.ph: ; preds = %0 + %2 = sext i8 %arg1 to i64 + br label %3 + +._crit_edge: ; preds = %3, %0 + ret void + +;