diff --git a/clang/include/clang/Interpreter/Interpreter.h b/clang/include/clang/Interpreter/Interpreter.h --- a/clang/include/clang/Interpreter/Interpreter.h +++ b/clang/include/clang/Interpreter/Interpreter.h @@ -60,14 +60,12 @@ const llvm::orc::LLJIT *getExecutionEngine() const; llvm::Expected Parse(llvm::StringRef Code); llvm::Error Execute(PartialTranslationUnit &T); - llvm::Error ParseAndExecute(llvm::StringRef Code) { - auto PTU = Parse(Code); - if (!PTU) - return PTU.takeError(); - if (PTU->TheModule) - return Execute(*PTU); - return llvm::Error::success(); - } + llvm::Error ParseAndExecute(llvm::StringRef Code); + + /// \returns the last PartialTranslationUnit. + PartialTranslationUnit &getLastPTU(); + + llvm::Error Undo(unsigned N = 1); /// \returns the \c JITTargetAddress of a \c GlobalDecl. This interface uses /// the CodeGenModule's internal mangling cache to avoid recomputing the diff --git a/clang/lib/Interpreter/IncrementalExecutor.h b/clang/lib/Interpreter/IncrementalExecutor.h --- a/clang/lib/Interpreter/IncrementalExecutor.h +++ b/clang/lib/Interpreter/IncrementalExecutor.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_LIB_INTERPRETER_INCREMENTALEXECUTOR_H #define LLVM_CLANG_LIB_INTERPRETER_INCREMENTALEXECUTOR_H +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" @@ -29,11 +30,17 @@ } // namespace llvm namespace clang { + +struct PartialTranslationUnit; + class IncrementalExecutor { using CtorDtorIterator = llvm::orc::CtorDtorIterator; std::unique_ptr Jit; llvm::orc::ThreadSafeContext &TSCtx; + llvm::DenseMap + ResourceTrackers; + public: enum SymbolNameKind { IRName, LinkerName }; @@ -41,7 +48,8 @@ const llvm::Triple &Triple); ~IncrementalExecutor(); - llvm::Error addModule(std::unique_ptr M); + llvm::Error addModule(PartialTranslationUnit *PTU); + llvm::Error removeModule(clang::PartialTranslationUnit *PTU); llvm::Error runCtors() const; llvm::Expected getSymbolAddress(llvm::StringRef Name, SymbolNameKind NameKind) const; diff --git a/clang/lib/Interpreter/IncrementalExecutor.cpp b/clang/lib/Interpreter/IncrementalExecutor.cpp --- a/clang/lib/Interpreter/IncrementalExecutor.cpp +++ b/clang/lib/Interpreter/IncrementalExecutor.cpp @@ -12,6 +12,7 @@ #include "IncrementalExecutor.h" +#include "clang/Interpreter/PartialTranslationUnit.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" @@ -52,8 +53,24 @@ IncrementalExecutor::~IncrementalExecutor() {} -llvm::Error IncrementalExecutor::addModule(std::unique_ptr M) { - return Jit->addIRModule(llvm::orc::ThreadSafeModule(std::move(M), TSCtx)); +llvm::Error IncrementalExecutor::addModule(PartialTranslationUnit *PTU) { + llvm::orc::ResourceTrackerSP RT = + Jit->getMainJITDylib().createResourceTracker(); + ResourceTrackers[PTU] = RT; + + return Jit->addIRModule(RT, {std::move(PTU->TheModule), TSCtx}); +} + +llvm::Error IncrementalExecutor::removeModule(PartialTranslationUnit *PTU) { + + llvm::orc::ResourceTrackerSP RT = std::move(ResourceTrackers[PTU]); + if (!RT) + return llvm::Error::success(); + + ResourceTrackers.erase(PTU); + if (llvm::Error Err = RT->remove()) + return Err; + return llvm::Error::success(); } llvm::Error IncrementalExecutor::runCtors() const { diff --git a/clang/lib/Interpreter/IncrementalParser.h b/clang/lib/Interpreter/IncrementalParser.h --- a/clang/lib/Interpreter/IncrementalParser.h +++ b/clang/lib/Interpreter/IncrementalParser.h @@ -72,6 +72,8 @@ ///\returns the mangled name of a \c GD. llvm::StringRef GetMangledName(GlobalDecl GD) const; + std::list &getPTUs() { return PTUs; } + private: llvm::Expected ParseOrWrapTopLevelDecl(); }; diff --git a/clang/lib/Interpreter/Interpreter.cpp b/clang/lib/Interpreter/Interpreter.cpp --- a/clang/lib/Interpreter/Interpreter.cpp +++ b/clang/lib/Interpreter/Interpreter.cpp @@ -218,8 +218,7 @@ if (Err) return Err; } - // FIXME: Add a callback to retain the llvm::Module once the JIT is done. - if (auto Err = IncrExecutor->addModule(std::move(T.TheModule))) + if (auto Err = IncrExecutor->addModule(&T)) return Err; if (auto Err = IncrExecutor->runCtors()) @@ -228,6 +227,16 @@ return llvm::Error::success(); } +llvm::Error Interpreter::ParseAndExecute(llvm::StringRef Code) { + auto PTU = Parse(Code); + if (!PTU) + return PTU.takeError(); + if (PTU->TheModule) { + return Execute(*PTU); + } + return llvm::Error::success(); +} + llvm::Expected Interpreter::getSymbolAddress(GlobalDecl GD) const { if (!IncrExecutor) @@ -257,3 +266,13 @@ return IncrExecutor->getSymbolAddress(Name, IncrementalExecutor::LinkerName); } + +clang::PartialTranslationUnit &Interpreter::getLastPTU() { + return IncrParser->getPTUs().back(); +} + +llvm::Error Interpreter::Undo(unsigned N) { + llvm::Error Err = IncrExecutor->removeModule(&IncrParser->getPTUs().back()); + IncrParser->getPTUs().pop_back(); + return Err; +} diff --git a/clang/tools/clang-repl/ClangRepl.cpp b/clang/tools/clang-repl/ClangRepl.cpp --- a/clang/tools/clang-repl/ClangRepl.cpp +++ b/clang/tools/clang-repl/ClangRepl.cpp @@ -95,6 +95,11 @@ while (llvm::Optional Line = LE.readLine()) { if (*Line == "quit") break; + if (*Line == "undo") { + Interp->Undo(); + continue; + } + if (auto Err = Interp->ParseAndExecute(*Line)) llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: "); } diff --git a/clang/unittests/Interpreter/InterpreterTest.cpp b/clang/unittests/Interpreter/InterpreterTest.cpp --- a/clang/unittests/Interpreter/InterpreterTest.cpp +++ b/clang/unittests/Interpreter/InterpreterTest.cpp @@ -248,4 +248,23 @@ EXPECT_EQ(42, fn(NewA)); } +TEST(InterpreterTest, UndoBasic) { + Args ExtraArgs = {"-Xclang", "-diagnostic-log-file", "-Xclang", "-"}; + + // Create the diagnostic engine with unowned consumer. + std::string DiagnosticOutput; + llvm::raw_string_ostream DiagnosticsOS(DiagnosticOutput); + auto DiagPrinter = std::make_unique( + DiagnosticsOS, new DiagnosticOptions()); + + auto Interp = createInterpreter(ExtraArgs, DiagPrinter.get()); + auto R1 = Interp->Parse("int x = 42;"); + EXPECT_TRUE(!!R1); + + llvm::cantFail(Interp->Undo()); + + auto R2 = Interp->Parse("int x = 24;"); + EXPECT_TRUE(!!R2); +} + } // end anonymous namespace