Index: include/llvm/Analysis/DemandedBits.h =================================================================== --- include/llvm/Analysis/DemandedBits.h +++ include/llvm/Analysis/DemandedBits.h @@ -46,6 +46,21 @@ /// Return the bits demanded from instruction I. APInt getDemandedBits(Instruction *I); + /// Return the minimum bitwidth for instruction I as deduced by DemandedBits. + /// + /// A demanded bit is defined as one that can either be zero or one without + /// affecting the result; this query tightens that definition. + /// Unlike getDemandedBits, this assumes that any non-demanded uppermost bits + /// will NOT be changed. + /// + /// An area where this is useful is where a bit is part of a zero or sign + /// extension. The bit is not significant in the result as long as it is not + /// changed. + /// + /// The most obvious use of this information is to truncate an instruction to + /// a smaller bitwidth. + unsigned getMinimumBitWidth(Instruction *I); + /// Return true if, during analysis, I could not be reached. bool isInstructionDead(Instruction *I); @@ -53,7 +68,8 @@ void performAnalysis(); void determineLiveOperandBits(const Instruction *UserI, const Instruction *I, unsigned OperandNo, - const APInt &AOut, APInt &AB, + const APInt &AOut, const APInt &MKOut, + APInt &AB, APInt &MKB, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2); @@ -65,6 +81,17 @@ // The set of visited instructions (non-integer-typed only). SmallPtrSet Visited; DenseMap AliveBits; + /// For each instruction, the bits that must not be modified (are not + /// don't-care). + /// + /// These are a subset of the non-demanded bits + /// (AliveBits & MustKeepBits == 0) that do not contribute to the + /// instruction's result as-is, but if they were modified they would + /// change the result. A common reason for this is that they are + /// statically known to be either zero or the same as the sign bit - + /// therefore they can be removed without changing the instruction's + /// result but not modified. See also getMinimumBitWidth(). + DenseMap MustKeepBits; }; /// Create a demanded bits analysis pass. Index: lib/Analysis/DemandedBits.cpp =================================================================== --- lib/Analysis/DemandedBits.cpp +++ lib/Analysis/DemandedBits.cpp @@ -69,8 +69,8 @@ void DemandedBits::determineLiveOperandBits( const Instruction *UserI, const Instruction *I, unsigned OperandNo, - const APInt &AOut, APInt &AB, APInt &KnownZero, APInt &KnownOne, - APInt &KnownZero2, APInt &KnownOne2) { + const APInt &AOut, const APInt &MKOut, APInt &AB, APInt &MKB, + APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2) { unsigned BitWidth = AB.getBitWidth(); // We're called once per operand, but for some instructions, we need to @@ -106,6 +106,7 @@ // The alive bits of the input are the swapped alive bits of // the output. AB = AOut.byteSwap(); + MKB = MKOut.byteSwap(); break; case Intrinsic::ctlz: if (OperandNo == 0) { @@ -115,6 +116,7 @@ ComputeKnownBits(BitWidth, I, nullptr); AB = APInt::getHighBitsSet(BitWidth, std::min(BitWidth, KnownOne.countLeadingZeros()+1)); + MKB = 0; } break; case Intrinsic::cttz: @@ -125,6 +127,7 @@ ComputeKnownBits(BitWidth, I, nullptr); AB = APInt::getLowBitsSet(BitWidth, std::min(BitWidth, KnownOne.countTrailingZeros()+1)); + MKB = 0; } break; } @@ -136,6 +139,7 @@ // bits than that (adds, and thus subtracts, ripple only to the // left). AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits()); + MKB = MKOut & ~AB; break; case Instruction::Shl: if (OperandNo == 0) @@ -143,6 +147,7 @@ dyn_cast(UserI->getOperand(1))) { uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1); AB = AOut.lshr(ShiftAmt); + MKB = MKOut.lshr(ShiftAmt); // If the shift is nuw/nsw, then the high bits are not dead // (because we've promised that they *must* be zero). @@ -159,7 +164,8 @@ dyn_cast(UserI->getOperand(1))) { uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1); AB = AOut.shl(ShiftAmt); - + MKB = MKOut.shl(ShiftAmt); + // If the shift is exact, then the low bits are not dead // (they must be zero). if (cast(UserI)->isExact()) @@ -172,6 +178,8 @@ dyn_cast(UserI->getOperand(1))) { uint64_t ShiftAmt = CI->getLimitedValue(BitWidth-1); AB = AOut.shl(ShiftAmt); + MKB = MKOut.shl(ShiftAmt); + // Because the high input bit is replicated into the // high-order bits of the result, if we need any of those // bits, then we must keep the highest input bit. @@ -187,7 +195,8 @@ break; case Instruction::And: AB = AOut; - + MKB = MKOut; + // For bits that are known zero, the corresponding bits in the // other operand are dead (unless they're both zero, in which // case they can't both be dead, so just mark the LHS bits as @@ -195,14 +204,17 @@ if (OperandNo == 0) { ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); AB &= ~KnownZero2; + MKB &= ~KnownZero2; } else { if (!isa(UserI->getOperand(0))) ComputeKnownBits(BitWidth, UserI->getOperand(0), I); AB &= ~(KnownZero & ~KnownZero2); + MKB &= ~(KnownZero & ~KnownZero2); } break; case Instruction::Or: AB = AOut; + MKB = MKOut; // For bits that are known one, the corresponding bits in the // other operand are dead (unless they're both one, in which @@ -211,24 +223,30 @@ if (OperandNo == 0) { ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); AB &= ~KnownOne2; + MKB &= ~KnownOne2; } else { if (!isa(UserI->getOperand(0))) ComputeKnownBits(BitWidth, UserI->getOperand(0), I); AB &= ~(KnownOne & ~KnownOne2); + MKB &= ~(KnownOne & ~KnownOne2); } break; case Instruction::Xor: case Instruction::PHI: AB = AOut; + MKB = MKOut; break; case Instruction::Trunc: AB = AOut.zext(BitWidth); + MKB = MKOut.zext(BitWidth); break; case Instruction::ZExt: AB = AOut.trunc(BitWidth); + MKB = MKOut.trunc(BitWidth); break; case Instruction::SExt: AB = AOut.trunc(BitWidth); + MKB = MKOut.trunc(BitWidth); // Because the high input bit is replicated into the // high-order bits of the result, if we need any of those // bits, then we must keep the highest input bit. @@ -238,8 +256,18 @@ AB.setBit(BitWidth-1); break; case Instruction::Select: - if (OperandNo != 0) + if (OperandNo != 0) { AB = AOut; + MKB = MKOut; + } + break; + case Instruction::ICmp: + // Count the number of leading zeroes in each operand. + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + auto NumLeadingZeroes = std::min(KnownZero.countLeadingOnes(), + KnownZero2.countLeadingOnes()); + AB = ~APInt::getHighBitsSet(BitWidth, NumLeadingZeroes); + MKB = ~AB; break; } } @@ -260,6 +288,7 @@ Visited.clear(); AliveBits.clear(); + MustKeepBits.clear(); SmallVector Worklist; @@ -276,6 +305,7 @@ if (IntegerType *IT = dyn_cast(I.getType())) { if (!AliveBits.count(&I)) { AliveBits[&I] = APInt(IT->getBitWidth(), 0); + MustKeepBits[&I] = APInt(IT->getBitWidth(), 0); Worklist.push_back(&I); } @@ -285,8 +315,10 @@ // Non-integer-typed instructions... for (Use &OI : I.operands()) { if (Instruction *J = dyn_cast(OI)) { - if (IntegerType *IT = dyn_cast(J->getType())) + if (IntegerType *IT = dyn_cast(J->getType())) { AliveBits[J] = APInt::getAllOnesValue(IT->getBitWidth()); + MustKeepBits[J] = APInt::getAllOnesValue(IT->getBitWidth()); + } Worklist.push_back(J); } } @@ -301,10 +333,12 @@ Instruction *UserI = Worklist.pop_back_val(); DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI); - APInt AOut; + APInt AOut, MKOut; if (UserI->getType()->isIntegerTy()) { AOut = AliveBits[UserI]; + MKOut = MustKeepBits[UserI]; DEBUG(dbgs() << " Alive Out: " << AOut); + DEBUG(dbgs() << " Must Keep Out: " << AOut); } DEBUG(dbgs() << "\n"); @@ -320,6 +354,7 @@ if (IntegerType *IT = dyn_cast(I->getType())) { unsigned BitWidth = IT->getBitWidth(); APInt AB = APInt::getAllOnesValue(BitWidth); + APInt MKB = APInt(BitWidth, 0); if (UserI->getType()->isIntegerTy() && !AOut && !isAlwaysLive(UserI)) { AB = APInt(BitWidth, 0); @@ -327,7 +362,7 @@ // If all bits of the output are dead, then all bits of the input // Bits of each operand that are used to compute alive bits of the // output are alive, all others are dead. - determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB, + determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, MKOut, AB, MKB, KnownZero, KnownOne, KnownZero2, KnownOne2); } @@ -340,10 +375,24 @@ if (ABI != AliveBits.end()) ABPrev = ABI->second; + bool AddedToWorklist = false; APInt ABNew = AB | ABPrev; if (ABNew != ABPrev || ABI == AliveBits.end()) { AliveBits[I] = std::move(ABNew); Worklist.push_back(I); + AddedToWorklist = true; + } + + APInt MKPrev(BitWidth, 0); + auto MKI = MustKeepBits.find(I); + if (MKI != MustKeepBits.end()) + MKPrev = MKI->second; + + APInt MKNew = MKB | MKPrev; + if (MKNew != MKPrev || MKI == MustKeepBits.end()) { + MustKeepBits[I] = std::move(MKNew); + if (!AddedToWorklist) + Worklist.push_back(I); } } else if (!Visited.count(I)) { Worklist.push_back(I); @@ -358,10 +407,20 @@ const DataLayout &DL = I->getParent()->getModule()->getDataLayout(); if (AliveBits.count(I)) - return AliveBits[I]; + return AliveBits[I] | MustKeepBits[I]; return APInt::getAllOnesValue(DL.getTypeSizeInBits(I->getType())); } +unsigned DemandedBits::getMinimumBitWidth(Instruction *I) { + performAnalysis(); + + const DataLayout &DL = I->getParent()->getModule()->getDataLayout(); + unsigned Sz = DL.getTypeSizeInBits(I->getType()); + if (AliveBits.count(I)) + return Sz - AliveBits[I].countLeadingZeros(); + return Sz; +} + bool DemandedBits::isInstructionDead(Instruction *I) { performAnalysis(); @@ -374,8 +433,11 @@ // just because of this one debugging method. const_cast(this)->performAnalysis(); for (auto &KV : AliveBits) { - OS << "DemandedBits: 0x" << utohexstr(KV.second.getLimitedValue()) << " for " - << *KV.first << "\n"; + auto MKB = MustKeepBits.find(KV.first)->second; + auto DB = KV.second | MKB; + OS << "DemandedBits: 0x" << utohexstr(DB.getLimitedValue()) << " for " + << *KV.first << " (MustKeepBits=0x" << utohexstr(MKB.getLimitedValue()) + << ")\n"; } } Index: lib/Analysis/VectorUtils.cpp =================================================================== --- lib/Analysis/VectorUtils.cpp +++ lib/Analysis/VectorUtils.cpp @@ -494,11 +494,12 @@ // If we encounter a type that is larger than 64 bits, we can't represent // it so bail out. - if (DB.getDemandedBits(I).getBitWidth() > 64) + unsigned BW = DB.getDemandedBits(I).getBitWidth(); + if (BW > 64) return MapVector(); - uint64_t V = DB.getDemandedBits(I).getZExtValue(); - DBits[Leader] |= V; + APInt V = APInt::getLowBitsSet(BW, DB.getMinimumBitWidth(I)); + DBits[Leader] |= V.getZExtValue(); // Casts, loads and instructions outside of our range terminate a chain // successfully. Index: test/Analysis/DemandedBits/basic.ll =================================================================== --- test/Analysis/DemandedBits/basic.ll +++ test/Analysis/DemandedBits/basic.ll @@ -10,3 +10,34 @@ %3 = trunc i32 %2 to i8 ret i8 %3 } + +; CHECK-LABEL: 'test_icmp1' +; CHECK-DAG: DemandedBits: 0x1 for %3 = icmp eq i32 %1, %2 +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %1 = and i32 %a, 255 (MustKeepBits=0xFFFFFF00) +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %2 = shl i32 %1, 4 (MustKeepBits=0xFFFFF000) +define i1 @test_icmp1(i32 %a, i32 %b) { + %1 = and i32 %a, 255 + %2 = shl i32 %1, 4 + %3 = icmp eq i32 %1, %2 + ret i1 %3 +} + +; CHECK-LABEL: 'test_icmp2' +; CHECK-DAG: DemandedBits: 0x1 for %3 = icmp eq i32 %1, %2 +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %1 = and i32 %a, 255 (MustKeepBits=0xFFFFFF00) +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %2 = ashr i32 %1, 4 (MustKeepBits=0xFFFFFF00) +define i1 @test_icmp2(i32 %a, i32 %b) { + %1 = and i32 %a, 255 + %2 = ashr i32 %1, 4 + %3 = icmp eq i32 %1, %2 + ret i1 %3 +} + +; CHECK-LABEL: 'test_icmp3' +; CHECK-DAG: DemandedBits: 0xFFFFFFFF for %1 = and i32 %a, 255 (MustKeepBits=0x0) +; CHECK-DAG: DemandedBits: 0x1 for %2 = icmp eq i32 -1, %1 +define i1 @test_icmp3(i32 %a) { + %1 = and i32 %a, 255 + %2 = icmp eq i32 -1, %1 + ret i1 %2 +} \ No newline at end of file Index: test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll =================================================================== --- test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll +++ test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll @@ -205,5 +205,39 @@ br i1 %exitcond, label %for.cond.cleanup, label %for.body } +; CHECK-LABEL: @add_g +; CHECK: load <16 x i8> +; CHECK: xor <16 x i8> +; CHECK: icmp ult <16 x i8> +; CHECK: select <16 x i1> {{.*}}, <16 x i8> +; CHECK: store <16 x i8> +define void @add_g(i8* noalias nocapture readonly %p, i8* noalias nocapture readonly %q, i8* noalias nocapture %r, i8 %arg1, i32 %len) #0 { + %1 = icmp sgt i32 %len, 0 + br i1 %1, label %.lr.ph, label %._crit_edge + +.lr.ph: ; preds = %0 + %2 = sext i8 %arg1 to i64 + br label %3 + +._crit_edge: ; preds = %3, %0 + ret void + +;