diff --git a/llvm/include/llvm/IR/CheckpointChanges.h b/llvm/include/llvm/IR/CheckpointChanges.h --- a/llvm/include/llvm/IR/CheckpointChanges.h +++ b/llvm/include/llvm/IR/CheckpointChanges.h @@ -90,6 +90,39 @@ LLVM_DUMP_METHOD void dump() const override; #endif // NDEBUG }; + +class TakeName : public ChangeBase { + std::string OrigName; + Value *FromV; + +public: + TakeName(Value *Val, Value *FromV, CheckpointTracker *CT); + void revert() override; + void apply() override{}; + static bool classof(const ChangeBase *Other) { + return Other->getID() == ChangeID::TakeNameID; + } + ~TakeName() {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const override; + LLVM_DUMP_METHOD void dump() const override; +#endif // NDEBUG +}; + +class DestroyName : public ChangeBase { +public: + DestroyName(Value *Val, CheckpointTracker *CT); + void revert() override {} + void apply() override; + static bool classof(const ChangeBase *Other) { + return Other->getID() == ChangeID::DestroyNameID; + } + ~DestroyName() {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const override; + LLVM_DUMP_METHOD void dump() const override; +#endif // NDEBUG +}; } // namespace llvm #endif // LLVM_IR_CHECKPOINTCHANGES_H diff --git a/llvm/include/llvm/IR/CheckpointTracker.h b/llvm/include/llvm/IR/CheckpointTracker.h --- a/llvm/include/llvm/IR/CheckpointTracker.h +++ b/llvm/include/llvm/IR/CheckpointTracker.h @@ -99,6 +99,10 @@ /// To be called when \p V is about to get its name updated. void setName(Value *V); + /// Called by V->takeName(FromV). + void takeName(Value *V, Value *FromV); + /// Gets called by ~Value() before destroyValueName(). + void destroyName(Value *V); // Main API functions. These are called by `Checkpoint`. diff --git a/llvm/lib/IR/CheckpointChanges.cpp b/llvm/lib/IR/CheckpointChanges.cpp --- a/llvm/lib/IR/CheckpointChanges.cpp +++ b/llvm/lib/IR/CheckpointChanges.cpp @@ -63,3 +63,36 @@ void SetName::dump() const { dump(dbgs()); } #endif // NDEBUG + +TakeName::TakeName(Value *Val, Value *FromV, CheckpointTracker *CT) + : ChangeBase(Val, ChangeID::TakeNameID, CT), OrigName(Val->getName()), + FromV(FromV) {} + +void TakeName::revert() { + std::string CurrName(V->getName()); + V->setName(OrigName); + FromV->setName(CurrName); +} + +#ifndef NDEBUG +void TakeName::dump(raw_ostream &OS) const { + dumpCommon(OS); + OS << "TakeName\n"; +} + +void TakeName::dump() const { dump(dbgs()); } +#endif // NDEBUG + +DestroyName::DestroyName(Value *Val, CheckpointTracker *CT) + : ChangeBase(Val, ChangeID::DestroyNameID, CT) {} + +void DestroyName::apply() { V->destroyValueName(); } + +#ifndef NDEBUG +void DestroyName::dump(raw_ostream &OS) const { + dumpCommon(OS); + OS << "DestroyName\n"; +} + +void DestroyName::dump() const { dump(dbgs()); } +#endif // NDEBUG diff --git a/llvm/lib/IR/CheckpointTracker.cpp b/llvm/lib/IR/CheckpointTracker.cpp --- a/llvm/lib/IR/CheckpointTracker.cpp +++ b/llvm/lib/IR/CheckpointTracker.cpp @@ -62,6 +62,28 @@ Changes.push_back(make_unique(V, this)); } +void CheckpointTracker::takeName(Value *V, Value *FromV) { + std::optional Parent = getParentComponent(V); + if (!Parent) + return; + // If changing name of a BB, we consider its parent to be the BB itself. + if (isa(V) && trackingComponent(cast(V))) + Parent = cast(V); + if (trackingComponent(*Parent)) + Changes.push_back(make_unique(V, FromV, this)); +} + +void CheckpointTracker::destroyName(Value *V) { + std::optional Parent = getParentComponent(V); + if (!Parent) + return; + // If changing name of a BB, we consider its parent to be the BB itself. + if (isa(V) && trackingComponent(cast(V))) + Parent = cast(V); + if (trackingComponent(*Parent)) + Changes.push_back(make_unique(V, this)); +} + void CheckpointTracker::trackComponent(ChkpntComponent Component) { assert(!trackingComponent(Component) && "Already tracking component"); ComponentsTracked.insert(Component); diff --git a/llvm/lib/IR/Value.cpp b/llvm/lib/IR/Value.cpp --- a/llvm/lib/IR/Value.cpp +++ b/llvm/lib/IR/Value.cpp @@ -104,7 +104,11 @@ // If this value is named, destroy the name. This should not be in a symtab // at this point. - destroyValueName(); + auto &Chkpnt = getContext().getChkpntTracker(); + if (LLVM_UNLIKELY(Chkpnt.isActive())) + Chkpnt.destroyName(this); + else + destroyValueName(); } void Value::deleteValue() { @@ -384,6 +388,10 @@ void Value::takeName(Value *V) { assert(V != this && "Illegal call to this->takeName(this)!"); + CheckpointTracker &ChkpntTracker = getContext().getChkpntTracker(); + if (LLVM_UNLIKELY(ChkpntTracker.isActive())) + ChkpntTracker.takeName(this, V); + ValueSymbolTable *ST = nullptr; // If this value has a name, drop it. if (hasName()) { diff --git a/llvm/unittests/IR/CheckpointTest.cpp b/llvm/unittests/IR/CheckpointTest.cpp --- a/llvm/unittests/IR/CheckpointTest.cpp +++ b/llvm/unittests/IR/CheckpointTest.cpp @@ -194,6 +194,34 @@ EXPECT_EQ(Instr->getName(), "new"); } +TEST(CheckpointTest, BB_TakeNameInstr) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr1 = add i32 %a, %b + %instr2 = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = &*F->begin(); + auto It = BB0->begin(); + auto *Instr1 = &*It++; + auto *Instr2 = &*It++; + + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0); + Instr1->takeName(Instr2); + EXPECT_EQ(Instr1->getName(), "instr2"); + EXPECT_FALSE(Instr2->hasName()); + EXPECT_FALSE(Chkpnt.empty()); + Chkpnt.restore(); + + EXPECT_EQ(Instr1->getName(), "instr1"); + EXPECT_EQ(Instr2->getName(), "instr2"); +} + TEST(CheckpointTest, BB_SetNameBB) { LLVMContext C; std::unique_ptr M = parseIR(C, R"(