diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -192,6 +192,38 @@ void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer = false); + /// Streams on which to output crash reproducer. + struct ReproducerStream { + virtual ~ReproducerStream() = default; + + /// Description of the reproducer stream. + virtual std::string description() = 0; + + /// Create the underlying stream. Returns whether succeeded. + virtual LogicalResult create() = 0; + + /// Stream on which to output configuration. + virtual raw_ostream &configurationStream() = 0; + + /// Stream on which to output the op on which pipeline failed. + virtual raw_ostream &operationStream() = 0; + + /// Called upon succesful write to stream(s). + virtual void close() = 0; + }; + + /// Method type for constructing ReproducerStream. + using ReproducerStreamFactory = + std::function()>; + + /// Enable support for the pass manager to generate a reproducer on the event + /// of a crash or a pass failure. `factory` is used to construct the streams + /// to write the generated reproducer to. If `genLocalReproducer` is true, the + /// pass manager will attempt to generate a local reproducer that contains the + /// smallest pipeline. + void enableCrashReproducerGeneration(ReproducerStreamFactory factory, + bool genLocalReproducer = false); + /// Runs the verifier after each individual pass. void enableVerifier(bool enabled = true); @@ -349,8 +381,8 @@ /// A manager for pass instrumentations. std::unique_ptr instrumentor; - /// An optional filename to use when generating a crash reproducer if valid. - Optional crashReproducerFileName; + /// An optional factory to use when generating a crash reproducer if valid. + ReproducerStreamFactory crashReproducerStreamFactory; /// Flag that specifies if pass timing is enabled. bool passTiming : 1; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -645,7 +645,8 @@ /// reproducers when a signal is raised, such as a segfault. struct RecoveryReproducerContext { RecoveryReproducerContext(MutableArrayRef> passes, - Operation *op, StringRef filename, + Operation *op, + PassManager::ReproducerStreamFactory &crashStream, bool disableThreads, bool verifyPasses); ~RecoveryReproducerContext(); @@ -665,8 +666,9 @@ /// The MLIR operation representing the IR before the crash. Operation *preCrashOperation; - /// The filename to use when generating the reproducer. - StringRef filename; + /// The factory for the reproducer output stream to use when generating the + /// reproducer. + PassManager::ReproducerStreamFactory &crashStreamFactory; /// Various pass manager and context flags. bool disableThreads; @@ -681,6 +683,31 @@ llvm::SmallSetVector> reproducerSet; }; + +/// Instance of ReproducerStream backed by file. +struct FileReproducerStream : public PassManager::ReproducerStream { + /// Description of the reproducer stream. + std::string description() override; + + /// Create the underlying stream. Returns whether succeeded. + LogicalResult create() override; + + /// Stream on which to output configuration. + raw_ostream &configurationStream() override; + + /// Stream on which to output the op on which pipeline failed. + raw_ostream &operationStream() override; + + /// Called upon succesful write to stream. + void close() override; + + /// Name of the file to output to. + std::string filename; + + /// ToolOutputFile corresponding to opened `filename`. + std::unique_ptr outputFile = nullptr; +}; + } // end anonymous namespace llvm::ManagedStatic> @@ -690,8 +717,9 @@ RecoveryReproducerContext::RecoveryReproducerContext( MutableArrayRef> passes, Operation *op, - StringRef filename, bool disableThreads, bool verifyPasses) - : preCrashOperation(op->clone()), filename(filename), + PassManager::ReproducerStreamFactory &crashStreamFactory, + bool disableThreads, bool verifyPasses) + : preCrashOperation(op->clone()), crashStreamFactory(crashStreamFactory), disableThreads(disableThreads), verifyPasses(verifyPasses) { // Grab the textual pipeline being executed.. { @@ -717,25 +745,59 @@ llvm::CrashRecoveryContext::Disable(); } +/// Create the underlying stream. Returns whether succeeded. +LogicalResult FileReproducerStream::create() { + std::string error; + outputFile = mlir::openOutputFile(filename, &error); + if (!outputFile) { + llvm::errs() << "Failed to create reproducer stream: " << error << '\n'; + return failure(); + } + return success(); +} + +/// Description of the reproducer stream. +std::string FileReproducerStream::description() { return filename; } + +/// Stream on which to output configuration. +raw_ostream &FileReproducerStream::configurationStream() { + return outputFile->os() << "// configuration: "; +} + +/// Stream on which to output op on which pipeline failed. +raw_ostream &FileReproducerStream::operationStream() { + return outputFile->os(); +} + +/// Called upon succesful write to stream. +void FileReproducerStream::close() { outputFile->keep(); } + LogicalResult RecoveryReproducerContext::generate(std::string &error) { - std::unique_ptr outputFile = - mlir::openOutputFile(filename, &error); - if (!outputFile) + auto crashStream = crashStreamFactory(); + if (!crashStream || failed(crashStream->create())) return failure(); - auto &outputOS = outputFile->os(); // Output the current pass manager configuration. - outputOS << "// configuration: -pass-pipeline='" << pipeline << "'"; + auto &os = crashStream->configurationStream(); + os << "-pass-pipeline='" << pipeline << "'"; if (disableThreads) - outputOS << " -mlir-disable-threading"; - - // TODO: Should this also be configured with a pass manager flag? - outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false") - << "\n"; + os << " -mlir-disable-threading"; + if (verifyPasses) + os << " -verify-each"; + os << '\n'; // Output the .mlir module. - preCrashOperation->print(outputOS); - outputFile->keep(); + preCrashOperation->print(crashStream->operationStream()); + crashStream->close(); + + bool shouldPrintOnOp = + preCrashOperation->getContext()->shouldPrintOpOnDiagnostic(); + preCrashOperation->getContext()->printOpOnDiagnostic(false); + preCrashOperation->emitError() + << "A failure has been detected while processing the MLIR module, a " + "reproducer has been generated in '" + << crashStream->description() << "'"; + preCrashOperation->getContext()->printOpOnDiagnostic(shouldPrintOnOp); return success(); } @@ -778,7 +840,7 @@ LogicalResult PassManager::runWithCrashRecovery(MutableArrayRef> passes, Operation *op, AnalysisManager am) { - RecoveryReproducerContext context(passes, op, *crashReproducerFileName, + RecoveryReproducerContext context(passes, op, crashReproducerStreamFactory, !getContext()->isMultithreadingEnabled(), verifyPasses); @@ -798,13 +860,6 @@ std::string error; if (failed(context.generate(error))) return op->emitError(": ") << error; - bool shouldPrintOnOp = op->getContext()->shouldPrintOpOnDiagnostic(); - op->getContext()->printOpOnDiagnostic(false); - op->emitError() - << "A failure has been detected while processing the MLIR module, a " - "reproducer has been generated in '" - << *crashReproducerFileName << "'"; - op->getContext()->printOpOnDiagnostic(shouldPrintOnOp); return failure(); } @@ -848,7 +903,7 @@ // If reproducer generation is enabled, run the pass manager with crash // handling enabled. LogicalResult result = - crashReproducerFileName + crashReproducerStreamFactory ? runWithCrashRecovery(op, am) : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses, impl->initializationGeneration); @@ -869,7 +924,26 @@ /// pipeline. void PassManager::enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer) { - crashReproducerFileName = std::string(outputFile); + // Capture the filename by value in case outputFile is out of scope when + // invoked. + std::string filename = outputFile.str(); + enableCrashReproducerGeneration( + [filename]() { + auto fileReproducer = std::make_unique(); + fileReproducer->filename = filename; + return fileReproducer; + }, + genLocalReproducer); +} + +/// Enable support for the pass manager to generate a reproducer on the event +/// of a crash or a pass failure. `factory` is used to construct the streams +/// to write the generated reproducer to. If `genLocalReproducer` is true, the +/// pass manager will attempt to generate a local reproducer that contains the +/// smallest pipeline. +void PassManager::enableCrashReproducerGeneration( + ReproducerStreamFactory factory, bool genLocalReproducer) { + crashReproducerStreamFactory = factory; localReproducer = genLocalReproducer; }