diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -9,7 +9,7 @@ #ifndef MLIR_PASS_ANALYSISMANAGER_H #define MLIR_PASS_ANALYSISMANAGER_H -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" #include "mlir/Pass/PassInstrumentation.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" @@ -177,8 +177,8 @@ bool wasInserted; std::tie(it, wasInserted) = analyses.try_emplace(id); - // If we don't have a cached analysis for this function, compute it directly - // and add it to the cache. + // If we don't have a cached analysis for this operation, compute it + // directly and add it to the cache. if (wasInserted) { if (pi) pi->runBeforeAnalysis(getAnalysisName(), id, ir); @@ -321,14 +321,14 @@ friend class ModuleAnalysisManager; }; -/// An analysis manager class specifically for the top-level module operation. -/// This class contains the memory allocations for all nested analysis managers, -/// and provides an anchor point. This is necessary because AnalysisManager is +/// An analysis manager class specifically for the top-level operation. This +/// class contains the memory allocations for all nested analysis managers, and +/// provides an anchor point. This is necessary because AnalysisManager is /// designed to be a thin wrapper around an existing analysis map instance. class ModuleAnalysisManager { public: - ModuleAnalysisManager(ModuleOp module, PassInstrumentor *passInstrumentor) - : analyses(module), passInstrumentor(passInstrumentor) {} + ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor) + : analyses(op), passInstrumentor(passInstrumentor) {} ModuleAnalysisManager(const ModuleAnalysisManager &) = delete; ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete; diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -28,7 +28,6 @@ class AnalysisManager; class Identifier; class MLIRContext; -class ModuleOp; class Operation; class Pass; class PassInstrumentation; @@ -158,12 +157,20 @@ /// The main pass manager and pipeline builder. class PassManager : public OpPassManager { public: - PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit); + /// Create a new pass manager under the given context with a specific nesting + /// style. The created pass manager can schedule operations that match + /// `operationName`. + PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit, + StringRef operationName = "module"); + PassManager(MLIRContext *ctx, StringRef operationName) + : PassManager(ctx, Nesting::Explicit, operationName) {} ~PassManager(); - /// Run the passes within this manager on the provided module. + /// Run the passes within this manager on the provided operation. The + /// specified operation must have the same name as the one provided the pass + /// manager on construction. LLVM_NODISCARD - LogicalResult run(ModuleOp module); + LogicalResult run(Operation *op); /// Return an instance of the context. MLIRContext *getContext() const { return context; } @@ -318,11 +325,11 @@ void dumpStatistics(); /// Run the pass manager with crash recover enabled. - LogicalResult runWithCrashRecovery(ModuleOp module, AnalysisManager am); + LogicalResult runWithCrashRecovery(Operation *op, AnalysisManager am); /// Run the given passes with crash recover enabled. LogicalResult runWithCrashRecovery(MutableArrayRef> passes, - ModuleOp module, AnalysisManager am); + Operation *op, AnalysisManager am); /// Context this PassManager was initialized with. MLIRContext *context; diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" @@ -97,14 +96,10 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, OpPrintingFlags flags) { - // Check to see if we are printing the top-level module. - auto module = dyn_cast(op); - if (module && !op->getBlock()) - return module.print(out << "\n", flags); - // Otherwise, check to see if we are not printing at module scope. if (!printModuleScope) - return op->print(out << "\n", flags.useLocalScope()); + return op->print(out << "\n", + op->getBlock() ? flags.useLocalScope() : flags); // Otherwise, we are printing at module scope. out << " ('" << op->getName() << "' operation"; @@ -113,17 +108,11 @@ out << ": @" << symbolName.getValue(); out << ")\n"; - // Find the top-level module operation. + // Find the top-level operation. auto *topLevelOp = op; while (auto *parentOp = topLevelOp->getParentOp()) topLevelOp = parentOp; - - // Check to see if the top-level operation is actually a module in the case of - // invalid-ir. - if (auto module = dyn_cast(topLevelOp)) - module.print(out, flags); - else - topLevelOp->print(out, flags); + topLevelOp->print(out, flags); } /// Instrumentation hooks. diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -12,7 +12,6 @@ #include "mlir/Pass/Pass.h" #include "PassDetail.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" @@ -528,9 +527,9 @@ asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(), mgrs); - // Run a prepass over the module to collect the operations to execute over. - // This ensures that an analysis manager exists for each operation, as well as - // providing a queue of operations to execute over. + // Run a prepass over the operation to collect the nested operations to + // execute over. This ensures that an analysis manager exists for each + // operation, as well as providing a queue of operations to execute over. std::vector> opAMPairs; for (auto ®ion : getOperation()->getRegions()) { for (auto &block : region) { @@ -614,7 +613,7 @@ /// reproducers when a signal is raised, such as a segfault. struct RecoveryReproducerContext { RecoveryReproducerContext(MutableArrayRef> passes, - ModuleOp module, StringRef filename, + Operation *op, StringRef filename, bool disableThreads, bool verifyPasses); ~RecoveryReproducerContext(); @@ -631,8 +630,8 @@ /// The textual description of the currently executing pipeline. std::string pipeline; - /// The MLIR module representing the IR before the crash. - OwningModuleRef module; + /// The MLIR operation representing the IR before the crash. + Operation *preCrashOperation; /// The filename to use when generating the reproducer. StringRef filename; @@ -658,9 +657,9 @@ RecoveryReproducerContext::reproducerSet; RecoveryReproducerContext::RecoveryReproducerContext( - MutableArrayRef> passes, ModuleOp module, + MutableArrayRef> passes, Operation *op, StringRef filename, bool disableThreads, bool verifyPasses) - : module(module.clone()), filename(filename), + : preCrashOperation(op->clone()), filename(filename), disableThreads(disableThreads), verifyPasses(verifyPasses) { // Grab the textual pipeline being executed.. { @@ -677,6 +676,9 @@ } RecoveryReproducerContext::~RecoveryReproducerContext() { + // Erase the cloned preCrash IR that we cached. + preCrashOperation->erase(); + llvm::sys::SmartScopedLock producerLock(*reproducerMutex); reproducerSet->remove(this); if (reproducerSet->empty()) @@ -700,7 +702,7 @@ << "\n"; // Output the .mlir module. - module->print(outputOS); + preCrashOperation->print(outputOS); outputFile->keep(); return success(); } @@ -722,11 +724,11 @@ } /// Run the pass manager with crash recover enabled. -LogicalResult PassManager::runWithCrashRecovery(ModuleOp module, +LogicalResult PassManager::runWithCrashRecovery(Operation *op, AnalysisManager am) { // If this isn't a local producer, run all of the passes in recovery mode. if (!localReproducer) - return runWithCrashRecovery(impl->passes, module, am); + return runWithCrashRecovery(impl->passes, op, am); // Split the passes within adaptors to ensure that each pass can be run in // isolation. @@ -735,7 +737,7 @@ // If this is a local producer, run each of the passes individually. MutableArrayRef> passes = impl->passes; for (std::unique_ptr &pass : passes) - if (failed(runWithCrashRecovery(pass, module, am))) + if (failed(runWithCrashRecovery(pass, op, am))) return failure(); return success(); } @@ -743,8 +745,8 @@ /// Run the given passes with crash recover enabled. LogicalResult PassManager::runWithCrashRecovery(MutableArrayRef> passes, - ModuleOp module, AnalysisManager am) { - RecoveryReproducerContext context(passes, module, *crashReproducerFileName, + Operation *op, AnalysisManager am) { + RecoveryReproducerContext context(passes, op, *crashReproducerFileName, !getContext()->isMultithreadingEnabled(), verifyPasses); @@ -753,7 +755,7 @@ llvm::CrashRecoveryContext recoveryContext; recoveryContext.RunSafelyOnThread([&] { for (std::unique_ptr &pass : passes) - if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses))) + if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses))) return; passManagerResult = success(); }); @@ -762,8 +764,8 @@ std::string error; if (failed(context.generate(error))) - return module.emitError(": ") << error; - return module.emitError() + return op->emitError(": ") << error; + return op->emitError() << "A failure has been detected while processing the MLIR module, a " "reproducer has been generated in '" << *crashReproducerFileName << "'"; @@ -773,18 +775,21 @@ // PassManager //===----------------------------------------------------------------------===// -PassManager::PassManager(MLIRContext *ctx, Nesting nesting) - : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), - nesting), - context(ctx), passTiming(false), localReproducer(false), - verifyPasses(true) {} +PassManager::PassManager(MLIRContext *ctx, Nesting nesting, + StringRef operationName) + : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx), + passTiming(false), localReproducer(false), verifyPasses(true) {} PassManager::~PassManager() {} void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; } -/// Run the passes within this manager on the provided module. -LogicalResult PassManager::run(ModuleOp module) { +/// Run the passes within this manager on the provided operation. +LogicalResult PassManager::run(Operation *op) { + MLIRContext *context = getContext(); + assert(op->getName().getIdentifier() == getOpName(*context) && + "operation has a different name than the PassManager"); + // Before running, make sure to coalesce any adjacent pass adaptors in the // pipeline. getImpl().coalesceAdjacentAdaptorPasses(); @@ -792,23 +797,23 @@ // Register all dialects for the current pipeline. DialectRegistry dependentDialects; getDependentDialects(dependentDialects); - dependentDialects.loadAll(module.getContext()); + dependentDialects.loadAll(context); - // Construct an analysis manager for the pipeline. - ModuleAnalysisManager am(module, instrumentor.get()); + // Construct a top level analysis manager for the pipeline. + ModuleAnalysisManager am(op, instrumentor.get()); // Notify the context that we start running a pipeline for book keeping. - module.getContext()->enterMultiThreadedExecution(); + context->enterMultiThreadedExecution(); // If reproducer generation is enabled, run the pass manager with crash // handling enabled. - LogicalResult result = crashReproducerFileName - ? runWithCrashRecovery(module, am) - : OpToOpPassAdaptor::runPipeline( - getPasses(), module, am, verifyPasses); + LogicalResult result = + crashReproducerFileName + ? runWithCrashRecovery(op, am) + : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses); // Notify the context that the run is done. - module.getContext()->exitMultiThreadedExecution(); + context->exitMultiThreadedExecution(); // Dump all of the pass statistics if necessary. if (passStatisticsMode) diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -50,7 +50,7 @@ llvm::cl::opt printModuleScope{ "print-ir-module-scope", llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} " - "always print the top-level module operation"), + "always print the top-level operation"), llvm::cl::init(false)}; /// Add an IR printing instrumentation if enabled by any 'print-ir' flags.