Index: llvm/trunk/include/llvm/CodeGen/TargetLowering.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/TargetLowering.h +++ llvm/trunk/include/llvm/CodeGen/TargetLowering.h @@ -2899,22 +2899,28 @@ bool SimplifyDemandedBits(SDNode *User, unsigned OpIdx, const APInt &Demanded, DAGCombinerInfo &DCI, TargetLoweringOpt &TLO) const; - /// Look at Op. At this point, we know that only the DemandedMask bits of the + /// Look at Op. At this point, we know that only the DemandedBits bits of the /// result of Op are ever used downstream. If we can use this information to /// simplify Op, create a new simplified DAG node and return true, returning /// the original and new nodes in Old and New. Otherwise, analyze the /// expression and return a mask of KnownOne and KnownZero bits for the /// expression (used to simplify the caller). The KnownZero/One bits may only - /// be accurate for those bits in the DemandedMask. + /// be accurate for those bits in the Demanded masks. /// \p AssumeSingleUse When this parameter is true, this function will /// attempt to simplify \p Op even if there are multiple uses. /// Callers are responsible for correctly updating the DAG based on the /// results of this function, because simply replacing replacing TLO.Old /// with TLO.New will be incorrect when this parameter is true and TLO.Old /// has multiple uses. - bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedMask, - KnownBits &Known, - TargetLoweringOpt &TLO, + bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, KnownBits &Known, + TargetLoweringOpt &TLO, unsigned Depth = 0, + bool AssumeSingleUse = false) const; + + /// Helper wrapper around SimplifyDemandedBits, demanding all elements. + /// Adds Op back to the worklist upon success. + bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, + KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth = 0, bool AssumeSingleUse = false) const; @@ -2985,13 +2991,14 @@ SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth = 0) const; - /// Attempt to simplify any target nodes based on the demanded bits, + /// Attempt to simplify any target nodes based on the demanded bits/elts, /// returning true on success. Otherwise, analyze the /// expression and return a mask of KnownOne and KnownZero bits for the /// expression (used to simplify the caller). The KnownZero/One bits may only - /// be accurate for those bits in the DemandedMask. + /// be accurate for those bits in the Demanded masks. virtual bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth = 0) const; Index: llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -496,23 +496,41 @@ return Simplified; } +bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, + KnownBits &Known, + TargetLoweringOpt &TLO, + unsigned Depth, + bool AssumeSingleUse) const { + EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth, + AssumeSingleUse); +} + /// Look at Op. At this point, we know that only the OriginalDemandedBits of the /// result of Op are ever used downstream. If we can use this information to /// simplify Op, create a new simplified DAG node and return true, returning the /// original and new nodes in Old and New. Otherwise, analyze the expression and /// return a mask of Known bits for the expression (used to simplify the /// caller). The Known bits may only be accurate for those bits in the -/// DemandedMask. -bool TargetLowering::SimplifyDemandedBits(SDValue Op, - const APInt &OriginalDemandedBits, - KnownBits &Known, - TargetLoweringOpt &TLO, - unsigned Depth, - bool AssumeSingleUse) const { +/// OriginalDemandedBits and OriginalDemandedElts. +bool TargetLowering::SimplifyDemandedBits( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth, bool AssumeSingleUse) const { unsigned BitWidth = OriginalDemandedBits.getBitWidth(); assert(Op.getScalarValueSizeInBits() == BitWidth && "Mask size mismatches value type size!"); + + unsigned NumElts = OriginalDemandedElts.getBitWidth(); + assert((!Op.getValueType().isVector() || + NumElts == Op.getValueType().getVectorNumElements()) && + "Unexpected vector size"); + APInt DemandedBits = OriginalDemandedBits; + APInt DemandedElts = OriginalDemandedElts; SDLoc dl(Op); auto &DL = TLO.DAG.getDataLayout(); @@ -532,18 +550,19 @@ if (Depth != 0) { // If not at the root, Just compute the Known bits to // simplify things downstream. - TLO.DAG.computeKnownBits(Op, Known, Depth); + TLO.DAG.computeKnownBits(Op, Known, DemandedElts, Depth); return false; } // If this is the root being simplified, allow it to have multiple uses, - // just set the DemandedBits to all bits. + // just set the DemandedBits/Elts to all bits. DemandedBits = APInt::getAllOnesValue(BitWidth); - } else if (OriginalDemandedBits == 0) { - // Not demanding any bits from Op. + DemandedElts = APInt::getAllOnesValue(NumElts); + } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) { + // Not demanding any bits/elts from Op. if (!Op.isUndef()) return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); return false; - } else if (Depth == 6) { // Limit search depth. + } else if (Depth == 6) { // Limit search depth. return false; } @@ -573,18 +592,71 @@ Known.One &= Known2.One; Known.Zero &= Known2.Zero; } - return false; // Don't fall through, will infinitely loop. - case ISD::CONCAT_VECTORS: + return false; // Don't fall through, will infinitely loop. + case ISD::CONCAT_VECTORS: { Known.Zero.setAllBits(); Known.One.setAllBits(); - for (SDValue SrcOp : Op->ops()) { - if (SimplifyDemandedBits(SrcOp, DemandedBits, Known2, TLO, Depth + 1)) + EVT SubVT = Op.getOperand(0).getValueType(); + unsigned NumSubVecs = Op.getNumOperands(); + unsigned NumSubElts = SubVT.getVectorNumElements(); + for (unsigned i = 0; i != NumSubVecs; ++i) { + APInt DemandedSubElts = + DemandedElts.extractBits(NumSubElts, i * NumSubElts); + if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts, + Known2, TLO, Depth + 1)) return true; - // Known bits are the values that are shared by every subvector. - Known.One &= Known2.One; - Known.Zero &= Known2.Zero; + // Known bits are shared by every demanded subvector element. + if (!!DemandedSubElts) { + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + } + break; + } + case ISD::VECTOR_SHUFFLE: { + ArrayRef ShuffleMask = cast(Op)->getMask(); + + // Collect demanded elements from shuffle operands.. + APInt DemandedLHS(NumElts, 0); + APInt DemandedRHS(NumElts, 0); + for (unsigned i = 0; i != NumElts; ++i) { + if (!DemandedElts[i]) + continue; + int M = ShuffleMask[i]; + if (M < 0) { + // For UNDEF elements, we don't know anything about the common state of + // the shuffle result. + DemandedLHS.clearAllBits(); + DemandedRHS.clearAllBits(); + break; + } + assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range"); + if (M < (int)NumElts) + DemandedLHS.setBit(M); + else + DemandedRHS.setBit(M - NumElts); + } + + if (!!DemandedLHS || !!DemandedRHS) { + Known.Zero.setAllBits(); + Known.One.setAllBits(); + if (!!DemandedLHS) { + if (SimplifyDemandedBits(Op.getOperand(0), DemandedBits, DemandedLHS, + Known2, TLO, Depth + 1)) + return true; + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + if (!!DemandedRHS) { + if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedRHS, + Known2, TLO, Depth + 1)) + return true; + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } } break; + } case ISD::AND: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -596,7 +668,7 @@ if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1)) { KnownBits LHSKnown; // Do not increment Depth here; that can cause an infinite loop. - TLO.DAG.computeKnownBits(Op0, LHSKnown, Depth); + TLO.DAG.computeKnownBits(Op0, LHSKnown, DemandedElts, Depth); // If the LHS already has zeros where RHSC does, this 'and' is dead. if ((LHSKnown.Zero & DemandedBits) == (~RHSC->getAPIntValue() & DemandedBits)) @@ -619,10 +691,10 @@ } } - if (SimplifyDemandedBits(Op1, DemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, Known2, TLO, + if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts, Known2, TLO, Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -653,10 +725,11 @@ SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (SimplifyDemandedBits(Op1, DemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, Known2, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts, Known2, TLO, + Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -683,10 +756,10 @@ SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (SimplifyDemandedBits(Op1, DemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - if (SimplifyDemandedBits(Op0, DemandedBits, Known2, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO, Depth + 1)) return true; assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); @@ -840,7 +913,7 @@ } } - if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), Known, TLO, + if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), DemandedElts, Known, TLO, Depth + 1)) return true; @@ -935,7 +1008,7 @@ } // Compute the new bits that are at the top now. - if (SimplifyDemandedBits(Op0, InDemandedMask, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); @@ -974,7 +1047,7 @@ if (DemandedBits.countLeadingZeros() < ShAmt) InDemandedMask.setSignBit(); - if (SimplifyDemandedBits(Op0, InDemandedMask, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); @@ -1221,18 +1294,26 @@ break; } case ISD::EXTRACT_VECTOR_ELT: { - // Demand the bits from every vector element. SDValue Src = Op.getOperand(0); + SDValue Idx = Op.getOperand(1); + unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); unsigned EltBitWidth = Src.getScalarValueSizeInBits(); + // Demand the bits from every vector element without a constant index. + APInt DemandedSrcElts = APInt::getAllOnesValue(NumSrcElts); + if (auto *CIdx = dyn_cast(Idx)) + if (CIdx->getAPIntValue().ult(NumSrcElts)) + DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue()); + // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know // anything about the extended bits. APInt DemandedSrcBits = DemandedBits; if (BitWidth > EltBitWidth) DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth); - if (SimplifyDemandedBits(Src, DemandedSrcBits, Known2, TLO, Depth + 1)) - return true; + if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO, + Depth + 1)) + return true; Known = Known2; if (BitWidth > EltBitWidth) @@ -1313,8 +1394,8 @@ SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1); unsigned DemandedBitsLZ = DemandedBits.countLeadingZeros(); APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ); - if (SimplifyDemandedBits(Op0, LoMask, Known2, TLO, Depth + 1) || - SimplifyDemandedBits(Op1, LoMask, Known2, TLO, Depth + 1) || + if (SimplifyDemandedBits(Op0, LoMask, DemandedElts, Known2, TLO, Depth + 1) || + SimplifyDemandedBits(Op1, LoMask, DemandedElts, Known2, TLO, Depth + 1) || // See if the operation should be performed at a smaller bit width. ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) { SDNodeFlags Flags = Op.getNode()->getFlags(); @@ -1354,14 +1435,14 @@ } default: if (Op.getOpcode() >= ISD::BUILTIN_OP_END) { - if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, Known, TLO, - Depth)) + if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts, + Known, TLO, Depth)) return true; break; } // Just use computeKnownBits to compute output bits. - TLO.DAG.computeKnownBits(Op, Known, Depth); + TLO.DAG.computeKnownBits(Op, Known, DemandedElts, Depth); break; } @@ -1887,8 +1968,8 @@ } bool TargetLowering::SimplifyDemandedBitsForTargetNode( - SDValue Op, const APInt &DemandedBits, KnownBits &Known, - TargetLoweringOpt &TLO, unsigned Depth) const { + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const { assert((Op.getOpcode() >= ISD::BUILTIN_OP_END || Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || Op.getOpcode() == ISD::INTRINSIC_W_CHAIN || @@ -1896,9 +1977,6 @@ "Should use SimplifyDemandedBits if you don't know whether Op" " is a target node!"); EVT VT = Op.getValueType(); - APInt DemandedElts = VT.isVector() - ? APInt::getAllOnesValue(VT.getVectorNumElements()) - : APInt(1, 1); computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth); return false; } Index: llvm/trunk/lib/Target/X86/X86ISelLowering.h =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.h +++ llvm/trunk/lib/Target/X86/X86ISelLowering.h @@ -871,6 +871,7 @@ bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const override; Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -32397,8 +32397,9 @@ } bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( - SDValue Op, const APInt &OriginalDemandedBits, KnownBits &Known, - TargetLoweringOpt &TLO, unsigned Depth) const { + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth) const { unsigned BitWidth = OriginalDemandedBits.getBitWidth(); unsigned Opc = Op.getOpcode(); switch(Opc) { @@ -32424,8 +32425,8 @@ KnownBits KnownOp; unsigned ShAmt = ShiftImm->getZExtValue(); APInt DemandedMask = OriginalDemandedBits.lshr(ShAmt); - if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, KnownOp, TLO, - Depth + 1)) + if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, + OriginalDemandedElts, KnownOp, TLO, Depth + 1)) return true; } break; @@ -32446,8 +32447,8 @@ OriginalDemandedBits.countLeadingZeros() < ShAmt) DemandedMask.setSignBit(); - if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, KnownOp, TLO, - Depth + 1)) + if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, + OriginalDemandedElts, KnownOp, TLO, Depth + 1)) return true; } break; @@ -32475,8 +32476,8 @@ // MOVMSK only uses the MSB from each vector element. KnownBits KnownSrc; - if (SimplifyDemandedBits(Src, APInt::getSignMask(SrcBits), KnownSrc, TLO, - Depth + 1)) + if (SimplifyDemandedBits(Src, APInt::getSignMask(SrcBits), DemandedElts, + KnownSrc, TLO, Depth + 1)) return true; if (KnownSrc.One[SrcBits - 1]) @@ -32488,7 +32489,7 @@ } return TargetLowering::SimplifyDemandedBitsForTargetNode( - Op, OriginalDemandedBits, Known, TLO, Depth); + Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth); } /// Check if a vector extract from a target-specific shuffle of a load can be Index: llvm/trunk/test/CodeGen/X86/combine-sdiv.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/combine-sdiv.ll +++ llvm/trunk/test/CodeGen/X86/combine-sdiv.ll @@ -3002,7 +3002,6 @@ ; SSE2-NEXT: packuswb %xmm0, %xmm2 ; SSE2-NEXT: psrlw $7, %xmm1 ; SSE2-NEXT: pand {{.*}}(%rip), %xmm1 -; SSE2-NEXT: pand {{.*}}(%rip), %xmm1 ; SSE2-NEXT: paddb %xmm2, %xmm1 ; SSE2-NEXT: movdqa %xmm1, %xmm0 ; SSE2-NEXT: retq @@ -3033,7 +3032,6 @@ ; SSE41-NEXT: packuswb %xmm1, %xmm2 ; SSE41-NEXT: psrlw $7, %xmm0 ; SSE41-NEXT: pand {{.*}}(%rip), %xmm0 -; SSE41-NEXT: pand {{.*}}(%rip), %xmm0 ; SSE41-NEXT: paddb %xmm2, %xmm0 ; SSE41-NEXT: retq ; @@ -3059,7 +3057,6 @@ ; AVX1-NEXT: vpackuswb %xmm1, %xmm2, %xmm1 ; AVX1-NEXT: vpsrlw $7, %xmm0, %xmm0 ; AVX1-NEXT: vpand {{.*}}(%rip), %xmm0, %xmm0 -; AVX1-NEXT: vpand {{.*}}(%rip), %xmm0, %xmm0 ; AVX1-NEXT: vpaddb %xmm0, %xmm1, %xmm0 ; AVX1-NEXT: retq ; @@ -3078,7 +3075,6 @@ ; AVX2-NEXT: vpackuswb %xmm2, %xmm1, %xmm1 ; AVX2-NEXT: vpsrlw $7, %xmm0, %xmm0 ; AVX2-NEXT: vpand {{.*}}(%rip), %xmm0, %xmm0 -; AVX2-NEXT: vpand {{.*}}(%rip), %xmm0, %xmm0 ; AVX2-NEXT: vpaddb %xmm0, %xmm1, %xmm0 ; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq @@ -3093,7 +3089,6 @@ ; AVX512F-NEXT: vpaddb %xmm0, %xmm1, %xmm0 ; AVX512F-NEXT: vpsrlw $7, %xmm0, %xmm1 ; AVX512F-NEXT: vpand {{.*}}(%rip), %xmm1, %xmm1 -; AVX512F-NEXT: vpand {{.*}}(%rip), %xmm1, %xmm1 ; AVX512F-NEXT: vpmovsxbd %xmm0, %zmm0 ; AVX512F-NEXT: vpsravd {{.*}}(%rip), %zmm0, %zmm0 ; AVX512F-NEXT: vpmovdb %zmm0, %xmm0 @@ -3110,7 +3105,6 @@ ; AVX512BW-NEXT: vpaddb %xmm0, %xmm1, %xmm0 ; AVX512BW-NEXT: vpsrlw $7, %xmm0, %xmm1 ; AVX512BW-NEXT: vpand {{.*}}(%rip), %xmm1, %xmm1 -; AVX512BW-NEXT: vpand {{.*}}(%rip), %xmm1, %xmm1 ; AVX512BW-NEXT: vpmovsxbw %xmm0, %ymm0 ; AVX512BW-NEXT: vpsravw {{.*}}(%rip), %ymm0, %ymm0 ; AVX512BW-NEXT: vpmovwb %ymm0, %xmm0