Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -1250,6 +1250,8 @@ setOperationAction(ISD::TRUNCATE, MVT::v8i16, Custom); setOperationAction(ISD::TRUNCATE, MVT::v4i32, Custom); + setOperationAction(ISD::CTPOP, MVT::v8i32, Custom); + if (Subtarget->hasFMA() || Subtarget->hasFMA4()) { setOperationAction(ISD::FMA, MVT::v8f32, Legal); setOperationAction(ISD::FMA, MVT::v4f64, Legal); @@ -18852,6 +18854,127 @@ return SDValue(); } +static SDValue LowerCTPOP(SDValue Op, const X86Subtarget *Subtarget, + SelectionDAG &DAG) { + SDNode *Node = Op.getNode(); + SDLoc dl(Node); + + Op = Op.getOperand(0); + EVT VT = Op.getValueType(); + assert((VT.is128BitVector() || VT.is256BitVector()) && + "CTPOP lowering only implemented for 128/256-bit wide vector types"); + + unsigned NumElts = VT.getVectorNumElements(); + EVT EltVT = VT.getVectorElementType(); + unsigned Len = EltVT.getSizeInBits(); + + assert(EltVT.isInteger() && Len == 32 && + "CTPOP not implemented for this vector element type."); + + if (VT.is256BitVector() && !Subtarget->hasInt256()) { + // Split v8i32 into two 128-bit CTPOP and merge back the result. Note + // that ctpop.v4i32 is not custom lowered by default, however we do it here + // because two ctpop.v4i32 for v8i32 are still better than the default + // implementation. + SDValue V0 = Extract128BitVector(Op, 0, DAG, dl); + SDValue V1 = Extract128BitVector(Op, NumElts/2, DAG, dl); + V0 = LowerCTPOP(DAG.getNode(ISD::CTPOP, dl, V0.getValueType(), V0), + Subtarget, DAG); + V1 = LowerCTPOP(DAG.getNode(ISD::CTPOP, dl, V1.getValueType(), V1), + Subtarget, DAG); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, V0, V1); + } + + // X86 canonicalize ANDs to vXi64, generate the appropriate bitcasts to avoid + // extra legalization. + MVT BitcastVT = VT.is256BitVector() ? MVT::v4i64 : MVT::v2i64; + + // This is the "best" algorithm from + // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + // with a minor tweak to use a series of adds + shifts instead of vector + // multiplications. + SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), EltVT); + SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), EltVT); + SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), EltVT); + + // v = v - ((v >> 1) & 0x55555555...) + SmallVector Ones(NumElts, DAG.getConstant(1, EltVT)); + SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ones); + SDValue Srl = DAG.getNode(ISD::SRL, dl, VT, Op, OnesV); + Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl); + + SmallVector Mask55(NumElts, Cst55); + SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask55); + M55 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M55); + + SDValue And = DAG.getNode(ISD::AND, dl, Srl.getValueType(), Srl, M55); + And = DAG.getNode(ISD::BITCAST, dl, VT, And); + SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, Op, And); + + // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...) + SmallVector Mask33(NumElts, Cst33); + SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask33); + SmallVector Twos(NumElts, DAG.getConstant(2, EltVT)); + SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Twos); + + Srl = DAG.getNode(ISD::SRL, dl, VT, Sub, TwosV); + Srl = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Srl); + M33 = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M33); + Sub = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Sub); + + SDValue AndRHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Srl, M33); + SDValue AndLHS = DAG.getNode(ISD::AND, dl, M33.getValueType(), Sub, M33); + if (VT != AndRHS.getValueType()) { + AndRHS = DAG.getNode(ISD::BITCAST, dl, VT, AndRHS); + AndLHS = DAG.getNode(ISD::BITCAST, dl, VT, AndLHS); + } + SDValue Add = DAG.getNode(ISD::ADD, dl, VT, AndLHS, AndRHS); + + // v = (v + (v >> 4)) & 0x0F0F0F0F... + SmallVector Fours(NumElts, DAG.getConstant(4, EltVT)); + SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Fours); + Srl = DAG.getNode(ISD::SRL, dl, VT, Add, FoursV); + Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl); + + SmallVector Mask0F(NumElts, Cst0F); + SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Mask0F); + Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add); + M0F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M0F); + And = DAG.getNode(ISD::AND, dl, M0F.getValueType(), Add, M0F); + And = DAG.getNode(ISD::BITCAST, dl, VT, And); + + // The algorithm mentioned above uses: + // v = (v * 0x01010101...) >> (Len - 8) + // + // Change it to use vector adds + vector shifts which yield faster results on + // Haswell than using vector integer multiplication. + // + // For i32 elements: + // v = v + (v >> 8) + // v = v + (v >> 16) + // + Add = And; + SmallVector Csts; + for (unsigned i = 8; i <= Len/2; i *= 2) { + Csts.assign(NumElts, DAG.getConstant(i, EltVT)); + SDValue CstsV = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Csts); + Srl = DAG.getNode(ISD::SRL, dl, VT, Add, CstsV); + Add = DAG.getNode(ISD::ADD, dl, VT, Add, Srl); + Csts.clear(); + } + + // The result is on the least significant 6-bits + SDValue Cst3F = DAG.getConstant(APInt(Len, 0x3F), EltVT); + SmallVector Cst3FV(NumElts, Cst3F); + SDValue M3F = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Cst3FV); + Add = DAG.getNode(ISD::BITCAST, dl, BitcastVT, Add); + M3F = DAG.getNode(ISD::BITCAST, dl, BitcastVT, M3F); + And = DAG.getNode(ISD::AND, dl, M3F.getValueType(), Add, M3F); + And = DAG.getNode(ISD::BITCAST, dl, VT, And); + + return And; +} + static SDValue LowerLOAD_SUB(SDValue Op, SelectionDAG &DAG) { SDNode *Node = Op.getNode(); SDLoc dl(Node); @@ -18979,6 +19102,7 @@ case ISD::ATOMIC_FENCE: return LowerATOMIC_FENCE(Op, Subtarget, DAG); case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS: return LowerCMP_SWAP(Op, Subtarget, DAG); + case ISD::CTPOP: return LowerCTPOP(Op, Subtarget, DAG); case ISD::ATOMIC_LOAD_SUB: return LowerLOAD_SUB(Op,DAG); case ISD::ATOMIC_STORE: return LowerATOMIC_STORE(Op,DAG); case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG); Index: test/CodeGen/X86/vector-ctpop.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/vector-ctpop.ll @@ -0,0 +1,78 @@ +; RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=corei7-avx | FileCheck -check-prefix=AVX1 %s +; RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=core-avx2 | FileCheck -check-prefix=AVX2 %s + +; This tests the x86 custom lowering for the llvm.ctpop.v8i32 instrinc. It +; implements a vectorized version of the following algorithm: +; +; v = v - ((v >> 1) & 0x55555555); +; v = (v & 0x33333333) + ((v >> 2) & 0x33333333); +; v = ((v + (v >> 4) & 0xF0F0F0F) +; v = v + (v >> 8) +; v = v + (v >> 16) +; v = v & 0x0000003F + +define <8 x i32> @test0(<8 x i32> %x) { +; AVX2-LABEL: @test0 +; AVX1-LABEL: @test0 +entry: +; AVX2: vpsrld $1, %ymm +; AVX2-NEXT: vpbroadcastd +; AVX2-NEXT: vpand +; AVX2-NEXT: vpsubd +; AVX2-NEXT: vpbroadcastd +; AVX2-NEXT: vpand +; AVX2-NEXT: vpsrld $2 +; AVX2-NEXT: vpand +; AVX2-NEXT: vpaddd +; AVX2-NEXT: vpsrld $4 +; AVX2-NEXT: vpaddd +; AVX2-NEXT: vpbroadcastd +; AVX2-NEXT: vpand +; AVX2-NEXT: vpsrld $8 +; AVX2-NEXT: vpaddd +; AVX2-NEXT: vpsrld $16 +; AVX2-NEXT: vpaddd +; AVX2-NEXT: vpbroadcastd +; AVX2-NEXT: vpand +; AVX1: vextractf128 $1, %ymm +; AVX1-NEXT: vpsrld $1, %xmm +; AVX1-NEXT: vmovdqa +; AVX1-NEXT: vpand +; AVX1-NEXT: vpsubd +; AVX1-NEXT: vmovdqa +; AVX1-NEXT: vpand +; AVX1-NEXT: vpsrld $2 +; AVX1-NEXT: vpand +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vpsrld $4 +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vmovdqa +; AVX1-NEXT: vpand +; AVX1-NEXT: vpsrld $8 +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vpsrld $16 +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vmovdqa +; AVX1-NEXT: vpand +; AVX1-NEXT: vpsrld $1, %xmm +; AVX1-NEXT: vpand +; AVX1-NEXT: vpsubd +; AVX1-NEXT: vpand +; AVX1-NEXT: vpsrld $2 +; AVX1-NEXT: vpand +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vpsrld $4 +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vpand +; AVX1-NEXT: vpsrld $8 +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vpsrld $16 +; AVX1-NEXT: vpaddd +; AVX1-NEXT: vpand +; AVX1-NEXT: vinsertf128 $1 + %y = call <8 x i32> @llvm.ctpop.v8i32(<8 x i32> %x) + ret <8 x i32> %y +} + +declare <8 x i32> @llvm.ctpop.v8i32(<8 x i32>) +