diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -304,6 +304,16 @@ return KnownBits(~C, C); } + /// Concatenate 2 known bits together. + static KnownBits concatBits(const KnownBits &Lo, const KnownBits &Hi) { + unsigned LoBits = Lo.getBitWidth(); + unsigned HiBits = Hi.getBitWidth(); + KnownBits Concat(LoBits + HiBits); + Concat.insertBits(Lo, 0); + Concat.insertBits(Hi, LoBits); + return Concat; + } + /// Compute known bits common to LHS and RHS. static KnownBits commonBits(const KnownBits &LHS, const KnownBits &RHS) { return KnownBits(LHS.Zero & RHS.Zero, LHS.One & RHS.One); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3335,13 +3335,11 @@ assert((Op.getResNo() == 0 || Op.getResNo() == 1) && "Unknown result"); // Collect lo/hi source values and concatenate. - // TODO: Would a KnownBits::concatBits helper be useful? unsigned LoBits = Op.getOperand(0).getScalarValueSizeInBits(); unsigned HiBits = Op.getOperand(1).getScalarValueSizeInBits(); Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - Known = Known.anyext(LoBits + HiBits); - Known.insertBits(Known2, LoBits); + Known = KnownBits::concatBits(Known, Known2); // Collect shift amount. Known2 = computeKnownBits(Op.getOperand(2), DemandedElts, Depth + 1); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2225,11 +2225,7 @@ if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownHi, TLO, Depth + 1)) return true; - Known.Zero = KnownLo.Zero.zext(BitWidth) | - KnownHi.Zero.zext(BitWidth).shl(HalfBitWidth); - - Known.One = KnownLo.One.zext(BitWidth) | - KnownHi.One.zext(BitWidth).shl(HalfBitWidth); + Known = KnownBits::concatBits(KnownLo, KnownHi); break; } case ISD::ZERO_EXTEND: diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -508,4 +508,29 @@ }); } +TEST(KnownBitsTest, ConcatBits) { + unsigned Bits = 4; + for (unsigned LoBits = 1; LoBits < Bits; ++LoBits) { + unsigned HiBits = Bits - LoBits; + ForeachKnownBits(LoBits, [&](const KnownBits &KnownLo) { + ForeachKnownBits(HiBits, [&](const KnownBits &KnownHi) { + KnownBits KnownAll = KnownBits::concatBits(KnownLo, KnownHi); + + EXPECT_EQ(KnownLo.countMinPopulation() + KnownHi.countMinPopulation(), + KnownAll.countMinPopulation()); + EXPECT_EQ(KnownLo.countMaxPopulation() + KnownHi.countMaxPopulation(), + KnownAll.countMaxPopulation()); + + KnownBits ExtractLo = KnownAll.extractBits(LoBits, 0); + KnownBits ExtractHi = KnownAll.extractBits(HiBits, LoBits); + + EXPECT_EQ(KnownLo.One.getZExtValue(), ExtractLo.One.getZExtValue()); + EXPECT_EQ(KnownHi.One.getZExtValue(), ExtractHi.One.getZExtValue()); + EXPECT_EQ(KnownLo.Zero.getZExtValue(), ExtractLo.Zero.getZExtValue()); + EXPECT_EQ(KnownHi.Zero.getZExtValue(), ExtractHi.Zero.getZExtValue()); + }); + }); + } +} + } // end anonymous namespace