diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2223,12 +2223,20 @@ /// Splat/Merge neighboring bits to widen/narrow the bitmask represented /// by \param A to \param NewBitWidth bits. /// -/// e.g. ScaleBitMask(0b0101, 8) -> 0b00110011 -/// e.g. ScaleBitMask(0b00011011, 4) -> 0b0111 +/// e.g. ScaleBitMask(0b0101, 8, ) -> 0b00110011 +/// e.g. ScaleBitMask(0b00011011, 4, BitMergingApproach::Greedy) -> 0b0111 +/// e.g. ScaleBitMask(0b00011011, 4, BitMergingApproach::Lossy) -> 0b0001 /// A.getBitwidth() or NewBitWidth must be a whole multiples of the other. /// -/// TODO: Do we need a mode where all bits must be set when merging down? -APInt ScaleBitMask(const APInt &A, unsigned NewBitWidth); +/// How should we merge neighboring bits? +enum struct BitMergingApproach { + // If *any* source bitisset, then the destination bit is set. + Greedy, + // If *all* source bits are set, only then the destination bit is set. + Lossy, +}; +APInt ScaleBitMask(const APInt &A, unsigned NewBitWidth, + BitMergingApproach Approach); } // namespace APIntOps // See friend declaration above. This additional declaration is required in diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2991,7 +2991,8 @@ // sub sections we actually care about. unsigned SubScale = SubBitWidth / BitWidth; APInt SubDemandedElts = - APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale); + APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale, + APIntOps::BitMergingApproach::Greedy); Known2 = computeKnownBits(N0, SubDemandedElts, Depth + 1); Known.Zero.setAllBits(); Known.One.setAllBits(); @@ -3799,8 +3800,8 @@ assert(VT.isVector() && "Expected bitcast to vector"); unsigned Scale = SrcBits / VTBits; - APInt SrcDemandedElts = - APIntOps::ScaleBitMask(DemandedElts, NumElts / Scale); + APInt SrcDemandedElts = APIntOps::ScaleBitMask( + DemandedElts, NumElts / Scale, APIntOps::BitMergingApproach::Greedy); // Fast case - sign splat can be simply split across the small elements. Tmp = ComputeNumSignBits(N0, SrcDemandedElts, Depth + 1); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2475,7 +2475,8 @@ // must demand a source element if any DemandedElt maps to it. if ((NumElts % NumSrcElts) == 0) { unsigned Scale = NumElts / NumSrcElts; - SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts); + SrcDemandedElts = APIntOps::ScaleBitMask( + DemandedElts, NumSrcElts, APIntOps::BitMergingApproach::Greedy); if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero, TLO, Depth + 1)) return true; @@ -2515,7 +2516,8 @@ // of this vector. if ((NumSrcElts % NumElts) == 0) { unsigned Scale = NumSrcElts / NumElts; - SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts); + SrcDemandedElts = APIntOps::ScaleBitMask( + DemandedElts, NumSrcElts, APIntOps::BitMergingApproach::Greedy); if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero, TLO, Depth + 1)) return true; diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -2959,7 +2959,8 @@ return A.getBitWidth() - ((A ^ B).countLeadingZeros() + 1); } -APInt llvm::APIntOps::ScaleBitMask(const APInt &A, unsigned NewBitWidth) { +APInt llvm::APIntOps::ScaleBitMask(const APInt &A, unsigned NewBitWidth, + BitMergingApproach Approach) { unsigned OldBitWidth = A.getBitWidth(); assert((((OldBitWidth % NewBitWidth) == 0) || ((NewBitWidth % OldBitWidth) == 0)) && @@ -2985,9 +2986,12 @@ } else { // Merge bits - if any old bit is set, then set scale equivalent new bit. unsigned Scale = OldBitWidth / NewBitWidth; - for (unsigned i = 0; i != NewBitWidth; ++i) - if (!A.extractBits(Scale, i * Scale).isNullValue()) + for (unsigned i = 0; i != NewBitWidth; ++i) { + APInt Part = A.extractBits(Scale, i * Scale); + if (Approach == BitMergingApproach::Greedy ? !Part.isNullValue() + : Part.isAllOnes()) NewA.setBit(i); + } } return NewA; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -39414,7 +39414,8 @@ // Aggressively peek through ops to get at the demanded elts. if (!DemandedElts.isAllOnesValue()) { unsigned NumSrcElts = LHS.getValueType().getVectorNumElements(); - APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts); + APInt DemandedSrcElts = APIntOps::ScaleBitMask( + DemandedElts, NumSrcElts, APIntOps::BitMergingApproach::Greedy); SDValue NewLHS = SimplifyMultipleUseDemandedVectorElts( LHS, DemandedSrcElts, TLO.DAG, Depth + 1); SDValue NewRHS = SimplifyMultipleUseDemandedVectorElts( diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2994,24 +2994,82 @@ EXPECT_EQ(0U, MZW1.getBitWidth()); } -TEST(APIntTest, ScaleBitMask) { - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x00), 8), APInt(8, 0x00)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x01), 8), APInt(8, 0x0F)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x02), 8), APInt(8, 0xF0)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x03), 8), APInt(8, 0xFF)); +TEST(APIntTest, ScaleBitMaskGreedily) { + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x00), 8, + APIntOps::BitMergingApproach::Greedy), + APInt(8, 0x00)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x01), 8, + APIntOps::BitMergingApproach::Greedy), + APInt(8, 0x0F)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x02), 8, + APIntOps::BitMergingApproach::Greedy), + APInt(8, 0xF0)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x03), 8, + APIntOps::BitMergingApproach::Greedy), + APInt(8, 0xFF)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 4), APInt(4, 0x00)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xFF), 4), APInt(4, 0x0F)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xE4), 4), APInt(4, 0x0E)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 4, + APIntOps::BitMergingApproach::Greedy), + APInt(4, 0x00)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xFF), 4, + APIntOps::BitMergingApproach::Greedy), + APInt(4, 0x0F)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xE4), 4, + APIntOps::BitMergingApproach::Greedy), + APInt(4, 0x0E)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 8), APInt(8, 0x00)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 8, + APIntOps::BitMergingApproach::Greedy), + APInt(8, 0x00)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getNullValue(1024), 4096), + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getNullValue(1024), 4096, + APIntOps::BitMergingApproach::Greedy), APInt::getNullValue(4096)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getAllOnes(4096), 256), + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getAllOnes(4096), 256, + APIntOps::BitMergingApproach::Greedy), APInt::getAllOnes(256)); - EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getOneBitSet(4096, 32), 256), + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getOneBitSet(4096, 32), 256, + APIntOps::BitMergingApproach::Greedy), APInt::getOneBitSet(256, 2)); } +TEST(APIntTest, ScaleBitMaskLossily) { + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x00), 8, + APIntOps::BitMergingApproach::Lossy), + APInt(8, 0x00)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x01), 8, + APIntOps::BitMergingApproach::Lossy), + APInt(8, 0x0F)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x02), 8, + APIntOps::BitMergingApproach::Lossy), + APInt(8, 0xF0)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x03), 8, + APIntOps::BitMergingApproach::Lossy), + APInt(8, 0xFF)); + + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 4, + APIntOps::BitMergingApproach::Lossy), + APInt(4, 0x00)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xFF), 4, + APIntOps::BitMergingApproach::Lossy), + APInt(4, 0x0F)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xE4), 4, + APIntOps::BitMergingApproach::Lossy), + APInt(4, 0x08)); + + EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 8, + APIntOps::BitMergingApproach::Lossy), + APInt(8, 0x00)); + + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getNullValue(1024), 4096, + APIntOps::BitMergingApproach::Lossy), + APInt::getNullValue(4096)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getAllOnes(4096), 256, + APIntOps::BitMergingApproach::Lossy), + APInt::getAllOnes(256)); + EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getOneBitSet(4096, 32), 256, + APIntOps::BitMergingApproach::Lossy), + APInt::getNullValue(256)); +} + } // end anonymous namespace