diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -496,6 +496,23 @@ return (Ones + LeadZ + countTrailingZeros()) == BitWidth; } + /// Return true if this APInt value contains a non-empty sequence of ones with + /// the remainder zero. If true, \p MaskIdx will specify the index of the + /// lowest set bit and \p MaskLen is updated to specify the length of the + /// mask, else neither are updated. + bool isShiftedMask(unsigned &MaskIdx, unsigned &MaskLen) const { + if (isSingleWord()) + return isShiftedMask_64(U.VAL, MaskIdx, MaskLen); + unsigned Ones = countPopulationSlowCase(); + unsigned LeadZ = countLeadingZerosSlowCase(); + unsigned TrailZ = countTrailingZerosSlowCase(); + if ((Ones + LeadZ + TrailZ) != BitWidth) + return false; + MaskLen = Ones; + MaskIdx = TrailZ; + return true; + } + /// Compute an APInt containing numBits highbits from this APInt. /// /// Get an APInt with the same BitWidth as this APInt, just zero mask the low diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h --- a/llvm/include/llvm/Support/MathExtras.h +++ b/llvm/include/llvm/Support/MathExtras.h @@ -571,6 +571,33 @@ return detail::PopulationCounter::count(Value); } +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true. +/// If true, \p MaskIdx will specify the index of the lowest set bit and \p +/// MaskLen is updated to specify the length of the mask, else neither are +/// updated. +inline bool isShiftedMask_32(uint32_t Value, unsigned &MaskIdx, + unsigned &MaskLen) { + if (!isShiftedMask_32(Value)) + return false; + MaskIdx = countTrailingZeros(Value); + MaskLen = countPopulation(Value); + return true; +} + +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (64 bit version.) If true, \p MaskIdx will specify the index +/// of the lowest set bit and \p MaskLen is updated to specify the length of the +/// mask, else neither are updated. +inline bool isShiftedMask_64(uint64_t Value, unsigned &MaskIdx, + unsigned &MaskLen) { + if (!isShiftedMask_64(Value)) + return false; + MaskIdx = countTrailingZeros(Value); + MaskLen = countPopulation(Value); + return true; +} + /// Compile time Log2. /// Valid only for positive powers of two. template constexpr inline size_t CTLog2() { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12254,10 +12254,7 @@ unsigned ActiveBits = 0; if (Mask.isMask()) { ActiveBits = Mask.countTrailingOnes(); - } else if (Mask.isShiftedMask()) { - ShAmt = Mask.countTrailingZeros(); - APInt ShiftedMask = Mask.lshr(ShAmt); - ActiveBits = ShiftedMask.countTrailingOnes(); + } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) { HasShiftedOffset = true; } else return SDValue(); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -3281,8 +3281,9 @@ // this improves the ability to match BFE patterns in isel. if (LHS.getOpcode() == ISD::AND) { if (auto *Mask = dyn_cast(LHS.getOperand(1))) { - if (Mask->getAPIntValue().isShiftedMask() && - Mask->getAPIntValue().countTrailingZeros() == ShiftAmt) { + unsigned MaskIdx, MaskLen; + if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) && + MaskIdx == ShiftAmt) { return DAG.getNode( ISD::AND, SL, VT, DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)), diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp --- a/llvm/lib/Target/Mips/MipsISelLowering.cpp +++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp @@ -94,18 +94,6 @@ Mips::D16_64, Mips::D17_64, Mips::D18_64, Mips::D19_64 }; -// If I is a shifted mask, set the size (Size) and the first bit of the -// mask (Pos), and return true. -// For example, if I is 0x003ff800, (Pos, Size) = (11, 11). -static bool isShiftedMask(uint64_t I, uint64_t &Pos, uint64_t &Size) { - if (!isShiftedMask_64(I)) - return false; - - Size = countPopulation(I); - Pos = countTrailingZeros(I); - return true; -} - // The MIPS MSA ABI passes vector arguments in the integer register set. // The number of integer registers used is dependant on the ABI used. MVT MipsTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, @@ -794,14 +782,15 @@ EVT ValTy = N->getValueType(0); SDLoc DL(N); - uint64_t Pos = 0, SMPos, SMSize; + uint64_t Pos = 0; + unsigned SMPos, SMSize; ConstantSDNode *CN; SDValue NewOperand; unsigned Opc; // Op's second operand must be a shifted mask. if (!(CN = dyn_cast(Mask)) || - !isShiftedMask(CN->getZExtValue(), SMPos, SMSize)) + !isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize)) return SDValue(); if (FirstOperandOpc == ISD::SRA || FirstOperandOpc == ISD::SRL) { @@ -875,7 +864,7 @@ return SDValue(); SDValue And0 = N->getOperand(0), And1 = N->getOperand(1); - uint64_t SMPos0, SMSize0, SMPos1, SMSize1; + unsigned SMPos0, SMSize0, SMPos1, SMSize1; ConstantSDNode *CN, *CN1; // See if Op's first operand matches (and $src1 , mask0). @@ -883,7 +872,7 @@ return SDValue(); if (!(CN = dyn_cast(And0.getOperand(1))) || - !isShiftedMask(~CN->getSExtValue(), SMPos0, SMSize0)) + !isShiftedMask_64(~CN->getSExtValue(), SMPos0, SMSize0)) return SDValue(); // See if Op's second operand matches (and (shl $src, pos), mask1). @@ -891,7 +880,7 @@ And1.getOperand(0).getOpcode() == ISD::SHL) { if (!(CN = dyn_cast(And1.getOperand(1))) || - !isShiftedMask(CN->getZExtValue(), SMPos1, SMSize1)) + !isShiftedMask_64(CN->getZExtValue(), SMPos1, SMSize1)) return SDValue(); // The shift masks must have the same position and size. @@ -1118,7 +1107,8 @@ EVT ValTy = N->getValueType(0); SDLoc DL(N); - uint64_t Pos = 0, SMPos, SMSize; + uint64_t Pos = 0; + unsigned SMPos, SMSize; ConstantSDNode *CN; SDValue NewOperand; @@ -1136,7 +1126,7 @@ // AND's second operand must be a shifted mask. if (!(CN = dyn_cast(FirstOperand.getOperand(1))) || - !isShiftedMask(CN->getZExtValue(), SMPos, SMSize)) + !isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize)) return SDValue(); // Return if the shifted mask does not start at bit 0 or the sum of its size diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp --- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp +++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp @@ -996,20 +996,18 @@ return IC.replaceInstUsesWith(II, II.getArgOperand(0)); } - if (MaskC->getValue().isShiftedMask()) { + unsigned MaskIdx, MaskLen; + if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) { // any single contingous sequence of 1s anywhere in the mask simply // describes a subset of the input bits shifted to the appropriate // position. Replace with the straight forward IR. - unsigned ShiftAmount = MaskC->getValue().countTrailingZeros(); Value *Input = II.getArgOperand(0); Value *Masked = IC.Builder.CreateAnd(Input, II.getArgOperand(1)); - Value *Shifted = IC.Builder.CreateLShr(Masked, - ConstantInt::get(II.getType(), - ShiftAmount)); + Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx); + Value *Shifted = IC.Builder.CreateLShr(Masked, ShiftAmt); return IC.replaceInstUsesWith(II, Shifted); } - if (auto *SrcC = dyn_cast(II.getArgOperand(0))) { uint64_t Src = SrcC->getZExtValue(); uint64_t Mask = MaskC->getZExtValue(); @@ -1041,15 +1039,15 @@ if (MaskC->isAllOnesValue()) { return IC.replaceInstUsesWith(II, II.getArgOperand(0)); } - if (MaskC->getValue().isShiftedMask()) { + + unsigned MaskIdx, MaskLen; + if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) { // any single contingous sequence of 1s anywhere in the mask simply // describes a subset of the input bits shifted to the appropriate // position. Replace with the straight forward IR. - unsigned ShiftAmount = MaskC->getValue().countTrailingZeros(); Value *Input = II.getArgOperand(0); - Value *Shifted = IC.Builder.CreateShl(Input, - ConstantInt::get(II.getType(), - ShiftAmount)); + Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx); + Value *Shifted = IC.Builder.CreateShl(Input, ShiftAmt); Value *Masked = IC.Builder.CreateAnd(Shifted, II.getArgOperand(1)); return IC.replaceInstUsesWith(II, Masked); } diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -1746,21 +1746,43 @@ EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask()); EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask()); + unsigned MaskIdx, MaskLen; + EXPECT_FALSE(APInt(32, 0x01010101).isShiftedMask(MaskIdx, MaskLen)); + EXPECT_TRUE(APInt(32, 0xf0000000).isShiftedMask(MaskIdx, MaskLen)); + EXPECT_EQ(28, MaskIdx); + EXPECT_EQ(4, MaskLen); + EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask(MaskIdx, MaskLen)); + EXPECT_EQ(16, MaskIdx); + EXPECT_EQ(16, MaskLen); + EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask(MaskIdx, MaskLen)); + EXPECT_EQ(1, MaskIdx); + EXPECT_EQ(8, MaskLen); + for (int N : { 1, 2, 3, 4, 7, 8, 16, 32, 64, 127, 128, 129, 256 }) { EXPECT_FALSE(APInt(N, 0).isShiftedMask()); + EXPECT_FALSE(APInt(N, 0).isShiftedMask(MaskIdx, MaskLen)); APInt One(N, 1); for (int I = 1; I < N; ++I) { APInt MaskVal = One.shl(I) - 1; EXPECT_TRUE(MaskVal.isShiftedMask()); + EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen)); + EXPECT_EQ(0, MaskIdx); + EXPECT_EQ(I, MaskLen); } for (int I = 1; I < N - 1; ++I) { APInt MaskVal = One.shl(I); EXPECT_TRUE(MaskVal.isShiftedMask()); + EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen)); + EXPECT_EQ(I, MaskIdx); + EXPECT_EQ(1, MaskLen); } for (int I = 1; I < N; ++I) { APInt MaskVal = APInt::getHighBitsSet(N, I); EXPECT_TRUE(MaskVal.isShiftedMask()); + EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen)); + EXPECT_EQ(N - I, MaskIdx); + EXPECT_EQ(I, MaskLen); } } } diff --git a/llvm/unittests/Support/MathExtrasTest.cpp b/llvm/unittests/Support/MathExtrasTest.cpp --- a/llvm/unittests/Support/MathExtrasTest.cpp +++ b/llvm/unittests/Support/MathExtrasTest.cpp @@ -180,6 +180,18 @@ EXPECT_TRUE(isShiftedMask_32(0xf0000000)); EXPECT_TRUE(isShiftedMask_32(0xffff0000)); EXPECT_TRUE(isShiftedMask_32(0xff << 1)); + + unsigned MaskIdx, MaskLen; + EXPECT_FALSE(isShiftedMask_32(0x01010101, MaskIdx, MaskLen)); + EXPECT_TRUE(isShiftedMask_32(0xf0000000, MaskIdx, MaskLen)); + EXPECT_EQ(28, MaskIdx); + EXPECT_EQ(4, MaskLen); + EXPECT_TRUE(isShiftedMask_32(0xffff0000, MaskIdx, MaskLen)); + EXPECT_EQ(16, MaskIdx); + EXPECT_EQ(16, MaskLen); + EXPECT_TRUE(isShiftedMask_32(0xff << 1, MaskIdx, MaskLen)); + EXPECT_EQ(1, MaskIdx); + EXPECT_EQ(8, MaskLen); } TEST(MathExtras, isShiftedMask_64) { @@ -187,6 +199,18 @@ EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull)); EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull)); EXPECT_TRUE(isShiftedMask_64(0xffull << 55)); + + unsigned MaskIdx, MaskLen; + EXPECT_FALSE(isShiftedMask_64(0x0101010101010101ull, MaskIdx, MaskLen)); + EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull, MaskIdx, MaskLen)); + EXPECT_EQ(60, MaskIdx); + EXPECT_EQ(4, MaskLen); + EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull, MaskIdx, MaskLen)); + EXPECT_EQ(48, MaskIdx); + EXPECT_EQ(16, MaskLen); + EXPECT_TRUE(isShiftedMask_64(0xffull << 55, MaskIdx, MaskLen)); + EXPECT_EQ(55, MaskIdx); + EXPECT_EQ(8, MaskLen); } TEST(MathExtras, isPowerOf2_32) {