Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1047,6 +1047,133 @@ return nullptr; } +// Check if BI matches a ctpop calculation pattern for a value of width BW +// bits. If so, return the argument V such that ctpop(V) would be a candidate +// for replacing BI. If BW is less than the bit-width of the type of V, then +// BI == ctpop(V) if the bits of V beyond BW are zero. +// The matched pattern is: +// x0 := (V & 0x55..5) + ((V>>1) & 0x55..5) +// x1 := (x0 & 0x33..3) + ((x0>>2) & 0x33..3) +// ... +// xn := (xn-1 & 0x00..0FF..F) + ((xn-1>>S/2) & 0x00..0FF..F) +// where xn is the candidate for ctpop(V). +static Value *matchCtpopW(BinaryOperator *BI, unsigned BW) { + auto matchStep = [] (Value *V, unsigned S, APInt M, bool ShiftAlone) + -> Value* { + Value *Op0 = nullptr, *Op1 = nullptr; + if (!match(V, m_Add(m_Value(Op0), m_Value(Op1)))) + return nullptr; + + auto matchAndShift = [S,M,ShiftAlone] (Value *V0, Value *V1) -> Value* { + Value *V = nullptr; + const APInt *P = &M; + auto Mask = m_APInt(P); + auto Shift = m_SpecificInt(S); + + if (!match(V0, m_And(m_Value(V), Mask))) + return nullptr; + if (ShiftAlone) { + if (!match(V1, m_LShr(m_Specific(V), Shift))) + return nullptr; + } else { + if (!match(V1, m_And(m_LShr(m_Specific(V), Shift), Mask))) + return nullptr; + } + return V; + }; + + if (Value *T = matchAndShift(Op0, Op1)) + return T; + if (Value *T = matchAndShift(Op1, Op0)) + return T; + return nullptr; + }; + + // Generate the bitmask for the & operation. BW is the bit-width of the + // entire mask. The masks are: + // 0b01010101..01010101 0x55..55 1 bit every 2 bits + // 0b00110011..00110011 0x33..35 2 bits every 4 bits + // 0b00000111..00000111 0x07..07 3 bits every 8 bits + // ... ... logS bits every S bits + // Normally the masks would be 01010101, 00110011, 00001111, i.e. the + // number of contiguous 1 bits in each group would be twice the number + // in the previous mask, but by the time this code runs, the "demanded" + // bits have been optimized to only require one more 1 bit in each + // subsequent mask. This function generates the post-optimized masks. + auto getMask = [] (unsigned S, unsigned BW) -> APInt { + assert(isPowerOf2_32(S)); + APInt M(BW, S-1); + APInt T(BW, 0); + while (M != 0) { + T |= M; + M <<= S; + } + return T; + }; + + Value *V = BI; + bool SA = true; + unsigned N = BW; + while (N > 1) { + unsigned S = N/2; + V = matchStep(V, S, getMask(N, BW), SA); + if (!V) + return nullptr; + N = S; + SA = false; + } + + return V; +} + +static Value *optimizeToCtpop(BinaryOperator *BI, + InstCombiner::BuilderTy *Builder) { + IntegerType *Ty = dyn_cast(BI->getType()); + if (!Ty) + return nullptr; + + // Take the first shift amount feeding the add, and assume this is the + // last shift in the popcnt computation. + Value *Op0 = nullptr, *Op1 = nullptr; + if (!match(BI, m_Add(m_Value(Op0), m_Value(Op1)))) + return nullptr; + + // Shift by half-width. + uint64_t SH = 0; + if (!match(Op0, m_And(m_Value(), m_LShr(m_Value(), m_ConstantInt(SH)))) && + !match(Op1, m_And(m_Value(), m_LShr(m_Value(), m_ConstantInt(SH)))) && + !match(Op0, m_LShr(m_Value(), m_ConstantInt(SH))) && + !match(Op1, m_LShr(m_Value(), m_ConstantInt(SH)))) + return nullptr; + + if (SH < 4 || !isPowerOf2_64(SH)) + return nullptr; + + Value *V = matchCtpopW(BI, 2*SH); + if (!V) + return nullptr; + + Module *M = Builder->GetInsertBlock()->getParent()->getParent(); + unsigned TW = Ty->getBitWidth(), BW = 2*SH; + if (BW < TW) { + // BW is the bit width of the expression whose population count is + // being calculated. TW is the bit width of the type associated with + // that expression. Usually they are the same, but for ctpop8 the + // type may be "unsigned", i.e. 32-bit, while the ctpop8 would only + // consider the low 8 bits. In that case BW=8 and TW=32. + APInt K0(TW, 0), K1(TW, 0); + computeKnownBits(V, K0, K1, M->getDataLayout()); + APInt Need0 = APInt::getBitsSet(TW, BW, TW); + if ((K0 & Need0) != Need0) + return nullptr; + } + + Value *Func = Intrinsic::getDeclaration(M, Intrinsic::ctpop, {V->getType()}); + CallInst *CI = Builder->CreateCall(Func, {V}); + CI->setDebugLoc(BI->getDebugLoc()); + return CI; +} + Instruction *InstCombiner::visitAdd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); @@ -1295,6 +1422,9 @@ } } + if (Value *V = optimizeToCtpop(&I, Builder)) + return ReplaceInstUsesWith(I, V); + // TODO(jingyue): Consider WillNotOverflowSignedAdd and // WillNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. Index: test/Transforms/InstCombine/ctpop-match.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/ctpop-match.ll @@ -0,0 +1,145 @@ +; RUN: opt -instcombine -S < %s | FileCheck %s + +; unsigned pop8(unsigned char t0) { +; unsigned t1 = (t0 & 0x55) + ((t0>>1) & 0x55); +; unsigned t2 = (t1 & 0x33) + ((t1>>2) & 0x33); +; unsigned t3 = (t2 & 0x0F) + ((t2>>4) & 0x0F); +; return t3; +; } +; +; CHECK: define i32 @pop8 +; CHECK: [[ARG8:%[a-zA-Z0-9_]+]] = zext i8 %t0 to i32 +; CHECK: @llvm.ctpop.i32(i32 [[ARG8]]) +define i32 @pop8(i8 zeroext %t0) #0 { +entry: + %conv = zext i8 %t0 to i32 + %and = and i32 %conv, 85 + %shr = ashr i32 %conv, 1 + %and2 = and i32 %shr, 85 + %add = add nsw i32 %and, %and2 + %and3 = and i32 %add, 51 + %shr4 = lshr i32 %add, 2 + %and5 = and i32 %shr4, 51 + %add6 = add i32 %and3, %and5 + %and7 = and i32 %add6, 15 + %shr8 = lshr i32 %add6, 4 + %and9 = and i32 %shr8, 15 + %add10 = add i32 %and7, %and9 + ret i32 %add10 +} + + +; unsigned pop16(unsigned short t0) { +; unsigned t1 = (t0 & 0x5555) + ((t0>>1) & 0x5555); +; unsigned t2 = (t1 & 0x3333) + ((t1>>2) & 0x3333); +; unsigned t3 = (t2 & 0x0F0F) + ((t2>>4) & 0x0F0F); +; unsigned t4 = (t3 & 0x00FF) + ((t3>>8) & 0x00FF); +; return t4; +; } +; +; CHECK: define i32 @pop16 +; CHECK: [[ARG16:%[a-zA-Z0-9_]+]] = zext i16 %t0 to i32 +; CHECK: @llvm.ctpop.i32(i32 [[ARG16]]) +define i32 @pop16(i16 zeroext %t0) #0 { +entry: + %conv = zext i16 %t0 to i32 + %and = and i32 %conv, 21845 + %shr = ashr i32 %conv, 1 + %and2 = and i32 %shr, 21845 + %add = add nsw i32 %and, %and2 + %and3 = and i32 %add, 13107 + %shr4 = lshr i32 %add, 2 + %and5 = and i32 %shr4, 13107 + %add6 = add i32 %and3, %and5 + %and7 = and i32 %add6, 3855 + %shr8 = lshr i32 %add6, 4 + %and9 = and i32 %shr8, 3855 + %add10 = add i32 %and7, %and9 + %and11 = and i32 %add10, 255 + %shr12 = lshr i32 %add10, 8 + %and13 = and i32 %shr12, 255 + %add14 = add i32 %and11, %and13 + ret i32 %add14 +} + + +; unsigned pop32(unsigned t0) { +; unsigned t1 = (t0 & 0x55555555) + ((t0>>1) & 0x55555555); +; unsigned t2 = (t1 & 0x33333333) + ((t1>>2) & 0x33333333); +; unsigned t3 = (t2 & 0x0F0F0F0F) + ((t2>>4) & 0x0F0F0F0F); +; unsigned t4 = (t3 & 0x00FF00FF) + ((t3>>8) & 0x00FF00FF); +; unsigned t5 = (t4 & 0x0000FFFF) + ((t4>>16) & 0x0000FFFF); +; return t5; +; } +; +; CHECK: define i32 @pop32 +; CHECK: @llvm.ctpop.i32(i32 %t0) +define i32 @pop32(i32 %t0) #0 { +entry: + %and = and i32 %t0, 1431655765 + %shr = lshr i32 %t0, 1 + %and1 = and i32 %shr, 1431655765 + %add = add i32 %and, %and1 + %and2 = and i32 %add, 858993459 + %shr3 = lshr i32 %add, 2 + %and4 = and i32 %shr3, 858993459 + %add5 = add i32 %and2, %and4 + %and6 = and i32 %add5, 252645135 + %shr7 = lshr i32 %add5, 4 + %and8 = and i32 %shr7, 252645135 + %add9 = add i32 %and6, %and8 + %and10 = and i32 %add9, 16711935 + %shr11 = lshr i32 %add9, 8 + %and12 = and i32 %shr11, 16711935 + %add13 = add i32 %and10, %and12 + %and14 = and i32 %add13, 65535 + %shr15 = lshr i32 %add13, 16 + %and16 = and i32 %shr15, 65535 + %add17 = add i32 %and14, %and16 + ret i32 %add17 +} + + +; typedef unsigned long long u64_t; +; u64_t pop64(u64_t t0) { +; u64_t t1 = (t0 & 0x5555555555555555LL) + ((t0>>1) & 0x5555555555555555LL); +; u64_t t2 = (t1 & 0x3333333333333333LL) + ((t1>>2) & 0x3333333333333333LL); +; u64_t t3 = (t2 & 0x0F0F0F0F0F0F0F0FLL) + ((t2>>4) & 0x0F0F0F0F0F0F0F0FLL); +; u64_t t4 = (t3 & 0x00FF00FF00FF00FFLL) + ((t3>>8) & 0x00FF00FF00FF00FFLL); +; u64_t t5 = (t4 & 0x0000FFFF0000FFFFLL) + ((t4>>16) & 0x0000FFFF0000FFFFLL); +; u64_t t6 = (t5 & 0x00000000FFFFFFFFLL) + ((t5>>32) & 0x00000000FFFFFFFFLL); +; return t6; +; } +; +; CHECK: define i64 @pop64 +; CHECK: @llvm.ctpop.i64(i64 %t0) +define i64 @pop64(i64 %t0) #0 { +entry: + %and = and i64 %t0, 6148914691236517205 + %shr = lshr i64 %t0, 1 + %and1 = and i64 %shr, 6148914691236517205 + %add = add i64 %and, %and1 + %and2 = and i64 %add, 3689348814741910323 + %shr3 = lshr i64 %add, 2 + %and4 = and i64 %shr3, 3689348814741910323 + %add5 = add i64 %and2, %and4 + %and6 = and i64 %add5, 1085102592571150095 + %shr7 = lshr i64 %add5, 4 + %and8 = and i64 %shr7, 1085102592571150095 + %add9 = add i64 %and6, %and8 + %and10 = and i64 %add9, 71777214294589695 + %shr11 = lshr i64 %add9, 8 + %and12 = and i64 %shr11, 71777214294589695 + %add13 = add i64 %and10, %and12 + %and14 = and i64 %add13, 281470681808895 + %shr15 = lshr i64 %add13, 16 + %and16 = and i64 %shr15, 281470681808895 + %add17 = add i64 %and14, %and16 + %and18 = and i64 %add17, 4294967295 + %shr19 = lshr i64 %add17, 32 + %and20 = and i64 %shr19, 4294967295 + %add21 = add i64 %and18, %and20 + ret i64 %add21 +} + +attributes #0 = { nounwind }