diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -1213,11 +1213,14 @@ when the original input relies on components (like dialects or passes) that may not always be available. +Note: Local reproducer generation requires that multi-threading is +disabled(`-mlir-disable-threading`) + For example, if the failure in the previous example came from `canonicalize`, the following reproducer will be generated: ```mlir -// configuration: -pass-pipeline='func(canonicalize)' -verify-each +// configuration: -pass-pipeline='func(canonicalize)' -verify-each -mlir-disable-threading module { func @foo() { diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -37,6 +37,7 @@ namespace detail { struct OpPassManagerImpl; class OpToOpPassAdaptor; +class PassCrashReproducerGenerator; struct PassExecutionState; } // end namespace detail @@ -373,12 +374,11 @@ /// Dump the statistics of the passes within this pass manager. void dumpStatistics(); - /// Run the pass manager with crash recover enabled. + /// Run the pass manager with crash recovery enabled. LogicalResult runWithCrashRecovery(Operation *op, AnalysisManager am); - /// Run the given passes with crash recover enabled. - LogicalResult - runWithCrashRecovery(MutableArrayRef> passes, - Operation *op, AnalysisManager am); + + /// Run the passes of the pass manager, and return the result. + LogicalResult runPasses(Operation *op, AnalysisManager am); /// Context this PassManager was initialized with. MLIRContext *context; @@ -389,8 +389,9 @@ /// A manager for pass instrumentations. std::unique_ptr instrumentor; - /// An optional factory to use when generating a crash reproducer if valid. - ReproducerStreamFactory crashReproducerStreamFactory; + /// An optional crash reproducer generator, if this pass manager is setup to + /// generate reproducers. + std::unique_ptr crashReproGenerator; /// A hash key used to detect when reinitialization is necessary. llvm::hash_code initializationKey; @@ -398,9 +399,6 @@ /// Flag that specifies if pass timing is enabled. bool passTiming : 1; - /// Flag that specifies if the generated crash reproducer should be local. - bool localReproducer : 1; - /// A flag that indicates if the IR should be verified in between passes. bool verifyPasses : 1; }; diff --git a/mlir/lib/Pass/CMakeLists.txt b/mlir/lib/Pass/CMakeLists.txt --- a/mlir/lib/Pass/CMakeLists.txt +++ b/mlir/lib/Pass/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(MLIRPass IRPrinting.cpp Pass.cpp + PassCrashRecovery.cpp PassManagerOptions.cpp PassRegistry.cpp PassStatistics.cpp diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -102,10 +102,6 @@ /// recursively through the pipeline graph. void coalesceAdjacentAdaptorPasses(); - /// Split all of AdaptorPasses such that each adaptor only contains one leaf - /// pass. - void splitAdaptorPasses(); - /// Return the operation name of this pass manager as an identifier. Identifier getOpName(MLIRContext &context) { if (!identifier) @@ -213,27 +209,6 @@ llvm::erase_if(passes, std::logical_not>()); } -void OpPassManagerImpl::splitAdaptorPasses() { - std::vector> oldPasses; - std::swap(passes, oldPasses); - - for (std::unique_ptr &pass : oldPasses) { - // If this pass isn't an adaptor, move it directly to the new pass list. - auto *currentAdaptor = dyn_cast(pass.get()); - if (!currentAdaptor) { - addPass(std::move(pass)); - continue; - } - - // Otherwise, split the adaptors of each manager within the adaptor. - for (OpPassManager &adaptorPM : currentAdaptor->getPassManagers()) { - adaptorPM.getImpl().splitAdaptorPasses(); - for (std::unique_ptr &nestedPass : adaptorPM.getImpl().passes) - nest(adaptorPM.getOpName()).addPass(std::move(nestedPass)); - } - } -} - //===----------------------------------------------------------------------===// // OpPassManager //===----------------------------------------------------------------------===// @@ -645,210 +620,6 @@ signalPassFailure(); } -//===----------------------------------------------------------------------===// -// PassCrashReproducer -//===----------------------------------------------------------------------===// - -namespace { -/// This class contains all of the context for generating a recovery reproducer. -/// Each recovery context is registered globally to allow for generating -/// reproducers when a signal is raised, such as a segfault. -struct RecoveryReproducerContext { - RecoveryReproducerContext(MutableArrayRef> passes, - Operation *op, - PassManager::ReproducerStreamFactory &crashStream, - bool disableThreads, bool verifyPasses); - ~RecoveryReproducerContext(); - - /// Generate a reproducer with the current context. - LogicalResult generate(std::string &error); - -private: - /// This function is invoked in the event of a crash. - static void crashHandler(void *); - - /// Register a signal handler to run in the event of a crash. - static void registerSignalHandler(); - - /// The textual description of the currently executing pipeline. - std::string pipeline; - - /// The MLIR operation representing the IR before the crash. - Operation *preCrashOperation; - - /// The factory for the reproducer output stream to use when generating the - /// reproducer. - PassManager::ReproducerStreamFactory &crashStreamFactory; - - /// Various pass manager and context flags. - bool disableThreads; - bool verifyPasses; - - /// The current set of active reproducer contexts. This is used in the event - /// of a crash. This is not thread_local as the pass manager may produce any - /// number of child threads. This uses a set to allow for multiple MLIR pass - /// managers to be running at the same time. - static llvm::ManagedStatic> reproducerMutex; - static llvm::ManagedStatic< - llvm::SmallSetVector> - reproducerSet; -}; - -/// Instance of ReproducerStream backed by file. -struct FileReproducerStream : public PassManager::ReproducerStream { - FileReproducerStream(std::unique_ptr outputFile) - : outputFile(std::move(outputFile)) {} - ~FileReproducerStream() override; - - /// Description of the reproducer stream. - StringRef description() override; - - /// Stream on which to output reprooducer. - raw_ostream &os() override; - -private: - /// ToolOutputFile corresponding to opened `filename`. - std::unique_ptr outputFile = nullptr; -}; - -} // end anonymous namespace - -llvm::ManagedStatic> - RecoveryReproducerContext::reproducerMutex; -llvm::ManagedStatic> - RecoveryReproducerContext::reproducerSet; - -RecoveryReproducerContext::RecoveryReproducerContext( - MutableArrayRef> passes, Operation *op, - PassManager::ReproducerStreamFactory &crashStreamFactory, - bool disableThreads, bool verifyPasses) - : preCrashOperation(op->clone()), crashStreamFactory(crashStreamFactory), - disableThreads(disableThreads), verifyPasses(verifyPasses) { - // Grab the textual pipeline being executed.. - { - llvm::raw_string_ostream pipelineOS(pipeline); - ::printAsTextualPipeline(passes, pipelineOS); - } - - // Make sure that the handler is registered, and update the current context. - llvm::sys::SmartScopedLock producerLock(*reproducerMutex); - if (reproducerSet->empty()) - llvm::CrashRecoveryContext::Enable(); - registerSignalHandler(); - reproducerSet->insert(this); -} - -RecoveryReproducerContext::~RecoveryReproducerContext() { - // Erase the cloned preCrash IR that we cached. - preCrashOperation->erase(); - - llvm::sys::SmartScopedLock producerLock(*reproducerMutex); - reproducerSet->remove(this); - if (reproducerSet->empty()) - llvm::CrashRecoveryContext::Disable(); -} - -/// Description of the reproducer stream. -StringRef FileReproducerStream::description() { - return outputFile->getFilename(); -} - -/// Stream on which to output reproducer. -raw_ostream &FileReproducerStream::os() { return outputFile->os(); } - -FileReproducerStream::~FileReproducerStream() { outputFile->keep(); } - -LogicalResult RecoveryReproducerContext::generate(std::string &error) { - std::unique_ptr crashStream = - crashStreamFactory(error); - if (!crashStream) - return failure(); - - // Output the current pass manager configuration. - auto &os = crashStream->os(); - os << "// configuration: -pass-pipeline='" << pipeline << "'"; - if (disableThreads) - os << " -mlir-disable-threading"; - if (verifyPasses) - os << " -verify-each"; - os << '\n'; - - // Output the .mlir module. - preCrashOperation->print(os); - - bool shouldPrintOnOp = - preCrashOperation->getContext()->shouldPrintOpOnDiagnostic(); - preCrashOperation->getContext()->printOpOnDiagnostic(false); - preCrashOperation->emitError() - << "A failure has been detected while processing the MLIR module, a " - "reproducer has been generated in '" - << crashStream->description() << "'"; - preCrashOperation->getContext()->printOpOnDiagnostic(shouldPrintOnOp); - return success(); -} - -void RecoveryReproducerContext::crashHandler(void *) { - // Walk the current stack of contexts and generate a reproducer for each one. - // We can't know for certain which one was the cause, so we need to generate - // a reproducer for all of them. - std::string ignored; - for (RecoveryReproducerContext *context : *reproducerSet) - (void)context->generate(ignored); -} - -void RecoveryReproducerContext::registerSignalHandler() { - // Ensure that the handler is only registered once. - static bool registered = - (llvm::sys::AddSignalHandler(crashHandler, nullptr), false); - (void)registered; -} - -/// Run the pass manager with crash recover enabled. -LogicalResult PassManager::runWithCrashRecovery(Operation *op, - AnalysisManager am) { - // If this isn't a local producer, run all of the passes in recovery mode. - if (!localReproducer) - return runWithCrashRecovery(impl->passes, op, am); - - // Split the passes within adaptors to ensure that each pass can be run in - // isolation. - impl->splitAdaptorPasses(); - - // If this is a local producer, run each of the passes individually. - MutableArrayRef> passes = impl->passes; - for (std::unique_ptr &pass : passes) - if (failed(runWithCrashRecovery(pass, op, am))) - return failure(); - return success(); -} - -/// Run the given passes with crash recover enabled. -LogicalResult -PassManager::runWithCrashRecovery(MutableArrayRef> passes, - Operation *op, AnalysisManager am) { - RecoveryReproducerContext context(passes, op, crashReproducerStreamFactory, - !getContext()->isMultithreadingEnabled(), - verifyPasses); - - // Safely invoke the passes within a recovery context. - LogicalResult passManagerResult = failure(); - llvm::CrashRecoveryContext recoveryContext; - recoveryContext.RunSafelyOnThread([&] { - for (std::unique_ptr &pass : passes) - if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses, - impl->initializationGeneration))) - return; - passManagerResult = success(); - }); - if (succeeded(passManagerResult)) - return success(); - - std::string error; - if (failed(context.generate(error))) - return op->emitError(": ") << error; - return failure(); -} - //===----------------------------------------------------------------------===// // PassManager //===----------------------------------------------------------------------===// @@ -857,7 +628,7 @@ StringRef operationName) : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx), initializationKey(DenseMapInfo::getTombstoneKey()), - passTiming(false), localReproducer(false), verifyPasses(true) {} + passTiming(false), verifyPasses(true) {} PassManager::~PassManager() {} @@ -898,10 +669,7 @@ // If reproducer generation is enabled, run the pass manager with crash // handling enabled. LogicalResult result = - crashReproducerStreamFactory - ? runWithCrashRecovery(op, am) - : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses, - impl->initializationGeneration); + crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am); // Notify the context that the run is done. context->exitMultiThreadedExecution(); @@ -912,40 +680,6 @@ return result; } -/// Enable support for the pass manager to generate a reproducer on the event -/// of a crash or a pass failure. `outputFile` is a .mlir filename used to write -/// the generated reproducer. If `genLocalReproducer` is true, the pass manager -/// will attempt to generate a local reproducer that contains the smallest -/// pipeline. -void PassManager::enableCrashReproducerGeneration(StringRef outputFile, - bool genLocalReproducer) { - // Capture the filename by value in case outputFile is out of scope when - // invoked. - std::string filename = outputFile.str(); - enableCrashReproducerGeneration( - [filename](std::string &error) -> std::unique_ptr { - std::unique_ptr outputFile = - mlir::openOutputFile(filename, &error); - if (!outputFile) { - error = "Failed to create reproducer stream: " + error; - return nullptr; - } - return std::make_unique(std::move(outputFile)); - }, - genLocalReproducer); -} - -/// Enable support for the pass manager to generate a reproducer on the event -/// of a crash or a pass failure. `factory` is used to construct the streams -/// to write the generated reproducer to. If `genLocalReproducer` is true, the -/// pass manager will attempt to generate a local reproducer that contains the -/// smallest pipeline. -void PassManager::enableCrashReproducerGeneration( - ReproducerStreamFactory factory, bool genLocalReproducer) { - crashReproducerStreamFactory = factory; - localReproducer = genLocalReproducer; -} - /// Add the provided instrumentation to the pass manager. void PassManager::addInstrumentation(std::unique_ptr pi) { if (!instrumentor) @@ -954,6 +688,11 @@ instrumentor->addInstrumentation(std::move(pi)); } +LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) { + return OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses, + impl->initializationGeneration); +} + //===----------------------------------------------------------------------===// // AnalysisManager //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/PassCrashRecovery.cpp b/mlir/lib/Pass/PassCrashRecovery.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Pass/PassCrashRecovery.cpp @@ -0,0 +1,441 @@ +//===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/FileUtilities.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/CrashRecoveryContext.h" +#include "llvm/Support/Mutex.h" +#include "llvm/Support/Parallel.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/Threading.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; +using namespace mlir::detail; + +//===----------------------------------------------------------------------===// +// RecoveryReproducerContext +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace detail { +/// This class contains all of the context for generating a recovery reproducer. +/// Each recovery context is registered globally to allow for generating +/// reproducers when a signal is raised, such as a segfault. +struct RecoveryReproducerContext { + RecoveryReproducerContext(std::string passPipelineStr, Operation *op, + PassManager::ReproducerStreamFactory &streamFactory, + bool verifyPasses); + ~RecoveryReproducerContext(); + + /// Generate a reproducer with the current context. + void generate(std::string &description); + + /// Disable this reproducer context. This prevents the context from generating + /// a reproducer in the result of a crash. + void disable(); + + /// Enable a previously disabled reproducer context. + void enable(); + +private: + /// This function is invoked in the event of a crash. + static void crashHandler(void *); + + /// Register a signal handler to run in the event of a crash. + static void registerSignalHandler(); + + /// The textual description of the currently executing pipeline. + std::string pipeline; + + /// The MLIR operation representing the IR before the crash. + Operation *preCrashOperation; + + /// The factory for the reproducer output stream to use when generating the + /// reproducer. + PassManager::ReproducerStreamFactory &streamFactory; + + /// Various pass manager and context flags. + bool disableThreads; + bool verifyPasses; + + /// The current set of active reproducer contexts. This is used in the event + /// of a crash. This is not thread_local as the pass manager may produce any + /// number of child threads. This uses a set to allow for multiple MLIR pass + /// managers to be running at the same time. + static llvm::ManagedStatic> reproducerMutex; + static llvm::ManagedStatic< + llvm::SmallSetVector> + reproducerSet; +}; +} // namespace detail +} // namespace mlir + +llvm::ManagedStatic> + RecoveryReproducerContext::reproducerMutex; +llvm::ManagedStatic> + RecoveryReproducerContext::reproducerSet; + +RecoveryReproducerContext::RecoveryReproducerContext( + std::string passPipelineStr, Operation *op, + PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses) + : pipeline(std::move(passPipelineStr)), preCrashOperation(op->clone()), + streamFactory(streamFactory), + disableThreads(!op->getContext()->isMultithreadingEnabled()), + verifyPasses(verifyPasses) { + enable(); +} + +RecoveryReproducerContext::~RecoveryReproducerContext() { + // Erase the cloned preCrash IR that we cached. + preCrashOperation->erase(); + disable(); +} + +void RecoveryReproducerContext::generate(std::string &description) { + llvm::raw_string_ostream descOS(description); + + // Try to create a new output stream for this crash reproducer. + std::string error; + std::unique_ptr stream = streamFactory(error); + if (!stream) { + descOS << "failed to create output stream: " << error; + return; + } + descOS << "reproducer generated at `" << stream->description() << "`"; + + // Output the current pass manager configuration to the crash stream. + auto &os = stream->os(); + os << "// configuration: -pass-pipeline='" << pipeline << "'"; + if (disableThreads) + os << " -mlir-disable-threading"; + if (verifyPasses) + os << " -verify-each"; + os << '\n'; + + // Output the .mlir module. + preCrashOperation->print(os); +} + +void RecoveryReproducerContext::disable() { + llvm::sys::SmartScopedLock lock(*reproducerMutex); + reproducerSet->remove(this); + if (reproducerSet->empty()) + llvm::CrashRecoveryContext::Disable(); +} + +void RecoveryReproducerContext::enable() { + llvm::sys::SmartScopedLock lock(*reproducerMutex); + if (reproducerSet->empty()) + llvm::CrashRecoveryContext::Enable(); + registerSignalHandler(); + reproducerSet->insert(this); +} + +void RecoveryReproducerContext::crashHandler(void *) { + // Walk the current stack of contexts and generate a reproducer for each one. + // We can't know for certain which one was the cause, so we need to generate + // a reproducer for all of them. + for (RecoveryReproducerContext *context : *reproducerSet) { + std::string description; + context->generate(description); + + // Emit an error using information only available within the context. + context->preCrashOperation->getContext()->printOpOnDiagnostic(false); + context->preCrashOperation->emitError() + << "A failure has been detected while processing the MLIR module:" + << description; + } +} + +void RecoveryReproducerContext::registerSignalHandler() { + // Ensure that the handler is only registered once. + static bool registered = + (llvm::sys::AddSignalHandler(crashHandler, nullptr), false); + (void)registered; +} + +//===----------------------------------------------------------------------===// +// PassCrashReproducerGenerator +//===----------------------------------------------------------------------===// + +struct PassCrashReproducerGenerator::Impl { + Impl(PassManager::ReproducerStreamFactory &streamFactory, + bool localReproducer) + : streamFactory(streamFactory), localReproducer(localReproducer) {} + + /// The factory to use when generating a crash reproducer. + PassManager::ReproducerStreamFactory streamFactory; + + /// Flag indicating if reproducer generation should be localized to the + /// failing pass. + bool localReproducer; + + /// A record of all of the currently active reproducer contexts. + SmallVector> activeContexts; + + /// The set of all currently running passes. Note: This is not populated when + /// `localReproducer` is true, as each pass will get its own recovery context. + SetVector> runningPasses; + + /// Various pass manager flags that get emitted when generating a reproducer. + bool pmFlagVerifyPasses; +}; + +PassCrashReproducerGenerator::PassCrashReproducerGenerator( + PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer) + : impl(std::make_unique(streamFactory, localReproducer)) {} +PassCrashReproducerGenerator::~PassCrashReproducerGenerator() {} + +void PassCrashReproducerGenerator::initialize( + iterator_range passes, Operation *op, + bool pmFlagVerifyPasses) { + assert((!impl->localReproducer || + !op->getContext()->isMultithreadingEnabled()) && + "expected multi-threading to be disabled when generating a local " + "reproducer"); + + llvm::CrashRecoveryContext::Enable(); + impl->pmFlagVerifyPasses = pmFlagVerifyPasses; + + // If we aren't generating a local reproducer, prepare a reproducer for the + // given top-level operation. + if (!impl->localReproducer) + prepareReproducerFor(passes, op); +} + +static void +formatPassOpReproducerMessage(Diagnostic &os, + std::pair passOpPair) { + os << "`" << passOpPair.first->getName() << "` on " + << "'" << passOpPair.second->getName() << "' operation"; + if (SymbolOpInterface symbol = dyn_cast(passOpPair.second)) + os << ": @" << symbol.getName(); +} + +void PassCrashReproducerGenerator::finalize(Operation *rootOp, + LogicalResult executionResult) { + // If the pass manager execution succeeded, we don't generate any reproducers. + if (succeeded(executionResult)) + return impl->activeContexts.clear(); + + MLIRContext *context = rootOp->getContext(); + bool shouldPrintOnOp = context->shouldPrintOpOnDiagnostic(); + context->printOpOnDiagnostic(false); + InFlightDiagnostic diag = rootOp->emitError() + << "Failures have been detected while " + "processing an MLIR pass pipeline"; + context->printOpOnDiagnostic(shouldPrintOnOp); + + // If we are generating a global reproducer, we include all of the running + // passes in the error message for the only active context. + if (!impl->localReproducer) { + assert(impl->activeContexts.size() == 1 && "expected one active context"); + + // Generate the reproducer. + std::string description; + impl->activeContexts.front()->generate(description); + + // Emit an error to the user. + Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing ["; + llvm::interleaveComma(impl->runningPasses, note, + [&](const std::pair &value) { + formatPassOpReproducerMessage(note, value); + }); + note << "]: " << description; + return; + } + + // If we were generating a local reproducer, we generate a reproducer for the + // most recently executing pass using the matching entry from `runningPasses` + // to generate a localized diagnostic message. + assert(impl->activeContexts.size() == impl->runningPasses.size() && + "expected running passes to match active contexts"); + + // Generate the reproducer. + RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back(); + std::string description; + reproducerContext.generate(description); + + // Emit an error to the user. + Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing "; + formatPassOpReproducerMessage(note, impl->runningPasses.back()); + note << ": " << description; + + impl->activeContexts.clear(); +} + +void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass, + Operation *op) { + // If not tracking local reproducers, we simply remember that this pass is + // running. + impl->runningPasses.insert(std::make_pair(pass, op)); + if (!impl->localReproducer) + return; + + // Disable the current pass recovery context, if there is one. This may happen + // in the case of dynamic pass pipelines. + if (!impl->activeContexts.empty()) + impl->activeContexts.back()->disable(); + + // Collect all of the parent scopes of this operation. + SmallVector scopes; + while (Operation *parentOp = op->getParentOp()) { + scopes.push_back(op->getName()); + op = parentOp; + } + + // Emit a pass pipeline string for the current pass running on the current + // operation type. + std::string passStr; + llvm::raw_string_ostream passOS(passStr); + for (OperationName scope : llvm::reverse(scopes)) + passOS << scope << "("; + pass->printAsTextualPipeline(passOS); + for (unsigned i = 0, e = scopes.size(); i < e; ++i) + passOS << ")"; + + impl->activeContexts.push_back(std::make_unique( + passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); +} +void PassCrashReproducerGenerator::prepareReproducerFor( + iterator_range passes, Operation *op) { + std::string passStr; + llvm::raw_string_ostream passOS(passStr); + llvm::interleaveComma( + passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); }); + + impl->activeContexts.push_back(std::make_unique( + passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); +} + +void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass, + Operation *op) { + // We only pop the active context if we are tracking local reproducers. + impl->runningPasses.remove(std::make_pair(pass, op)); + if (impl->localReproducer) { + impl->activeContexts.pop_back(); + + // Re-enable the previous pass recovery context, if there was one. This may + // happen in the case of dynamic pass pipelines. + if (!impl->activeContexts.empty()) + impl->activeContexts.back()->enable(); + } +} + +//===----------------------------------------------------------------------===// +// CrashReproducerInstrumentation +//===----------------------------------------------------------------------===// + +namespace { +struct CrashReproducerInstrumentation : public PassInstrumentation { + CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator) + : generator(generator) {} + ~CrashReproducerInstrumentation() override = default; + + /// A callback to run before a pass is executed. + void runBeforePass(Pass *pass, Operation *op) override { + if (!isa(pass)) + generator.prepareReproducerFor(pass, op); + } + + /// A callback to run after a pass is successfully executed. This function + /// takes a pointer to the pass to be executed, as well as the current + /// operation being operated on. + void runAfterPass(Pass *pass, Operation *op) override { + if (!isa(pass)) + generator.removeLastReproducerFor(pass, op); + } + +private: + /// The generator used to create crash reproducers. + PassCrashReproducerGenerator &generator; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// FileReproducerStream +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a default instance of PassManager::ReproducerStream +/// that is backed by a file. +struct FileReproducerStream : public PassManager::ReproducerStream { + FileReproducerStream(std::unique_ptr outputFile) + : outputFile(std::move(outputFile)) {} + ~FileReproducerStream() override { outputFile->keep(); } + + /// Returns a description of the reproducer stream. + StringRef description() override { return outputFile->getFilename(); } + + /// Returns the stream on which to output the reproducer. + raw_ostream &os() override { return outputFile->os(); } + +private: + /// ToolOutputFile corresponding to opened `filename`. + std::unique_ptr outputFile = nullptr; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// PassManager +//===----------------------------------------------------------------------===// + +LogicalResult PassManager::runWithCrashRecovery(Operation *op, + AnalysisManager am) { + crashReproGenerator->initialize(getPasses(), op, verifyPasses); + + // Safely invoke the passes within a recovery context. + LogicalResult passManagerResult = failure(); + llvm::CrashRecoveryContext recoveryContext; + recoveryContext.RunSafelyOnThread( + [&] { passManagerResult = runPasses(op, am); }); + crashReproGenerator->finalize(op, passManagerResult); + return passManagerResult; +} + +void PassManager::enableCrashReproducerGeneration(StringRef outputFile, + bool genLocalReproducer) { + // Capture the filename by value in case outputFile is out of scope when + // invoked. + std::string filename = outputFile.str(); + enableCrashReproducerGeneration( + [filename](std::string &error) -> std::unique_ptr { + std::unique_ptr outputFile = + mlir::openOutputFile(filename, &error); + if (!outputFile) { + error = "Failed to create reproducer stream: " + error; + return nullptr; + } + return std::make_unique(std::move(outputFile)); + }, + genLocalReproducer); +} + +void PassManager::enableCrashReproducerGeneration( + ReproducerStreamFactory factory, bool genLocalReproducer) { + assert(!crashReproGenerator && + "crash reproducer has already been initialized"); + if (genLocalReproducer && getContext()->isMultithreadingEnabled()) + llvm::report_fatal_error( + "Local crash reproduction can't be setup on a " + "pass-manager without disabling multi-threading first."); + + crashReproGenerator = std::make_unique( + factory, genLocalReproducer); + addInstrumentation( + std::make_unique(*crashReproGenerator)); +} diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -82,6 +82,43 @@ friend class mlir::PassManager; }; +//===----------------------------------------------------------------------===// +// PassCrashReproducerGenerator +//===----------------------------------------------------------------------===// + +class PassCrashReproducerGenerator { +public: + PassCrashReproducerGenerator( + PassManager::ReproducerStreamFactory &streamFactory, + bool localReproducer); + ~PassCrashReproducerGenerator(); + + /// Initialize the generator in preparation for reproducer generation. The + /// generator should be reinitialized before each run of the pass manager. + void initialize(iterator_range passes, + Operation *op, bool pmFlagVerifyPasses); + /// Finalize the current run of the generator, generating any necessary + /// reproducers if the provided execution result is a failure. + void finalize(Operation *rootOp, LogicalResult executionResult); + + /// Prepare a new reproducer for the given pass, operating on `op`. + void prepareReproducerFor(Pass *pass, Operation *op); + + /// Prepare a new reproducer for the given passes, operating on `op`. + void prepareReproducerFor(iterator_range passes, + Operation *op); + + /// Remove the last recorded reproducer anchored at the given pass and + /// operation. + void removeLastReproducerFor(Pass *pass, Operation *op); + +private: + struct Impl; + + /// The internal implementation of the crash reproducer. + std::unique_ptr impl; +}; + } // end namespace detail } // end namespace mlir #endif // MLIR_PASS_PASSDETAIL_H_ diff --git a/mlir/test/Pass/crash-recovery.mlir b/mlir/test/Pass/crash-recovery.mlir --- a/mlir/test/Pass/crash-recovery.mlir +++ b/mlir/test/Pass/crash-recovery.mlir @@ -1,26 +1,33 @@ -// RUN: mlir-opt %s -pass-pipeline='func(test-function-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics +// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics // RUN: cat %t | FileCheck -check-prefix=REPRO %s -// RUN: mlir-opt %s -pass-pipeline='func(test-function-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer +// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer -mlir-disable-threading // RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL %s -// Check that we correctly handle verifiers passes with local reproducer, this use to crash. -// RUN: mlir-opt %s -test-function-pass -test-function-pass -test-module-pass -pass-pipeline-crash-reproducer=%t -pass-pipeline-local-reproducer +// Check that we correctly handle verifiers passes with local reproducer, this used to crash. +// RUN: mlir-opt %s -test-module-pass -test-module-pass -test-module-pass -pass-pipeline-crash-reproducer=%t -pass-pipeline-local-reproducer -mlir-disable-threading +// RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL %s + +// Check that local reproducers will also traverse dynamic pass pipelines. +// RUN: mlir-opt %s -pass-pipeline='test-module-pass,test-dynamic-pipeline{op-name=inner_mod1 run-on-nested-operations=1 dynamic-pipeline=test-pass-crash}' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer --mlir-disable-threading +// RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL_DYNAMIC %s -// expected-error@+1 {{A failure has been detected while processing the MLIR module}} -module { - func @foo() { - return - } +// expected-error@below {{Failures have been detected while processing an MLIR pass pipeline}} +// expected-note@below {{Pipeline failed while executing}} +module @inner_mod1 { + module @foo {} } -// REPRO: configuration: -pass-pipeline='func(test-function-pass, test-pass-crash)' +// REPRO: configuration: -pass-pipeline='module(test-module-pass, test-pass-crash)' + +// REPRO: module @inner_mod1 +// REPRO: module @foo { + +// REPRO_LOCAL: configuration: -pass-pipeline='module(test-pass-crash)' -// REPRO: module -// REPRO: func @foo() { -// REPRO-NEXT: return +// REPRO_LOCAL: module @inner_mod1 +// REPRO_LOCAL: module @foo { -// REPRO_LOCAL: configuration: -pass-pipeline='func(test-pass-crash)' +// REPRO_LOCAL_DYNAMIC: configuration: -pass-pipeline='module(test-pass-crash)' -// REPRO_LOCAL: module -// REPRO_LOCAL: func @foo() { -// REPRO_LOCAL-NEXT: return +// REPRO_LOCAL_DYNAMIC: module @inner_mod1 +// REPRO_LOCAL_DYNAMIC: module @foo {