diff --git a/llvm/include/llvm/ADT/BitVector.h b/llvm/include/llvm/ADT/BitVector.h --- a/llvm/include/llvm/ADT/BitVector.h +++ b/llvm/include/llvm/ADT/BitVector.h @@ -568,6 +568,20 @@ return false; } + template + static BitVector &apply(F &&f, BitVector &Out, BitVector const &Arg, + ArgTys const &...Args) { + assert(llvm::all_of( + std::initializer_list{Args.size()...}, + [&Arg](auto const &BV) { return Arg.size() == BV; }) && + "consistent sizes"); + Out.resize(Arg.size()); + for (size_t i = 0, e = Out.NumBitWords(Arg.size()); i != e; ++i) + Out.Bits[i] = f(Arg.Bits[i], Args.Bits[i]...); + Out.clear_unused_bits(); + return Out; + } + BitVector &operator|=(const BitVector &RHS) { if (size() < RHS.size()) resize(RHS.size()); diff --git a/llvm/lib/CodeGen/CFIInstrInserter.cpp b/llvm/lib/CodeGen/CFIInstrInserter.cpp --- a/llvm/lib/CodeGen/CFIInstrInserter.cpp +++ b/llvm/lib/CodeGen/CFIInstrInserter.cpp @@ -264,9 +264,9 @@ MBBInfo.OutgoingCFARegister = SetRegister; // Update outgoing CSR info. - MBBInfo.OutgoingCSRSaved = MBBInfo.IncomingCSRSaved; - MBBInfo.OutgoingCSRSaved |= CSRSaved; - MBBInfo.OutgoingCSRSaved.reset(CSRRestored); + BitVector::apply([](auto x, auto y, auto z) { return (x | y) & ~z; }, + MBBInfo.OutgoingCSRSaved, MBBInfo.IncomingCSRSaved, CSRSaved, + CSRRestored); } void CFIInstrInserter::updateSuccCFAInfo(MBBCFAInfo &MBBInfo) { @@ -294,6 +294,7 @@ const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); bool InsertedCFIInstr = false; + BitVector SetDifference; for (MachineBasicBlock &MBB : MF) { // Skip the first MBB in a function if (MBB.getNumber() == MF.front().getNumber()) continue; @@ -345,8 +346,8 @@ continue; } - BitVector SetDifference = PrevMBBInfo->OutgoingCSRSaved; - SetDifference.reset(MBBInfo.IncomingCSRSaved); + BitVector::apply([](auto x, auto y) { return x & ~y; }, SetDifference, + PrevMBBInfo->OutgoingCSRSaved, MBBInfo.IncomingCSRSaved); for (int Reg : SetDifference.set_bits()) { unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRestore(nullptr, Reg)); @@ -355,8 +356,8 @@ InsertedCFIInstr = true; } - SetDifference = MBBInfo.IncomingCSRSaved; - SetDifference.reset(PrevMBBInfo->OutgoingCSRSaved); + BitVector::apply([](auto x, auto y) { return x & ~y; }, SetDifference, + MBBInfo.IncomingCSRSaved, PrevMBBInfo->OutgoingCSRSaved); for (int Reg : SetDifference.set_bits()) { auto it = CSRLocMap.find(Reg); assert(it != CSRLocMap.end() && "Reg should have an entry in CSRLocMap"); diff --git a/llvm/unittests/ADT/BitVectorTest.cpp b/llvm/unittests/ADT/BitVectorTest.cpp --- a/llvm/unittests/ADT/BitVectorTest.cpp +++ b/llvm/unittests/ADT/BitVectorTest.cpp @@ -779,6 +779,7 @@ EXPECT_TRUE(Vec.none()); } + TYPED_TEST(BitVectorTest, PortableBitMask) { TypeParam A; const uint32_t Mask1[] = { 0x80000000, 6, 5 }; @@ -1261,4 +1262,40 @@ } } +TEST(BitVectoryTest, Apply) { + for (int i = 0; i < 2; ++i) { + int j = i * 100 + 3; + + const BitVector x = + createBitVector(j + 5, {{i, i + 1}, {j - 1, j}}); + const BitVector y = createBitVector(j + 5, {{i, j}}); + const BitVector z = + createBitVector(j + 5, {{i + 1, i + 2}, {j, j + 1}}); + + auto op0 = [](auto x) { return ~x; }; + BitVector expected0 = x; + expected0.flip(); + BitVector out0(j - 2); + BitVector::apply(op0, out0, x); + EXPECT_EQ(out0, expected0); + + auto op1 = [](auto x, auto y) { return x & ~y; }; + BitVector expected1 = x; + expected1.reset(y); + BitVector out1; + BitVector::apply(op1, out1, x, y); + EXPECT_EQ(out1, expected1); + + auto op2 = [](auto x, auto y, auto z) { return (x ^ ~y) | z; }; + BitVector expected2 = y; + expected2.flip(); + expected2 ^= x; + expected2 |= z; + BitVector out2(j + 5); + BitVector::apply(op2, out2, x, y, z); + EXPECT_EQ(out2, expected2); + } +} + + } // namespace