Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -1311,6 +1311,17 @@ /// target nodes to be understood. unsigned ComputeNumSignBits(SDValue Op, unsigned Depth = 0) const; + /// Return the number of times the sign bit of the register is replicated into + /// the other bits. We know that at least 1 bit is always equal to the sign + /// bit (itself), but other cases can give us information. For example, + /// immediately after an "SRA X, 2", we know that the top 3 bits are all equal + /// to each other, so we return 3. The DemandedElts argument allows + /// us to only collect the minimum sign bits of the requested vector elements. + /// Targets can implement the ComputeNumSignBitsForTarget method in the + /// TargetLowering class to allow target nodes to be understood. + unsigned ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, + unsigned Depth = 0) const; + /// Return true if the specified operand is an ISD::ADD with a ConstantSDNode /// on the right-hand side, or if it is an ISD::OR with a ConstantSDNode that /// is guaranteed to have the same semantics as an ADD. This handles the Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -2431,6 +2431,7 @@ /// This method can be implemented by targets that want to expose additional /// information about sign bits to the DAG Combiner. virtual unsigned ComputeNumSignBitsForTargetNode(SDValue Op, + const APInt &DemandedElts, const SelectionDAG &DAG, unsigned Depth = 0) const; Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2898,6 +2898,15 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, unsigned Depth) const { EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return ComputeNumSignBits(Op, DemandedElts, Depth); +} + +unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, + unsigned Depth) const { + EVT VT = Op.getValueType(); assert(VT.isInteger() && "Invalid VT!"); unsigned VTBits = VT.getScalarSizeInBits(); unsigned Tmp, Tmp2; @@ -2906,6 +2915,9 @@ if (Depth == 6) return 1; // Limit search depth. + if (!DemandedElts) + return 1; // No demanded elts, better to assume we don't know anything. + switch (Op.getOpcode()) { default: break; case ISD::AssertSext: @@ -2923,6 +2935,9 @@ case ISD::BUILD_VECTOR: Tmp = VTBits; for (unsigned i = 0, e = Op.getNumOperands(); (i < e) && (Tmp > 1); ++i) { + if (!DemandedElts[i]) + continue; + SDValue SrcOp = Op.getOperand(i); Tmp2 = ComputeNumSignBits(Op.getOperand(i), Depth + 1); @@ -2951,7 +2966,7 @@ return std::max(Tmp, Tmp2); case ISD::SRA: - Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1); + Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1); // SRA X, C -> adds C sign bits. if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1))) { APInt ShiftVal = C->getAPIntValue(); @@ -3114,19 +3129,62 @@ // result. Otherwise it gives either negative or > bitwidth result return std::max(std::min(KnownSign - rIndex * BitWidth, BitWidth), 0); } + case ISD::INSERT_VECTOR_ELT: { + SDValue InVec = Op.getOperand(0); + SDValue InVal = Op.getOperand(1); + SDValue EltNo = Op.getOperand(2); + unsigned NumElts = InVec.getValueType().getVectorNumElements(); + + ConstantSDNode *CEltNo = dyn_cast(EltNo); + if (CEltNo && CEltNo->getAPIntValue().ult(NumElts)) { + // If we know the element index, split the demand between the + // source vector and the inserted element. + unsigned EltIdx = CEltNo->getZExtValue(); + + // If we demand the inserted element then get its sign bits. + Tmp = UINT_MAX; + if (DemandedElts[EltIdx]) + Tmp = ComputeNumSignBits(InVal, Depth + 1); + + // If we demand the source vector then get its sign bits, and determine + // the minimum. + APInt VectorElts = DemandedElts & ~(APInt::getOneBitSet(NumElts, EltIdx)); + if (!!VectorElts) { + Tmp2 = ComputeNumSignBits(InVec, VectorElts, Depth + 1); + Tmp = std::min(Tmp, Tmp2); + } + } else { + // Unknown element index, so ignore DemandedElts and demand them all. + Tmp = ComputeNumSignBits(InVec, Depth + 1); + Tmp2 = ComputeNumSignBits(InVal, Depth + 1); + Tmp = std::min(Tmp, Tmp2); + } + assert(Tmp <= VTBits && "Failed to determine minimum sign bits"); + return Tmp; + } case ISD::EXTRACT_VECTOR_ELT: { - // At the moment we keep this simple and skip tracking the specific - // element. This way we get the lowest common denominator for all elements - // of the vector. - // TODO: get information for given vector element + SDValue InVec = Op.getOperand(0); + SDValue EltNo = Op.getOperand(1); + EVT VecVT = InVec.getValueType(); const unsigned BitWidth = Op.getValueSizeInBits(); const unsigned EltBitWidth = Op.getOperand(0).getScalarValueSizeInBits(); + const unsigned NumSrcElts = VecVT.getVectorNumElements(); + // If BitWidth > EltBitWidth the value is anyext:ed, and we do not know // anything about sign bits. But if the sizes match we can derive knowledge // about sign bits from the vector operand. - if (BitWidth == EltBitWidth) - return ComputeNumSignBits(Op.getOperand(0), Depth+1); - break; + if (BitWidth != EltBitWidth) + break; + + // If we know the element index, just demand that vector element, else for + // an unknown element index, ignore DemandedElts and demand them all. + APInt DemandedSrcElts = APInt::getAllOnesValue(NumSrcElts); + ConstantSDNode *ConstEltNo = dyn_cast(EltNo); + if (ConstEltNo && ConstEltNo->getAPIntValue().ult(NumSrcElts)) + DemandedSrcElts = + APInt::getOneBitSet(NumSrcElts, ConstEltNo->getZExtValue()); + + return ComputeNumSignBits(InVec, DemandedSrcElts, Depth + 1); } case ISD::EXTRACT_SUBVECTOR: return ComputeNumSignBits(Op.getOperand(0), Depth + 1); @@ -3161,7 +3219,8 @@ Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || Op.getOpcode() == ISD::INTRINSIC_W_CHAIN || Op.getOpcode() == ISD::INTRINSIC_VOID) { - unsigned NumBits = TLI->ComputeNumSignBitsForTargetNode(Op, *this, Depth); + unsigned NumBits = + TLI->ComputeNumSignBitsForTargetNode(Op, DemandedElts, *this, Depth); if (NumBits > 1) FirstAnswer = std::max(FirstAnswer, NumBits); } @@ -3169,7 +3228,7 @@ // Finally, if we can prove that the top bits of the result are 0's or 1's, // use this information. APInt KnownZero, KnownOne; - computeKnownBits(Op, KnownZero, KnownOne, Depth); + computeKnownBits(Op, KnownZero, KnownOne, DemandedElts, Depth); APInt Mask; if (KnownZero.isNegative()) { // sign bit is 0 Index: lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1337,6 +1337,7 @@ /// This method can be implemented by targets that want to expose additional /// information about sign bits to the DAG Combiner. unsigned TargetLowering::ComputeNumSignBitsForTargetNode(SDValue Op, + const APInt &, const SelectionDAG &, unsigned Depth) const { assert((Op.getOpcode() >= ISD::BUILTIN_OP_END || Index: lib/Target/AMDGPU/AMDGPUISelLowering.h =================================================================== --- lib/Target/AMDGPU/AMDGPUISelLowering.h +++ lib/Target/AMDGPU/AMDGPUISelLowering.h @@ -203,7 +203,8 @@ const SelectionDAG &DAG, unsigned Depth = 0) const override; - unsigned ComputeNumSignBitsForTargetNode(SDValue Op, const SelectionDAG &DAG, + unsigned ComputeNumSignBitsForTargetNode(SDValue Op, const APInt &DemandedElts, + const SelectionDAG &DAG, unsigned Depth = 0) const override; /// \brief Helper function that adds Reg to the LiveIn list of the DAG's Index: lib/Target/AMDGPU/AMDGPUISelLowering.cpp =================================================================== --- lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -3603,9 +3603,8 @@ } unsigned AMDGPUTargetLowering::ComputeNumSignBitsForTargetNode( - SDValue Op, - const SelectionDAG &DAG, - unsigned Depth) const { + SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG, + unsigned Depth) const { switch (Op.getOpcode()) { case AMDGPUISD::BFE_I32: { ConstantSDNode *Width = dyn_cast(Op.getOperand(2)); Index: lib/Target/X86/X86ISelLowering.h =================================================================== --- lib/Target/X86/X86ISelLowering.h +++ lib/Target/X86/X86ISelLowering.h @@ -832,6 +832,7 @@ /// Determine the number of bits in the operation that are sign bits. unsigned ComputeNumSignBitsForTargetNode(SDValue Op, + const APInt &DemandedElts, const SelectionDAG &DAG, unsigned Depth) const override; Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -26645,7 +26645,8 @@ } unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( - SDValue Op, const SelectionDAG &DAG, unsigned Depth) const { + SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG, + unsigned Depth) const { unsigned Opcode = Op.getOpcode(); switch (Opcode) { case X86ISD::SETCC_CARRY: Index: test/CodeGen/X86/known-bits-vector.ll =================================================================== --- test/CodeGen/X86/known-bits-vector.ll +++ test/CodeGen/X86/known-bits-vector.ll @@ -23,18 +23,14 @@ define float @knownbits_mask_extract_uitofp(<2 x i64> %a0) nounwind { ; X32-LABEL: knownbits_mask_extract_uitofp: ; X32: # BB#0: -; X32-NEXT: pushl %ebp -; X32-NEXT: movl %esp, %ebp -; X32-NEXT: andl $-8, %esp -; X32-NEXT: subl $16, %esp +; X32-NEXT: pushl %eax ; X32-NEXT: vpxor %xmm1, %xmm1, %xmm1 ; X32-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0],xmm1[1,2,3],xmm0[4,5,6,7] -; X32-NEXT: vmovq %xmm0, {{[0-9]+}}(%esp) -; X32-NEXT: fildll {{[0-9]+}}(%esp) -; X32-NEXT: fstps {{[0-9]+}}(%esp) -; X32-NEXT: flds {{[0-9]+}}(%esp) -; X32-NEXT: movl %ebp, %esp -; X32-NEXT: popl %ebp +; X32-NEXT: vmovd %xmm0, %eax +; X32-NEXT: vcvtsi2ssl %eax, %xmm2, %xmm0 +; X32-NEXT: vmovss %xmm0, (%esp) +; X32-NEXT: flds (%esp) +; X32-NEXT: popl %eax ; X32-NEXT: retl ; ; X64-LABEL: knownbits_mask_extract_uitofp: @@ -42,7 +38,7 @@ ; X64-NEXT: vpxor %xmm1, %xmm1, %xmm1 ; X64-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0],xmm1[1,2,3],xmm0[4,5,6,7] ; X64-NEXT: vmovq %xmm0, %rax -; X64-NEXT: vcvtsi2ssq %rax, %xmm2, %xmm0 +; X64-NEXT: vcvtsi2ssl %eax, %xmm2, %xmm0 ; X64-NEXT: retq %1 = and <2 x i64> %a0, %2 = extractelement <2 x i64> %1, i32 0 Index: test/CodeGen/X86/known-signbits-vector.ll =================================================================== --- test/CodeGen/X86/known-signbits-vector.ll +++ test/CodeGen/X86/known-signbits-vector.ll @@ -100,27 +100,21 @@ define float @signbits_ashr_insert_ashr_extract_sitofp(i64 %a0, i64 %a1) nounwind { ; X32-LABEL: signbits_ashr_insert_ashr_extract_sitofp: ; X32: # BB#0: -; X32-NEXT: pushl %ebp -; X32-NEXT: movl %esp, %ebp -; X32-NEXT: andl $-8, %esp -; X32-NEXT: subl $16, %esp -; X32-NEXT: movl 8(%ebp), %eax -; X32-NEXT: movl 12(%ebp), %ecx +; X32-NEXT: pushl %eax +; X32-NEXT: movl {{[0-9]+}}(%esp), %eax +; X32-NEXT: movl {{[0-9]+}}(%esp), %ecx ; X32-NEXT: shrdl $30, %ecx, %eax ; X32-NEXT: sarl $30, %ecx ; X32-NEXT: vmovd %eax, %xmm0 ; X32-NEXT: vpinsrd $1, %ecx, %xmm0, %xmm0 -; X32-NEXT: vpinsrd $2, 16(%ebp), %xmm0, %xmm0 -; X32-NEXT: vpinsrd $3, 20(%ebp), %xmm0, %xmm0 -; X32-NEXT: vpsrad $3, %xmm0, %xmm1 +; X32-NEXT: vpinsrd $2, {{[0-9]+}}(%esp), %xmm0, %xmm0 +; X32-NEXT: vpinsrd $3, {{[0-9]+}}(%esp), %xmm0, %xmm0 ; X32-NEXT: vpsrlq $3, %xmm0, %xmm0 -; X32-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7] -; X32-NEXT: vmovq %xmm0, {{[0-9]+}}(%esp) -; X32-NEXT: fildll {{[0-9]+}}(%esp) -; X32-NEXT: fstps {{[0-9]+}}(%esp) -; X32-NEXT: flds {{[0-9]+}}(%esp) -; X32-NEXT: movl %ebp, %esp -; X32-NEXT: popl %ebp +; X32-NEXT: vmovd %xmm0, %eax +; X32-NEXT: vcvtsi2ssl %eax, %xmm1, %xmm0 +; X32-NEXT: vmovss %xmm0, (%esp) +; X32-NEXT: flds (%esp) +; X32-NEXT: popl %eax ; X32-NEXT: retl ; ; X64-LABEL: signbits_ashr_insert_ashr_extract_sitofp: @@ -133,7 +127,7 @@ ; X64-NEXT: vpsrlq $3, %xmm0, %xmm0 ; X64-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7] ; X64-NEXT: vmovq %xmm0, %rax -; X64-NEXT: vcvtsi2ssq %rax, %xmm2, %xmm0 +; X64-NEXT: vcvtsi2ssl %eax, %xmm2, %xmm0 ; X64-NEXT: retq %1 = ashr i64 %a0, 30 %2 = insertelement <2 x i64> undef, i64 %1, i32 0