diff --git a/llvm/include/llvm/IR/BasicBlock.h b/llvm/include/llvm/IR/BasicBlock.h --- a/llvm/include/llvm/IR/BasicBlock.h +++ b/llvm/include/llvm/IR/BasicBlock.h @@ -19,6 +19,7 @@ #include "llvm/ADT/ilist_node.h" #include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/IR/Checkpoint.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/SymbolTableListTraits.h" #include "llvm/IR/Value.h" diff --git a/llvm/include/llvm/IR/Checkpoint.h b/llvm/include/llvm/IR/Checkpoint.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Checkpoint.h @@ -0,0 +1,100 @@ +//===- Checkpoint.h ---------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// About +// ----- +// This class declares the checkpoint handle class for local checkpointing of +// IR components. It provides a simple API: track() / restore() / accept(). +// - track(, , ...) starts tracking changes made to the +// components passed as arguments. For example `track(BB)` will track any +// changes made to the contents of `BB`. +// - restore() reverts the state of all components tracked to the the state when +// we started tracking them. This works by reverting one-by-one all individual +// changes recorded. +// - accept() stops tracking and accepts all changes made to all components. +// +// How to use +// ---------- +// - Get a checkpoint handle using `LLVMContext::getCheckpoint()`. +// For example, given LLVMContext `C`: +// auto Chkpnt = C.getCheckpoint(); +// - Start tracking changes in `BB`'s state using `Chkpnt.track(BB)`. This will +// track changes made to the contents of the `BB` from this point on. +// - Modify the `BB` in any way (e.g. move instructions, change instruction +// operands, etc.). +// - Restore the original state of the IR using: +// Chkpnt.restore(); +// Or accept the current state using: +// Chkpnt.accept(); +// - Don't let the handle go out of scope without calling accept() or restore() +// - Supported components have a corresponding `Checkpoint::track()` +// function. +// +// Complexity +// ---------- +// Tracking uses linear space to store the changes. Restoring state takes linear +// time to the number of changes tracked. +// + +#ifndef LLVM_IR_CHECKPOINT_H +#define LLVM_IR_CHECKPOINT_H + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/CheckpointCommon.h" +#include "llvm/Support/Compiler.h" +#include + +namespace llvm { + +class raw_ostream; +class BasicBlock; +class CheckpointTracker; + +class Checkpoint { + CheckpointTracker &ChkpntTracker; + /// No copies allowed because going out of scope checks if we restored or + /// accepted the changes. + Checkpoint(Checkpoint &) = delete; + void operator=(const Checkpoint &) = delete; + +public: + Checkpoint(CheckpointTracker &ChkpntTracker) : ChkpntTracker(ChkpntTracker) {} + ~Checkpoint(); + + /// \p MaxNumOfTrackedChanges is used for debugging to help diagnose cases + /// were the user forgets to accept() or restore(). It will cause a crash + /// if we record more changes that this number. + void setMaxNumOfTrackedChanges(uint32_t MaxNumOfTrackedChanges); + /// Activates checkpointing and starts tracking changes made to \p BB. This + /// includes all instructions in the block and their state, including their + /// operands, users, names, and other instruction-specific state. + void track(BasicBlock *BB); + /// Track multiple components. + template + void track(BasicBlock *BB0, ComponentTs... Other) { + track(Other...); + track(BB0); + } + /// Reverts the state of all components. + void restore(); + /// Accepts all changes and stops tracking for all components. + void accept(); + /// \Returns true if there are no changes tracked. + bool empty() const; +#ifndef NDEBUG + /// \Returns the number of entries. + uint32_t size() const; + /// Debug printers. + void dump(raw_ostream &OS) const; + LLVM_DUMP_METHOD void dump() const; +#endif // NDEBUG +}; +} // namespace llvm +#endif // LLVM_IR_CHECKPOINT_H diff --git a/llvm/include/llvm/IR/CheckpointChanges.h b/llvm/include/llvm/IR/CheckpointChanges.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/CheckpointChanges.h @@ -0,0 +1,95 @@ +//===- CheckpointChanges.h --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares all the Change classes that are used for tracking changes +// to the components being tracked. +// For example, `SetName` track a change in the name of an llvm::Value. +// +// ChangeBase is an abstract class that declares a simple interface with the +// main functionality being: `revert()` and `apply()`. +// All change classes, like for example `SetName`, inherit from ChangeBase. + +#ifndef LLVM_IR_CHECKPOINTCHANGES_H +#define LLVM_IR_CHECKPOINTCHANGES_H + +#include "llvm/ADT/iterator_range.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CheckpointCommon.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/Support/Debug.h" + +namespace llvm { + +class Function; +class BasicBlock; +class Instruction; +class Value; +class User; +class Use; +class PHINode; +class CallBase; +class CheckpointTracker; + +enum class ChangeID : uint8_t { + SetNameID, + TakeNameID, + DestroyNameID, +}; + +/// Abstract class used as a base class for all change classes. +class ChangeBase { +protected: + /// This is used for keeping track of the saved state. + Value *V; + /// For isa<>, cast<> etc. + ChangeID ID; +#ifndef NDEBUG + /// Helper for debuging + CheckpointTracker *Parent; +#endif // NDEBUG + +#ifndef NDEBUG + /// \Returns the unique ID of this object. Used for debugging. + LLVM_DUMP_METHOD uint32_t getUid() const; + void dumpCommon(raw_ostream &OS) const; + void addDump(Value *V); +#endif // NDEBUG + +public: + ChangeBase(Value *V, ChangeID ID, CheckpointTracker *CE); + /// Reverts the change. + virtual void revert() = 0; + /// If we decide to accept the current state we call this to finalize. + virtual void apply() = 0; + ChangeID getID() const { return ID; } + virtual ~ChangeBase() {} +#ifndef NDEBUG + virtual void dump(raw_ostream &OS) const = 0; + LLVM_DUMP_METHOD virtual void dump() const = 0; +#endif // NDEBUG +}; + +class SetName : public ChangeBase { + std::string OrigName; + +public: + SetName(Value *V, CheckpointTracker *CT); + void revert() override; + void apply() override; + static bool classof(const ChangeBase *Other) { + return Other->getID() == ChangeID::SetNameID; + } + ~SetName() {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const override; + LLVM_DUMP_METHOD void dump() const override; +#endif // NDEBUG +}; +} // namespace llvm + +#endif // LLVM_IR_CHECKPOINTCHANGES_H diff --git a/llvm/include/llvm/IR/CheckpointCommon.h b/llvm/include/llvm/IR/CheckpointCommon.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/CheckpointCommon.h @@ -0,0 +1,34 @@ +//===- CheckpointCommon.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_CHECKPOINTCOMMON_H +#define LLVM_IR_CHECKPOINTCOMMON_H + +#include +#include +#include + +namespace llvm { + +class Value; +class BasicBlock; +class Function; +class Module; + +/// A component can be one of these pointer types. This can be extended with +/// more pointer types. +using ChkpntComponent = std::variant; + +#ifndef NDEBUG +/// \Returns a dump of \p Component for debugging. +std::string dumpComponent(ChkpntComponent Component); +#endif + +} // namespace llvm + +#endif // LLVM_IR_CHECKPOINTCOMMON_H diff --git a/llvm/include/llvm/IR/CheckpointTracker.h b/llvm/include/llvm/IR/CheckpointTracker.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/CheckpointTracker.h @@ -0,0 +1,138 @@ +//===- CheckpointTracker.h --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Note: This header file should not be included by chekcpointing clients. +// Please include IR/Checkpoint.h instead. +// This should only be included by IR classes that need to notify the +// checkpointing tracker about changes in components being tracked. +// +// About +// ----- +// This declares CheckpointTracker, which is the main data structure that holds +// the state changes and is also in charge of applying/reverting them. +// +// Implementation +// -------------- +// The change objects are collected in `Changes` vector in the order they take +// place. The objects get generated by functions of this class that get called +// by core IR functions. For example, `llvm::Value::setName()` calls +// `CheckpointTracker::setName()` which creates a `SetName` change object and +// appends it to the vector of all changes. +// +// A change is tracked only if it belongs to one of the components being +// tracked. Tracking a component is done with `trackComponent()`. +// +// A call to `restoreComponents()` visits all changes in reverse order, calling +// `ChangeBase::revert()` for each one of them. +// A call to `acceptComponents()` state visits all changes in order, calling +// `ChangeBase::apply()`. + +#ifndef LLVM_IR_CHECKPOINTTRACKER_H +#define LLVM_IR_CHECKPOINTTRACKER_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/CheckpointCommon.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace llvm { +class BasicBlock; +class ChangeBase; +class Instruction; +class CheckpointTracker; + +/// A simple guard class that deactivates checkpointing on construction and +/// reactivates it on destruction. +class CheckpointGuard { + CheckpointTracker *Chkpnt; + bool LastState; + friend class CheckpointTracker; + /// Private by design. Use CheckpointTracker::disable() to get a guard. + CheckpointGuard(bool NewState, CheckpointTracker *Chkpnt); + +public: + ~CheckpointGuard(); +}; + +/// This is the main class for the checkpointing internals. This is where +/// the changes get recorded. +class CheckpointTracker { + /// The sequence of changes applied to the IR in the order they take place. + SmallVector, 32> Changes; + /// The set of components currently being tracked. + DenseSet ComponentsTracked; + + /// This is true while checkpointing is active. + bool Active = false; + friend class CheckpointGuard; // Needs access to `Active`. + + /// A limit to the number of changes we will record. Going over the limit + /// causes a crash. This is useful for debugging, as it will catch cases where + /// we collect too many changes, which would suggest that the user forgot to + /// accept() or restore(). + /// This value can be overriden by the user using: + /// `Checkpoint::setMaxNumOfTrackedChanges()`. + uint32_t MaxNumChanges = 4096; + +#ifndef NDEBUG + /// Unique ID for each change object, for debugging. + DenseMap ChangeUids; + friend class ChangeBase; // Writes to ChangeUids. +#endif // NDEBUG + + friend class Checkpoint; // Calls trackComponent(). + friend class Value; // Calls setName(). + + /// Clears the state. + void clear(); + /// \Returns the parent component of \p V. Can handle detached instructions. + std::optional getParentComponent(Value *V) const; + + /// To be called when \p V is about to get its name updated. + void setName(Value *V); + + // Main API functions. These are called by `Checkpoint`. + + /// Start tracking IR changes for \p Component from this point on. + void trackComponent(ChkpntComponent Component); + /// Override the maximum number of tracked changes. + void setMaxNumOfTrackedChanges(uint32_t MaxNumOfTrackedChanges); + /// Accept all changes. + void acceptComponents(); + /// Reverts all changes. + void restoreComponents(); + +#ifndef NDEBUG + void dump(raw_ostream &OS) const; + LLVM_DUMP_METHOD void dump() const; +#endif // NDEBUG + +public: + CheckpointTracker(); + ~CheckpointTracker(); + /// Deactivates checkpointing as long as the returned guard is in-scope. + /// This is used when reverting changes, to avoid re-tracking the reverting + /// changes. + CheckpointGuard disable(); + /// \Returns true if we are currently tracking any component. This needs to be + /// fast because it gets called repeatedly throughout IR functions. + inline bool isActive() const { return Active; } + /// \Returns true if we are actively tracking \p Component. + bool trackingComponent(const ChkpntComponent &Component) const; + /// \Returns true if there are no changes in the changes list. + bool empty() const; + /// \Returns the number of changes in the list. + uint32_t size() const; +}; +} // namespace llvm + +#endif // LLVM_IR_CHECKPOINTTRACKER_H diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -15,6 +15,8 @@ #define LLVM_IR_LLVMCONTEXT_H #include "llvm-c/Types.h" +#include "llvm/IR/Checkpoint.h" +#include "llvm/IR/CheckpointTracker.h" #include "llvm/IR/DiagnosticHandler.h" #include "llvm/Support/CBindingWrapping.h" #include @@ -36,6 +38,7 @@ class StringRef; class Twine; class LLVMRemarkStreamer; +class CheckpointTracker; namespace remarks { class RemarkStreamer; @@ -320,6 +323,10 @@ /// Whether typed pointers are supported. If false, all pointers are opaque. bool supportsTypedPointers() const; + /// \Returns the checkpoint handle that allows us to save/restore the state of + /// IR components being tracked. + Checkpoint getCheckpoint(); + private: // Module needs access to the add/removeModule methods. friend class Module; @@ -330,6 +337,19 @@ /// removeModule - Unregister a module from this context. void removeModule(Module*); + + /// The checkpointing object that contains the state changes. IR member + /// functions that are modifying the IR update this object. This cannot be + /// used directly by the user: Use a handle provided by `getCheckpoint()` + /// instead. + CheckpointTracker ChkpntTracker; + friend class CheckpointTracker; + + /// \Returns the internal checkpointing object. To be used by the IR member + /// functions, not by the user. + CheckpointTracker &getChkpntTracker() { return ChkpntTracker; } + + friend class Value; // Calls getChkpntTracker() }; // Create wrappers for C Binding types (see CBindingWrapping.h). diff --git a/llvm/include/llvm/IR/Value.h b/llvm/include/llvm/IR/Value.h --- a/llvm/include/llvm/IR/Value.h +++ b/llvm/include/llvm/IR/Value.h @@ -263,6 +263,7 @@ void setValueName(ValueName *VN); private: + friend class DestroyName; // Needs to call destroyValueName(). void destroyValueName(); enum class ReplaceMetadataUses { No, Yes }; void doRAUW(Value *New, ReplaceMetadataUses); diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -6,6 +6,10 @@ AutoUpgrade.cpp BasicBlock.cpp BuiltinGCs.cpp + Checkpoint.cpp + CheckpointChanges.cpp + CheckpointCommon.cpp + CheckpointTracker.cpp Comdat.cpp ConstantFold.cpp ConstantRange.cpp diff --git a/llvm/lib/IR/Checkpoint.cpp b/llvm/lib/IR/Checkpoint.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/Checkpoint.cpp @@ -0,0 +1,38 @@ +//===- Checkpoint.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/Checkpoint.h" +#include "llvm/IR/CheckpointTracker.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +Checkpoint::~Checkpoint() { + assert(ChkpntTracker.empty() && + "Missing call to Checkpoint::accept() or Checkpoint::restore()"); +} + +void Checkpoint::setMaxNumOfTrackedChanges(uint32_t MaxNumOfTrackedChanges) { + ChkpntTracker.setMaxNumOfTrackedChanges(MaxNumOfTrackedChanges); +} + +void Checkpoint::track(BasicBlock *BB) { ChkpntTracker.trackComponent(BB); } + +void Checkpoint::restore() { ChkpntTracker.restoreComponents(); } + +void Checkpoint::accept() { ChkpntTracker.acceptComponents(); } + +bool Checkpoint::empty() const { return ChkpntTracker.empty(); } + +#ifndef NDEBUG +uint32_t Checkpoint::size() const { return ChkpntTracker.size(); } + +void Checkpoint::dump(raw_ostream &OS) const { ChkpntTracker.dump(OS); } + +void Checkpoint::dump() const { dump(dbgs()); } +#endif // NDEBUG diff --git a/llvm/lib/IR/CheckpointChanges.cpp b/llvm/lib/IR/CheckpointChanges.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/CheckpointChanges.cpp @@ -0,0 +1,65 @@ +//===- CheckpointChanges.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/CheckpointChanges.h" +#include "LLVMContextImpl.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Comdat.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/User.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace llvm; +using namespace std; + +ChangeBase::ChangeBase(Value *V, ChangeID ID, CheckpointTracker *CT) + : V(V), ID(ID) +#ifndef NDEBUG + , + Parent(CT) +#endif // NDEBUG +{ +#ifndef NDEBUG + Parent->ChangeUids[this] = Parent->ChangeUids.size() + 1; + assert(Parent->isActive() && "Need to call save() first"); + assert(Parent->size() < Parent->MaxNumChanges && + "Tracking too many changes!"); +#endif +} + +#ifndef NDEBUG +uint32_t ChangeBase::getUid() const { return Parent->ChangeUids.lookup(this); } +void ChangeBase::dumpCommon(raw_ostream &OS) const { OS << getUid() << ". "; } +#endif + +SetName::SetName(Value *Val, CheckpointTracker *CT) + : ChangeBase(Val, ChangeID::SetNameID, CT) { + OrigName = Val->getName(); +} + +void SetName::revert() { V->setName(OrigName); } + +void SetName::apply() {} + +#ifndef NDEBUG +void SetName::dump(raw_ostream &OS) const { + dumpCommon(OS); + OS << "SetName: " << V << " OrigName='" << OrigName << "'\n"; +} + +void SetName::dump() const { dump(dbgs()); } +#endif // NDEBUG diff --git a/llvm/lib/IR/CheckpointCommon.cpp b/llvm/lib/IR/CheckpointCommon.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/CheckpointCommon.cpp @@ -0,0 +1,34 @@ +//===- CheckpointCommon.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/CheckpointCommon.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include + +using namespace llvm; + +namespace llvm { + +#ifndef NDEBUG +std::string dumpComponent(ChkpntComponent Component) { + std::stringstream SS; + if (BasicBlock **BBPtr = std::get_if(&Component)) { + SS << "BasicBlock " << *BBPtr; + return SS.str(); + } + if (Function **FPtr = std::get_if(&Component)) { + SS << "Function " << *FPtr; + return SS.str(); + } + llvm_unreachable("Unimplemented Component Str"); +} +#endif +} // namespace llvm diff --git a/llvm/lib/IR/CheckpointTracker.cpp b/llvm/lib/IR/CheckpointTracker.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/CheckpointTracker.cpp @@ -0,0 +1,127 @@ +//===- CheckpointTracker.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/CheckpointTracker.h" +#include "LLVMContextImpl.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CheckpointChanges.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include + +using namespace llvm; +using namespace std; + +CheckpointGuard::CheckpointGuard(bool NewState, CheckpointTracker *Chkpnt) + : Chkpnt(Chkpnt), LastState(Chkpnt->isActive()) { + Chkpnt->Active = NewState; +} + +CheckpointGuard::~CheckpointGuard() { Chkpnt->Active = LastState; } + +void CheckpointTracker::clear() { + Changes.clear(); + Active = false; +#ifndef NDEBUG + ChangeUids.clear(); +#endif // NDEBUG +} + +std::optional +CheckpointTracker::getParentComponent(Value *V) const { + if (isa(V)) { + if (auto *Parent = cast(V)->getParent()) + return Parent; + } + if (isa(V)) { + if (auto *Parent = cast(V)->getParent()) + return Parent; + llvm_unreachable("Unimplemented"); + } + return std::nullopt; +} + +void CheckpointTracker::setName(Value *V) { + std::optional Parent = getParentComponent(V); + if (!Parent) + return; + // If changing name of a BB, we consider its parent to be the BB itself. + if (isa(V) && trackingComponent(cast(V))) + Parent = cast(V); + if (trackingComponent(*Parent)) + Changes.push_back(make_unique(V, this)); +} + +void CheckpointTracker::trackComponent(ChkpntComponent Component) { + assert(!trackingComponent(Component) && "Already tracking component"); + ComponentsTracked.insert(Component); + Active = true; +} + +void CheckpointTracker::setMaxNumOfTrackedChanges( + uint32_t MaxNumOfTrackedChanges) { + MaxNumChanges = MaxNumOfTrackedChanges; +} + +void CheckpointTracker::acceptComponents() { + { + // Deactivate tracking temporarily to make sure that any functions used by + // `apply()` are not being tracked. + auto DisableTrackingGuard = disable(); + for (auto &ChangePtr : Changes) + ChangePtr->apply(); + } + Changes.clear(); + ComponentsTracked.clear(); + Active = false; +} + +void CheckpointTracker::restoreComponents() { + assert(Active && "Trying to restore() without having started tracking"); + { + // Deactivate tracking temporarily to make sure that any functions used by + // `revert()` are not being tracked. + auto DisableTrackingGuard = disable(); + // Iterate through the changes in reverse and revert if they match + // Component. + for (auto &ChangePtr : reverse(Changes)) + ChangePtr->revert(); + } + Changes.clear(); + ComponentsTracked.clear(); + Active = false; +} + +CheckpointTracker::CheckpointTracker() {} + +CheckpointTracker::~CheckpointTracker() {} + +bool CheckpointTracker::trackingComponent( + const ChkpntComponent &Component) const { + return ComponentsTracked.count(Component); +} +bool CheckpointTracker::empty() const { return Changes.empty(); } + +uint32_t CheckpointTracker::size() const { return Changes.size(); } + +CheckpointGuard CheckpointTracker::disable() { + return CheckpointGuard(false, this); +} + +#ifndef NDEBUG +void CheckpointTracker::dump(raw_ostream &OS) const { + for (const auto &ChangePtr : Changes) + ChangePtr->dump(OS); +} +void CheckpointTracker::dump() const { dump(dbgs()); } +#endif // NDEBUG diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" +#include "llvm/IR/Checkpoint.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMRemarkStreamer.h" @@ -375,3 +376,5 @@ bool LLVMContext::supportsTypedPointers() const { return !pImpl->getOpaquePointers(); } + +Checkpoint LLVMContext::getCheckpoint() { return Checkpoint(ChkpntTracker); } diff --git a/llvm/lib/IR/Value.cpp b/llvm/lib/IR/Value.cpp --- a/llvm/lib/IR/Value.cpp +++ b/llvm/lib/IR/Value.cpp @@ -14,6 +14,7 @@ #include "LLVMContextImpl.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallString.h" +#include "llvm/IR/CheckpointTracker.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -373,6 +374,9 @@ } void Value::setName(const Twine &NewName) { + CheckpointTracker &ChkpntTracker = getContext().getChkpntTracker(); + if (LLVM_UNLIKELY(ChkpntTracker.isActive())) + ChkpntTracker.setName(this); setNameImpl(NewName); if (Function *F = dyn_cast(this)) F->recalculateIntrinsicID(); diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -13,6 +13,7 @@ AsmWriterTest.cpp AttributesTest.cpp BasicBlockTest.cpp + CheckpointTest.cpp CFGBuilder.cpp ConstantRangeTest.cpp ConstantsTest.cpp diff --git a/llvm/unittests/IR/CheckpointTest.cpp b/llvm/unittests/IR/CheckpointTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/IR/CheckpointTest.cpp @@ -0,0 +1,241 @@ +//===- llvm/unittest/IR/CheckpointTest.cpp - Checkpoint unit tests --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/Checkpoint.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Checkpoint.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +namespace llvm { +namespace { + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("CheckpointTest", errs()); + return Mod; +} + +static BasicBlock *getBBWithName(Function *F, StringRef Name) { + auto It = find_if( + *F, [&Name](const BasicBlock &BB) { return BB.getName() == Name; }); + assert(It != F->end() && "Not found!"); + return &*It; +} + +TEST(CheckpointTest, HandleOutOfScope) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = &*F->begin(); + auto *Instr = &*std::next(BB0->begin(), 0); + +#ifndef NDEBUG + EXPECT_DEATH( + { + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0); + Instr->setName("new"); + }, + ".*"); +#endif +} + +TEST(CheckpointTest, MaxNumOfTrackedChanges) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = &*F->begin(); + auto *Instr = &*std::next(BB0->begin(), 0); + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0); + Chkpnt.setMaxNumOfTrackedChanges(1); + Instr->setName("change1"); +#ifndef NDEBUG + // This should crash as we exceeded the maximum number of tracked changes. + EXPECT_DEATH({ Instr->setName("change2"); }, ".*"); +#endif + Chkpnt.accept(); +} + +TEST(CheckpointTest, BB_SetNameInstr) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = getBBWithName(F, "bb0"); + auto *Instr = &*std::next(BB0->begin(), 0); + + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0); + Instr->setName("new"); + EXPECT_NE(Instr->getName(), "instr"); + EXPECT_FALSE(Chkpnt.empty()); + Chkpnt.restore(); + + EXPECT_EQ(Instr->getName(), "instr"); +} + +TEST(CheckpointTest, MultipleBBs_SetNameInstr) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr1 = add i32 %a, %b + br label %bb1 +bb1: + %instr2 = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = getBBWithName(F, "bb0"); + BasicBlock *BB1 = getBBWithName(F, "bb1"); + auto *Instr1 = &*std::next(BB0->begin(), 0); + auto *Instr2 = &*std::next(BB1->begin(), 0); + + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0, BB1); + Instr1->setName("new1"); + Instr2->setName("new2"); + EXPECT_NE(Instr1->getName(), "instr1"); + EXPECT_NE(Instr2->getName(), "instr2"); + EXPECT_FALSE(Chkpnt.empty()); + Chkpnt.restore(); + EXPECT_EQ(Instr1->getName(), "instr1"); + EXPECT_EQ(Instr2->getName(), "instr2"); +} + +TEST(CheckpointTest, MultipleBBs_ExcercizeRestoreAccept_SetNameInstr) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr1 = add i32 %a, %b + br label %bb1 +bb1: + %instr2 = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = getBBWithName(F, "bb0"); + BasicBlock *BB1 = getBBWithName(F, "bb1"); + auto *Instr1 = &*std::next(BB0->begin(), 0); + auto *Instr2 = &*std::next(BB1->begin(), 0); + + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0, BB1); + Instr1->setName("new1"); + Instr2->setName("new2"); + EXPECT_NE(Instr1->getName(), "instr1"); + EXPECT_NE(Instr2->getName(), "instr2"); + EXPECT_FALSE(Chkpnt.empty()); + Chkpnt.restore(); + EXPECT_EQ(Instr1->getName(), "instr1"); + EXPECT_EQ(Instr2->getName(), "instr2"); + + Chkpnt.track(BB0, BB1); + Instr1->setName("new1"); + Instr2->setName("new2"); + Chkpnt.accept(); + EXPECT_EQ(Instr1->getName(), "new1"); + EXPECT_EQ(Instr2->getName(), "new2"); +} + +// Same as the previous test, but we accept the changes. +TEST(CheckpointTest, BB_SetNameInstr_Accept) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = getBBWithName(F, "bb0"); + auto *Instr = &*std::next(BB0->begin(), 0); + + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0); + Instr->setName("new"); + EXPECT_FALSE(Chkpnt.empty()); + Chkpnt.accept(); + EXPECT_EQ(Instr->getName(), "new"); +} + +TEST(CheckpointTest, BB_SetNameBB) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = &*F->begin(); + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0); + BB0->setName("NEWNAME"); + EXPECT_NE(BB0->getName(), "bb0"); + EXPECT_FALSE(Chkpnt.empty()); + Chkpnt.restore(); + EXPECT_EQ(BB0->getName(), "bb0"); +} + +TEST(CheckpointTest, BB_SetNameFn) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"( +define void @foo(i32 %a, i32 %b) { +bb0: + %instr = add i32 %a, %b + ret void +} +)"); + Function *F = &*M->begin(); + BasicBlock *BB0 = getBBWithName(F, "bb0"); + + Checkpoint Chkpnt = C.getCheckpoint(); + Chkpnt.track(BB0); + F->setName("bar"); + EXPECT_NE(F->getName(), "foo"); + // We don't save the state of F because we are only tracking BB0 + EXPECT_TRUE(Chkpnt.empty()); + Chkpnt.restore(); + // We did not save the state of F so rollback should not touch F's name. + EXPECT_EQ(F->getName(), "bar"); +} + +} // namespace +} // namespace llvm