diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -218,6 +218,13 @@ One.extractBits(NumBits, BitPosition)); } + /// Concatenate the bits from \p Lo onto the bottom of *this. This is + /// equivalent to: + /// (this->zext(NewWidth) << Lo.getBitWidth()) | Lo.zext(NewWidth) + KnownBits concat(const KnownBits &Lo) const { + return KnownBits(Zero.concat(Lo.Zero), One.concat(Lo.One)); + } + /// Return KnownBits based on this, but updated given that the underlying /// value is known to be greater than or equal to Val. KnownBits makeGE(const APInt &Val) const; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3293,13 +3293,11 @@ assert((Op.getResNo() == 0 || Op.getResNo() == 1) && "Unknown result"); // Collect lo/hi source values and concatenate. - // TODO: Would a KnownBits::concatBits helper be useful? unsigned LoBits = Op.getOperand(0).getScalarValueSizeInBits(); unsigned HiBits = Op.getOperand(1).getScalarValueSizeInBits(); Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - Known = Known.anyext(LoBits + HiBits); - Known.insertBits(Known2, LoBits); + Known = Known2.concat(Known); // Collect shift amount. Known2 = computeKnownBits(Op.getOperand(2), DemandedElts, Depth + 1); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2235,11 +2235,7 @@ if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownHi, TLO, Depth + 1)) return true; - Known.Zero = KnownLo.Zero.zext(BitWidth) | - KnownHi.Zero.zext(BitWidth).shl(HalfBitWidth); - - Known.One = KnownLo.One.zext(BitWidth) | - KnownHi.One.zext(BitWidth).shl(HalfBitWidth); + Known = KnownHi.concat(KnownLo); break; } case ISD::ZERO_EXTEND: diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -508,4 +508,29 @@ }); } +TEST(KnownBitsTest, ConcatBits) { + unsigned Bits = 4; + for (unsigned LoBits = 1; LoBits < Bits; ++LoBits) { + unsigned HiBits = Bits - LoBits; + ForeachKnownBits(LoBits, [&](const KnownBits &KnownLo) { + ForeachKnownBits(HiBits, [&](const KnownBits &KnownHi) { + KnownBits KnownAll = KnownHi.concat(KnownLo); + + EXPECT_EQ(KnownLo.countMinPopulation() + KnownHi.countMinPopulation(), + KnownAll.countMinPopulation()); + EXPECT_EQ(KnownLo.countMaxPopulation() + KnownHi.countMaxPopulation(), + KnownAll.countMaxPopulation()); + + KnownBits ExtractLo = KnownAll.extractBits(LoBits, 0); + KnownBits ExtractHi = KnownAll.extractBits(HiBits, LoBits); + + EXPECT_EQ(KnownLo.One.getZExtValue(), ExtractLo.One.getZExtValue()); + EXPECT_EQ(KnownHi.One.getZExtValue(), ExtractHi.One.getZExtValue()); + EXPECT_EQ(KnownLo.Zero.getZExtValue(), ExtractLo.Zero.getZExtValue()); + EXPECT_EQ(KnownHi.Zero.getZExtValue(), ExtractHi.Zero.getZExtValue()); + }); + }); + } +} + } // end anonymous namespace