diff --git a/llvm/include/llvm/Support/ToolOutputFile.h b/llvm/include/llvm/Support/ToolOutputFile.h --- a/llvm/include/llvm/Support/ToolOutputFile.h +++ b/llvm/include/llvm/Support/ToolOutputFile.h @@ -35,6 +35,7 @@ /// The flag which indicates whether we should not delete the file. bool Keep; + StringRef getFilename() { return Filename; } explicit CleanupInstaller(StringRef Filename); ~CleanupInstaller(); } Installer; @@ -57,6 +58,9 @@ /// Return the contained raw_fd_ostream. raw_fd_ostream &os() { return *OS; } + /// Return the filename initialized with. + StringRef getFilename() { return Installer.getFilename(); } + /// Indicate that the tool's job wrt this output file has been successful and /// the file should not be deleted. void keep() { Installer.Keep = true; } diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -1145,8 +1145,7 @@ reproducible may have the form: ```mlir -// configuration: -pass-pipeline='func(cse,canonicalize),inline' -// note: verifyPasses=false +// configuration: -pass-pipeline='func(cse,canonicalize),inline' -verify-each module { func @foo() { @@ -1159,6 +1158,10 @@ `-run-reproducer` flag. This will result in parsing the first line configuration of the reproducer and adding those to the command line options. +Beyond specifying a filename, one can also register a `ReproducerStreamFactory` +function that would be invoked in the case of a crash and the reproducer written +to its stream. + ### Local Reproducer Generation An additional flag may be passed to diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -192,6 +192,29 @@ void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer = false); + /// Streams on which to output crash reproducer. + struct ReproducerStream { + virtual ~ReproducerStream() = default; + + /// Description of the reproducer stream. + virtual StringRef description() = 0; + + /// Stream on which to output reproducer. + virtual raw_ostream &os() = 0; + }; + + /// Method type for constructing ReproducerStream. + using ReproducerStreamFactory = + std::function(std::string &error)>; + + /// Enable support for the pass manager to generate a reproducer on the event + /// of a crash or a pass failure. `factory` is used to construct the streams + /// to write the generated reproducer to. If `genLocalReproducer` is true, the + /// pass manager will attempt to generate a local reproducer that contains the + /// smallest pipeline. + void enableCrashReproducerGeneration(ReproducerStreamFactory factory, + bool genLocalReproducer = false); + /// Runs the verifier after each individual pass. void enableVerifier(bool enabled = true); @@ -349,8 +372,8 @@ /// A manager for pass instrumentations. std::unique_ptr instrumentor; - /// An optional filename to use when generating a crash reproducer if valid. - Optional crashReproducerFileName; + /// An optional factory to use when generating a crash reproducer if valid. + ReproducerStreamFactory crashReproducerStreamFactory; /// Flag that specifies if pass timing is enabled. bool passTiming : 1; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -645,7 +645,8 @@ /// reproducers when a signal is raised, such as a segfault. struct RecoveryReproducerContext { RecoveryReproducerContext(MutableArrayRef> passes, - Operation *op, StringRef filename, + Operation *op, + PassManager::ReproducerStreamFactory &crashStream, bool disableThreads, bool verifyPasses); ~RecoveryReproducerContext(); @@ -665,8 +666,9 @@ /// The MLIR operation representing the IR before the crash. Operation *preCrashOperation; - /// The filename to use when generating the reproducer. - StringRef filename; + /// The factory for the reproducer output stream to use when generating the + /// reproducer. + PassManager::ReproducerStreamFactory &crashStreamFactory; /// Various pass manager and context flags. bool disableThreads; @@ -681,6 +683,24 @@ llvm::SmallSetVector> reproducerSet; }; + +/// Instance of ReproducerStream backed by file. +struct FileReproducerStream : public PassManager::ReproducerStream { + FileReproducerStream(std::unique_ptr outputFile) + : outputFile(std::move(outputFile)) {} + ~FileReproducerStream() override; + + /// Description of the reproducer stream. + StringRef description() override; + + /// Stream on which to output reprooducer. + raw_ostream &os() override; + +private: + /// ToolOutputFile corresponding to opened `filename`. + std::unique_ptr outputFile = nullptr; +}; + } // end anonymous namespace llvm::ManagedStatic> @@ -690,8 +710,9 @@ RecoveryReproducerContext::RecoveryReproducerContext( MutableArrayRef> passes, Operation *op, - StringRef filename, bool disableThreads, bool verifyPasses) - : preCrashOperation(op->clone()), filename(filename), + PassManager::ReproducerStreamFactory &crashStreamFactory, + bool disableThreads, bool verifyPasses) + : preCrashOperation(op->clone()), crashStreamFactory(crashStreamFactory), disableThreads(disableThreads), verifyPasses(verifyPasses) { // Grab the textual pipeline being executed.. { @@ -717,25 +738,42 @@ llvm::CrashRecoveryContext::Disable(); } +/// Description of the reproducer stream. +StringRef FileReproducerStream::description() { + return outputFile->getFilename(); +} + +/// Stream on which to output reproducer. +raw_ostream &FileReproducerStream::os() { return outputFile->os(); } + +FileReproducerStream::~FileReproducerStream() { outputFile->keep(); } + LogicalResult RecoveryReproducerContext::generate(std::string &error) { - std::unique_ptr outputFile = - mlir::openOutputFile(filename, &error); - if (!outputFile) + std::unique_ptr crashStream = + crashStreamFactory(error); + if (!crashStream) return failure(); - auto &outputOS = outputFile->os(); // Output the current pass manager configuration. - outputOS << "// configuration: -pass-pipeline='" << pipeline << "'"; + auto &os = crashStream->os(); + os << "// configuration: -pass-pipeline='" << pipeline << "'"; if (disableThreads) - outputOS << " -mlir-disable-threading"; - - // TODO: Should this also be configured with a pass manager flag? - outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false") - << "\n"; + os << " -mlir-disable-threading"; + if (verifyPasses) + os << " -verify-each"; + os << '\n'; // Output the .mlir module. - preCrashOperation->print(outputOS); - outputFile->keep(); + preCrashOperation->print(os); + + bool shouldPrintOnOp = + preCrashOperation->getContext()->shouldPrintOpOnDiagnostic(); + preCrashOperation->getContext()->printOpOnDiagnostic(false); + preCrashOperation->emitError() + << "A failure has been detected while processing the MLIR module, a " + "reproducer has been generated in '" + << crashStream->description() << "'"; + preCrashOperation->getContext()->printOpOnDiagnostic(shouldPrintOnOp); return success(); } @@ -778,7 +816,7 @@ LogicalResult PassManager::runWithCrashRecovery(MutableArrayRef> passes, Operation *op, AnalysisManager am) { - RecoveryReproducerContext context(passes, op, *crashReproducerFileName, + RecoveryReproducerContext context(passes, op, crashReproducerStreamFactory, !getContext()->isMultithreadingEnabled(), verifyPasses); @@ -798,13 +836,6 @@ std::string error; if (failed(context.generate(error))) return op->emitError(": ") << error; - bool shouldPrintOnOp = op->getContext()->shouldPrintOpOnDiagnostic(); - op->getContext()->printOpOnDiagnostic(false); - op->emitError() - << "A failure has been detected while processing the MLIR module, a " - "reproducer has been generated in '" - << *crashReproducerFileName << "'"; - op->getContext()->printOpOnDiagnostic(shouldPrintOnOp); return failure(); } @@ -848,7 +879,7 @@ // If reproducer generation is enabled, run the pass manager with crash // handling enabled. LogicalResult result = - crashReproducerFileName + crashReproducerStreamFactory ? runWithCrashRecovery(op, am) : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses, impl->initializationGeneration); @@ -869,7 +900,30 @@ /// pipeline. void PassManager::enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer) { - crashReproducerFileName = std::string(outputFile); + // Capture the filename by value in case outputFile is out of scope when + // invoked. + std::string filename = outputFile.str(); + enableCrashReproducerGeneration( + [filename](std::string &error) -> std::unique_ptr { + std::unique_ptr outputFile = + mlir::openOutputFile(filename, &error); + if (!outputFile) { + error = "Failed to create reproducer stream: " + error; + return nullptr; + } + return std::make_unique(std::move(outputFile)); + }, + genLocalReproducer); +} + +/// Enable support for the pass manager to generate a reproducer on the event +/// of a crash or a pass failure. `factory` is used to construct the streams +/// to write the generated reproducer to. If `genLocalReproducer` is true, the +/// pass manager will attempt to generate a local reproducer that contains the +/// smallest pipeline. +void PassManager::enableCrashReproducerGeneration( + ReproducerStreamFactory factory, bool genLocalReproducer) { + crashReproducerStreamFactory = factory; localReproducer = genLocalReproducer; }