diff --git a/llvm/include/llvm/IR/Checkpoint.h b/llvm/include/llvm/IR/Checkpoint.h --- a/llvm/include/llvm/IR/Checkpoint.h +++ b/llvm/include/llvm/IR/Checkpoint.h @@ -59,13 +59,20 @@ class Checkpoint { CheckpointTracker &ChkpntTracker; + bool RunVerifier; /// No copies allowed because going out of scope checks if we restored or /// accepted the changes. Checkpoint(Checkpoint &) = delete; void operator=(const Checkpoint &) = delete; public: - Checkpoint(CheckpointTracker &ChkpntTracker) : ChkpntTracker(ChkpntTracker) {} + /// If \p RunVerifier is true we run expensive checks to compare the modules + /// state between save() and restore(). These check if checkpointing works + /// correctly by comparing the saved and restored states. The checks only run + /// in a debug build, but if the saved components are large they can be + /// expensive. + Checkpoint(CheckpointTracker &ChkpntTracker, bool RunVerifier) + : ChkpntTracker(ChkpntTracker), RunVerifier(RunVerifier) {} ~Checkpoint(); /// \p MaxNumOfTrackedChanges is used for debugging to help diagnose cases diff --git a/llvm/include/llvm/IR/CheckpointTracker.h b/llvm/include/llvm/IR/CheckpointTracker.h --- a/llvm/include/llvm/IR/CheckpointTracker.h +++ b/llvm/include/llvm/IR/CheckpointTracker.h @@ -46,6 +46,58 @@ namespace llvm { class BasicBlock; + +#ifndef NDEBUG +/// Helper abstract class that saves the textual representation of the IR upon +/// construction and compares against it when `expectNoDiff()` is called. +/// Each component should implement its own custom checker class that inherits +/// from this. +class IRChecker { +protected: + /// Holds the initial IR dump. + std::string OriginalIR; + /// Helper string that helps skip the list of predecessors. This is to avoid + /// false positives due to differences in the order of users. + std::string PredsRegexStr = "; preds = .*\n"; + /// Holds the std::regex object created from `PredsRegexStr`. + std::regex PredsRegex; + /// Prints a simple line diff between \p OrigIR and \p CurrIR. + static void showDiff(const std::string &OrigIR, const std::string &CurrIR); + +public: + IRChecker() = default; + virtual ~IRChecker() = default; + IRChecker(const IRChecker &) = delete; + IRChecker(IRChecker &&) = delete; + /// This saves the IR state into `OriginalIR`. + virtual void save() = 0; + /// \Returns the dump of the original IR. + const std::string &origIR() const { return OriginalIR; } + /// \Returns the dump of the current IR. + virtual std::string currIR() const = 0; + /// Crashes if there is a difference between the original and current IR. + void expectNoDiff() const; +}; + +/// Checker for a single BasicBlock. +class BBChecker : public IRChecker { + BasicBlock *BB = nullptr; + bool SkipPreds; + + std::string dumpBB() const; + +public: + /// If SkipPreds is true we remove the "; preds = .*" string from the dumps. + /// This helps avoids false positives since Checkpoint cannot currently + /// preserve the order of users. + BBChecker(BasicBlock *BB, bool SkipPreds = true); + /// We dump the IR of the BasicBlock into `OriginalIR`. + void save() override; + /// \Returns an IR dump of the current state of the BasicBlock. + std::string currIR() const override; +}; +#endif // NDEBUG + class ChangeBase; class Instruction; class CheckpointTracker; @@ -71,6 +123,13 @@ /// The set of components currently being tracked. DenseSet ComponentsTracked; +#ifndef NDEBUG + /// Controls whether we are running the verifier in rollback(). + bool RunVerifier = false; + /// The verifier. + DenseMap> IRCheckers; +#endif // NDEBUG + /// This is true while checkpointing is active. bool Active = false; friend class CheckpointGuard; // Needs access to `Active`. @@ -107,7 +166,9 @@ // Main API functions. These are called by `Checkpoint`. /// Start tracking IR changes for \p Component from this point on. - void trackComponent(ChkpntComponent Component); + /// \p RunVerifier runs checks that compare component's state after + /// restore and before tracking started (only in DEBUG builds). + void trackComponent(ChkpntComponent Component, bool Runverifier); /// Override the maximum number of tracked changes. void setMaxNumOfTrackedChanges(uint32_t MaxNumOfTrackedChanges); /// Accept all changes. diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -325,7 +325,7 @@ /// \Returns the checkpoint handle that allows us to save/restore the state of /// IR components being tracked. - Checkpoint getCheckpoint(); + Checkpoint getCheckpoint(bool RunVerifier = true); private: // Module needs access to the add/removeModule methods. diff --git a/llvm/lib/IR/Checkpoint.cpp b/llvm/lib/IR/Checkpoint.cpp --- a/llvm/lib/IR/Checkpoint.cpp +++ b/llvm/lib/IR/Checkpoint.cpp @@ -21,7 +21,9 @@ ChkpntTracker.setMaxNumOfTrackedChanges(MaxNumOfTrackedChanges); } -void Checkpoint::track(BasicBlock *BB) { ChkpntTracker.trackComponent(BB); } +void Checkpoint::track(BasicBlock *BB) { + ChkpntTracker.trackComponent(BB, RunVerifier); +} void Checkpoint::restore() { ChkpntTracker.restoreComponents(); } diff --git a/llvm/lib/IR/CheckpointTracker.cpp b/llvm/lib/IR/CheckpointTracker.cpp --- a/llvm/lib/IR/CheckpointTracker.cpp +++ b/llvm/lib/IR/CheckpointTracker.cpp @@ -22,6 +22,75 @@ using namespace llvm; using namespace std; +#ifndef NDEBUG +void IRChecker::showDiff(const std::string &OrigIR, const std::string &CurrIR) { + // Show the first line that differes. + std::stringstream OrigSS(OrigIR); + std::stringstream CurrSS(CurrIR); + std::string OrigLine; + std::string CurrLine; + SmallVector Context; + static constexpr const uint32_t MaxContext = 3; + while (OrigSS.good() && CurrSS.good()) { + std::getline(OrigSS, OrigLine); + std::getline(CurrSS, CurrLine); + if (CurrLine != OrigLine) { + // Print context. + for (const std::string &ContextLine : Context) + dbgs() << " " << ContextLine << "\n"; + // Print the line difference. + dbgs() << "- " << OrigLine << "\n"; + dbgs() << "+ " << CurrLine << "\n"; + } else { + // Lazy way to maintain context. Performance of this code does not matter. + Context.push_back(OrigLine); + if (Context.size() > MaxContext) + Context.erase(Context.begin()); + } + } + // If one file is larger than the other print line in the larger one. + if (!OrigSS.good() && CurrSS.good()) { + std::getline(CurrSS, CurrLine); + dbgs() << "+ " << CurrLine << "\n"; + } + if (OrigSS.good() && !CurrSS.good()) { + std::getline(OrigSS, OrigLine); + dbgs() << "+ " << OrigLine << "\n"; + } +} + +void IRChecker::expectNoDiff() const { + const std::string &OrigIR = origIR(); + std::string CurrIR = currIR(); + bool Same = OrigIR == CurrIR; + if (!Same) { + showDiff(OrigIR, CurrIR); + llvm_unreachable( + "Original and current IR differ! Possibly a Checkpointing bug."); + } +} + +BBChecker::BBChecker(BasicBlock *BB, bool SkipPreds) : BB(BB) {} + +std::string BBChecker::dumpBB() const { + std::string TmpStr; + raw_string_ostream SS(TmpStr); + BB->print(SS, /*AssemblyAnnotationWriter=*/nullptr); + return SkipPreds + ? std::regex_replace(TmpStr, PredsRegex, "; preds = \n") + : TmpStr; +} + +void BBChecker::save() { + assert(BB != nullptr && "BB not set!"); + if (SkipPreds) + PredsRegex = std::regex(PredsRegexStr); + OriginalIR = dumpBB(); +} + +std::string BBChecker::currIR() const { return dumpBB(); } +#endif // NDEBUG + CheckpointGuard::CheckpointGuard(bool NewState, CheckpointTracker *Chkpnt) : Chkpnt(Chkpnt), LastState(Chkpnt->isActive()) { Chkpnt->Active = NewState; @@ -84,9 +153,20 @@ Changes.push_back(make_unique(V, this)); } -void CheckpointTracker::trackComponent(ChkpntComponent Component) { +void CheckpointTracker::trackComponent(ChkpntComponent Component, + bool RunVerifier) { + this->RunVerifier = RunVerifier; assert(!trackingComponent(Component) && "Already tracking component"); ComponentsTracked.insert(Component); +#ifndef NDEBUG + if (RunVerifier) { + if (BasicBlock **BBPtr = std::get_if(&Component)) + IRCheckers[Component] = make_unique(*BBPtr); + else + llvm_unreachable("Unimplemented IRChecker for `Component`"); + IRCheckers[Component]->save(); + } +#endif // NDEBUG Active = true; } @@ -106,6 +186,9 @@ Changes.clear(); ComponentsTracked.clear(); Active = false; +#ifndef NDEBUG + IRCheckers.clear(); +#endif // NDEBUG } void CheckpointTracker::restoreComponents() { @@ -122,6 +205,16 @@ Changes.clear(); ComponentsTracked.clear(); Active = false; +#ifndef NDEBUG + if (RunVerifier) { + for (ChkpntComponent &Component : ComponentsTracked) { + auto &CheckerPtr = IRCheckers[Component]; + CheckerPtr->expectNoDiff(); + IRCheckers.erase(Component); + break; + } + } +#endif // NDEBUG } CheckpointTracker::CheckpointTracker() {} diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -377,4 +377,6 @@ return !pImpl->getOpaquePointers(); } -Checkpoint LLVMContext::getCheckpoint() { return Checkpoint(ChkpntTracker); } +Checkpoint LLVMContext::getCheckpoint(bool RunVerifier) { + return Checkpoint(ChkpntTracker, RunVerifier); +} diff --git a/llvm/unittests/IR/CheckpointTest.cpp b/llvm/unittests/IR/CheckpointTest.cpp --- a/llvm/unittests/IR/CheckpointTest.cpp +++ b/llvm/unittests/IR/CheckpointTest.cpp @@ -51,7 +51,7 @@ #ifndef NDEBUG EXPECT_DEATH( { - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0); Instr->setName("new"); }, @@ -71,7 +71,7 @@ Function *F = &*M->begin(); BasicBlock *BB0 = &*F->begin(); auto *Instr = &*std::next(BB0->begin(), 0); - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0); Chkpnt.setMaxNumOfTrackedChanges(1); Instr->setName("change1"); @@ -95,7 +95,7 @@ BasicBlock *BB0 = getBBWithName(F, "bb0"); auto *Instr = &*std::next(BB0->begin(), 0); - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0); Instr->setName("new"); EXPECT_NE(Instr->getName(), "instr"); @@ -123,7 +123,7 @@ auto *Instr1 = &*std::next(BB0->begin(), 0); auto *Instr2 = &*std::next(BB1->begin(), 0); - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0, BB1); Instr1->setName("new1"); Instr2->setName("new2"); @@ -153,7 +153,7 @@ auto *Instr1 = &*std::next(BB0->begin(), 0); auto *Instr2 = &*std::next(BB1->begin(), 0); - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0, BB1); Instr1->setName("new1"); Instr2->setName("new2"); @@ -186,7 +186,7 @@ BasicBlock *BB0 = getBBWithName(F, "bb0"); auto *Instr = &*std::next(BB0->begin(), 0); - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0); Instr->setName("new"); EXPECT_FALSE(Chkpnt.empty()); @@ -210,7 +210,7 @@ auto *Instr1 = &*It++; auto *Instr2 = &*It++; - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0); Instr1->takeName(Instr2); EXPECT_EQ(Instr1->getName(), "instr2"); @@ -233,7 +233,7 @@ )"); Function *F = &*M->begin(); BasicBlock *BB0 = &*F->begin(); - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0); BB0->setName("NEWNAME"); EXPECT_NE(BB0->getName(), "bb0"); @@ -254,7 +254,7 @@ Function *F = &*M->begin(); BasicBlock *BB0 = getBBWithName(F, "bb0"); - Checkpoint Chkpnt = C.getCheckpoint(); + Checkpoint Chkpnt = C.getCheckpoint(/*RunVerifier=*/true); Chkpnt.track(BB0); F->setName("bar"); EXPECT_NE(F->getName(), "foo");