diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Module.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/CrashRecoveryContext.h" #include "llvm/Support/Mutex.h" @@ -545,6 +546,115 @@ // PassCrashReproducer //===----------------------------------------------------------------------===// +namespace { +/// This class contains all of the context for generating a recovery reproducer. +/// Each recovery context is registered globally to allow for generating +/// reproducers when a signal is raised, such as a segfault. +struct RecoveryReproducerContext { + RecoveryReproducerContext(MutableArrayRef> passes, + ModuleOp module, StringRef filename, + bool disableThreads, bool verifyPasses); + ~RecoveryReproducerContext(); + + /// Generate a reproducer with the current context. + LogicalResult generate(std::string &error); + +private: + /// This function is invoked in the event of a crash. + static void crashHandler(void *); + + /// Register a signal handler to run in the event of a crash. + static void registerSignalHandler(); + + /// The textual description of the currently executing pipeline. + std::string pipeline; + + /// The MLIR module representing the IR before the crash. + OwningModuleRef module; + + /// The filename to use when generating the reproducer. + StringRef filename; + + /// Various pass manager flags. + bool disableThreads; + bool verifyPasses; + + /// The current set of active reproducer contexts. This is used in the event + /// of a crash. This is not thread_local as the pass manager may produce any + /// number of child threads. This uses a set to allow for multiple MLIR pass + /// managers to be running at the same time. + static llvm::ManagedStatic> reproducerMutex; + static llvm::ManagedStatic< + llvm::SmallSetVector> + reproducerSet; +}; +} // end anonymous namespace + +llvm::ManagedStatic> + RecoveryReproducerContext::reproducerMutex; +llvm::ManagedStatic> + RecoveryReproducerContext::reproducerSet; + +RecoveryReproducerContext::RecoveryReproducerContext( + MutableArrayRef> passes, ModuleOp module, + StringRef filename, bool disableThreads, bool verifyPasses) + : module(module.clone()), filename(filename), + disableThreads(disableThreads), verifyPasses(verifyPasses) { + // Grab the textual pipeline being executed.. + { + llvm::raw_string_ostream pipelineOS(pipeline); + ::printAsTextualPipeline(passes, pipelineOS); + } + + // Make sure that the handler is registered, and update the current context. + llvm::sys::SmartScopedLock producerLock(*reproducerMutex); + registerSignalHandler(); + reproducerSet->insert(this); +} + +RecoveryReproducerContext::~RecoveryReproducerContext() { + llvm::sys::SmartScopedLock producerLock(*reproducerMutex); + reproducerSet->remove(this); +} + +LogicalResult RecoveryReproducerContext::generate(std::string &error) { + std::unique_ptr outputFile = + mlir::openOutputFile(filename, &error); + if (!outputFile) + return failure(); + auto &outputOS = outputFile->os(); + + // Output the current pass manager configuration. + outputOS << "// configuration: -pass-pipeline='" << pipeline << "'"; + if (disableThreads) + outputOS << " -disable-pass-threading"; + + // TODO: Should this also be configured with a pass manager flag? + outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false") + << "\n"; + + // Output the .mlir module. + module->print(outputOS); + outputFile->keep(); + return success(); +} + +void RecoveryReproducerContext::crashHandler(void *) { + // Walk the current stack of contexts and generate a reproducer for each one. + // We can't know for certain which one was the cause, so we need to generate + // a reproducer for all of them. + std::string ignored; + for (RecoveryReproducerContext *context : *reproducerSet) + context->generate(ignored); +} + +void RecoveryReproducerContext::registerSignalHandler() { + // Ensure that the handler is only registered once. + static bool registered = + (llvm::sys::AddSignalHandler(crashHandler, nullptr), false); + (void)registered; +} + /// Run the pass manager with crash recover enabled. LogicalResult PassManager::runWithCrashRecovery(ModuleOp module, AnalysisManager am) { @@ -572,21 +682,11 @@ LogicalResult PassManager::runWithCrashRecovery(MutableArrayRef> passes, ModuleOp module, AnalysisManager am) { - /// Enable crash recovery. - llvm::CrashRecoveryContext::Enable(); - - // Grab the textual pipeline being executed first, just in case the passes - // become compromised. - std::string pipeline; - { - llvm::raw_string_ostream pipelineOS(pipeline); - ::printAsTextualPipeline(passes, pipelineOS); - } - - // Clone the initial module before running it through the pass pipeline. - OwningModuleRef reproducerModule = module.clone(); + RecoveryReproducerContext context(passes, module, *crashReproducerFileName, + impl->disableThreads, impl->verifyPasses); // Safely invoke the passes within a recovery context. + llvm::CrashRecoveryContext::Enable(); LogicalResult passManagerResult = failure(); llvm::CrashRecoveryContext recoveryContext; recoveryContext.RunSafelyOnThread([&] { @@ -600,26 +700,9 @@ return success(); std::string error; - std::unique_ptr outputFile = - mlir::openOutputFile(*crashReproducerFileName, &error); - if (!outputFile) + if (failed(context.generate(error))) return module.emitError(": ") << error; - auto &outputOS = outputFile->os(); - - // Output the current pass manager configuration. - outputOS << "// configuration: -pass-pipeline='" << pipeline << "'"; - if (impl->disableThreads) - outputOS << " -disable-pass-threading"; - - // TODO: Should this also be configured with a pass manager flag? - outputOS << "\n// note: verifyPasses=" - << (impl->verifyPasses ? "true" : "false") << "\n"; - - // Output the .mlir module. - reproducerModule->print(outputOS); - outputFile->keep(); - - return reproducerModule->emitError() + return module.emitError() << "A failure has been detected while processing the MLIR module, a " "reproducer has been generated in '" << *crashReproducerFileName << "'";