diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -1126,6 +1126,21 @@ } ``` +* `print-ir-after-failure` + * Only print IR after a pass failure. + * This option should *not* be used with the other `print-ir-after` flags + above. + +```shell +$ mlir-opt foo.mlir -pass-pipeline='func(cse,bad-pass)' -print-ir-failure + +*** IR Dump After BadPass Failed *** +func @simple_constant() -> (i32, i32) { + %c1_i32 = constant 1 : i32 + return %c1_i32, %c1_i32 : i32, i32 +} +``` + * `print-ir-module-scope` * Always print the top-level module operation, regardless of pass type or operation nesting level. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -242,10 +242,15 @@ /// pass, in the case of a non-failure, we should first check if any /// potential mutations were made. This allows for reducing the number of /// logs that don't contain meaningful changes. + /// * 'printAfterOnlyOnFailure' signals that when printing the IR after a + /// pass, we only print in the case of a failure. + /// - This option should *not* be used with the other `printAfter` flags + /// above. /// * 'opPrintingFlags' sets up the printing flags to use when printing the /// IR. explicit IRPrinterConfig( bool printModuleScope = false, bool printAfterOnlyOnChange = false, + bool printAfterOnlyOnFailure = false, OpPrintingFlags opPrintingFlags = OpPrintingFlags()); virtual ~IRPrinterConfig(); @@ -270,6 +275,12 @@ /// "changed". bool shouldPrintAfterOnlyOnChange() const { return printAfterOnlyOnChange; } + /// Returns true if the IR should only printed after a pass if the pass + /// "failed". + bool shouldPrintAfterOnlyOnFailure() const { + return printAfterOnlyOnFailure; + } + /// Returns the printing flags to be used to print the IR. OpPrintingFlags getOpPrintingFlags() const { return opPrintingFlags; } @@ -281,6 +292,10 @@ /// a change is detected. bool printAfterOnlyOnChange; + /// A flag that indicates that the IR after a pass should only be printed if + /// the pass failed. + bool printAfterOnlyOnFailure; + /// Flags to control printing behavior. OpPrintingFlags opPrintingFlags; }; @@ -299,16 +314,20 @@ /// * 'printAfterOnlyOnChange' signals that when printing the IR after a /// pass, in the case of a non-failure, we should first check if any /// potential mutations were made. + /// * 'printAfterOnlyOnFailure' signals that when printing the IR after a + /// pass, we only print in the case of a failure. + /// - This option should *not* be used with the other `printAfter` flags + /// above. + /// * 'out' corresponds to the stream to output the printed IR to. /// * 'opPrintingFlags' sets up the printing flags to use when printing the /// IR. - /// * 'out' corresponds to the stream to output the printed IR to. void enableIRPrinting( std::function shouldPrintBeforePass = [](Pass *, Operation *) { return true; }, std::function shouldPrintAfterPass = [](Pass *, Operation *) { return true; }, bool printModuleScope = true, bool printAfterOnlyOnChange = true, - raw_ostream &out = llvm::errs(), + bool printAfterOnlyOnFailure = false, raw_ostream &out = llvm::errs(), OpPrintingFlags opPrintingFlags = OpPrintingFlags()); //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -134,6 +134,11 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { if (isa(pass)) return; + + // Check to see if we are only printing on failure. + if (config->shouldPrintAfterOnlyOnFailure()) + return; + // If the config asked to detect changes, compare the current fingerprint with // the previous. if (config->shouldPrintAfterOnlyOnChange()) { @@ -177,9 +182,11 @@ /// Initialize the configuration. PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags) : printModuleScope(printModuleScope), printAfterOnlyOnChange(printAfterOnlyOnChange), + printAfterOnlyOnFailure(printAfterOnlyOnFailure), opPrintingFlags(opPrintingFlags) {} PassManager::IRPrinterConfig::~IRPrinterConfig() {} @@ -212,9 +219,10 @@ std::function shouldPrintBeforePass, std::function shouldPrintAfterPass, bool printModuleScope, bool printAfterOnlyOnChange, - OpPrintingFlags opPrintingFlags, raw_ostream &out) + bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, + raw_ostream &out) : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, - opPrintingFlags), + printAfterOnlyOnFailure, opPrintingFlags), shouldPrintBeforePass(shouldPrintBeforePass), shouldPrintAfterPass(shouldPrintAfterPass), out(out) { assert((shouldPrintBeforePass || shouldPrintAfterPass) && @@ -257,9 +265,11 @@ void PassManager::enableIRPrinting( std::function shouldPrintBeforePass, std::function shouldPrintAfterPass, - bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out, + bool printModuleScope, bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure, raw_ostream &out, OpPrintingFlags opPrintingFlags) { enableIRPrinting(std::make_unique( std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), - printModuleScope, printAfterOnlyOnChange, opPrintingFlags, out)); + printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, + opPrintingFlags, out)); } diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -48,6 +48,11 @@ llvm::cl::desc( "When printing the IR after a pass, only print if the IR changed"), llvm::cl::init(false)}; + llvm::cl::opt printAfterFailure{ + "print-ir-after-failure", + llvm::cl::desc( + "When printing the IR after a pass, only print if the pass failed"), + llvm::cl::init(false)}; llvm::cl::opt printModuleScope{ "print-ir-module-scope", llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} " @@ -96,8 +101,9 @@ } // Handle print-after. - if (printAfterAll) { - // If we are printing after all, then just return true for the filter. + if (printAfterAll || printAfterFailure) { + // If we are printing after all or failure, then just return true for the + // filter. shouldPrintAfterPass = [](Pass *, Operation *) { return true; }; } else if (printAfter.hasAnyOccurrences()) { // Otherwise if there are specific passes to print after, then check to see @@ -114,7 +120,8 @@ // Otherwise, add the IR printing instrumentation. pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, - printModuleScope, printAfterChange, llvm::errs()); + printModuleScope, printAfterChange, printAfterFailure, + llvm::errs()); } void mlir::registerPassManagerCLOptions() { diff --git a/mlir/test/Pass/ir-printing.mlir b/mlir/test/Pass/ir-printing.mlir --- a/mlir/test/Pass/ir-printing.mlir +++ b/mlir/test/Pass/ir-printing.mlir @@ -4,6 +4,7 @@ // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s // RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s +// RUN: not mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,test-pass-failure)' -print-ir-after-failure -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_FAILURE %s func @foo() { %0 = constant 0 : i32 @@ -60,3 +61,6 @@ // AFTER_ALL_CHANGE-NOT: *** IR Dump After{{.*}}CSE *** // We expect that only 'foo' changed during CSE, and the second run of CSE did // nothing. + +// AFTER_FAILURE-NOT: *** IR Dump After{{.*}}CSE +// AFTER_FAILURE: *** IR Dump After{{.*}}TestFailurePass Failed *** diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -58,6 +58,12 @@ void runOnOperation() final { abort(); } }; +/// A test pass that always fails to enable testing the failure recovery +/// mechanisms of the pass manager. +class TestFailurePass : public PassWrapper> { + void runOnOperation() final { signalPassFailure(); } +}; + /// A test pass that contains a statistic. struct TestStatisticPass : public PassWrapper> { @@ -103,6 +109,8 @@ PassRegistration( "test-pass-crash", "Test a pass in the pass manager that always crashes"); + PassRegistration( + "test-pass-failure", "Test a pass in the pass manager that always fails"); PassRegistration unusedStatP("test-stats-pass", "Test pass statistics");