diff --git a/mlir/lib/Reducer/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp --- a/mlir/lib/Reducer/OptReductionPass.cpp +++ b/mlir/lib/Reducer/OptReductionPass.cpp @@ -44,19 +44,19 @@ PassManager passManager(module.getContext()); if (failed(parsePassPipeline(optPass, passManager))) { - LLVM_DEBUG(llvm::dbgs() << "\nFailed to parse pass pipeline"); - return; + module.emitError() << "\nFailed to parse pass pipeline"; + return signalPassFailure(); } std::pair original = test.isInteresting(module); if (original.first != Tester::Interestingness::True) { - LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested"); - return; + module.emitError() << "\nThe original input is not interested"; + return signalPassFailure(); } if (failed(passManager.run(moduleVariant))) { - LLVM_DEBUG(llvm::dbgs() << "\nFailed to run pass pipeline"); - return; + module.emitError() << "\nFailed to run pass pipeline"; + return signalPassFailure(); } std::pair reduced = diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -76,15 +76,15 @@ /// alternative way to remove operations, which is using `eraseOpNotInRange` to /// erase the operations not in the range specified by ReductionNode. template -static void findOptimal(ModuleOp module, Region ®ion, - const FrozenRewritePatternSet &patterns, - const Tester &test, bool eraseOpNotInRange) { +static LogicalResult findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + const Tester &test, bool eraseOpNotInRange) { std::pair initStatus = test.isInteresting(module); // While exploring the reduction tree, we always branch from an interesting // node. Thus the root node must be interesting. if (initStatus.first != Tester::Interestingness::True) - return; + return module.emitWarning() << "Uninterested module will not be reduced"; llvm::SpecificBumpPtrAllocator allocator; @@ -137,23 +137,27 @@ if (test.isInteresting(module).second != smallestNode->getSize()) llvm::report_fatal_error( "Reduced module doesn't have consistent size with smallestNode"); + return success(); } template -static void findOptimal(ModuleOp module, Region ®ion, - const FrozenRewritePatternSet &patterns, - const Tester &test) { +static LogicalResult findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + const Tester &test) { // We separate the reduction process into 2 steps, the first one is to erase // redundant operations and the second one is to apply the reducer patterns. // In the first phase, we don't apply any patterns so that we only select the // range of operations to keep to the module stay interesting. - findOptimal(module, region, /*patterns=*/{}, test, - /*eraseOpNotInRange=*/true); + if (failed(findOptimal(module, region, /*patterns=*/{}, test, + /*eraseOpNotInRange=*/true))) + return failure(); // In the second phase, we suppose that no operation is redundant, so we try // to rewrite the operation into simpler form. - findOptimal(module, region, patterns, test, - /*eraseOpNotInRange=*/false); + if (failed(findOptimal(module, region, patterns, test, + /*eraseOpNotInRange=*/false))) + return failure(); + return success(); } namespace { @@ -192,7 +196,7 @@ void runOnOperation() override; private: - void reduceOp(ModuleOp module, Region ®ion); + LogicalResult reduceOp(ModuleOp module, Region ®ion); FrozenRewritePatternSet reducerPatterns; }; @@ -221,7 +225,8 @@ for (Region ®ion : op->getRegions()) if (!region.empty()) - reduceOp(module, region); + if (failed(reduceOp(module, region))) + return signalPassFailure(); for (Region ®ion : op->getRegions()) for (Operation &op : region.getOps()) @@ -230,15 +235,14 @@ } while (!workList.empty()); } -void ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { +LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { Tester test(testerName, testerArgs); switch (traversalModeId) { case TraversalMode::SinglePath: - findOptimal>( + return findOptimal>( module, region, reducerPatterns, test); - break; default: - llvm_unreachable("Unsupported mode"); + return module.emitError() << "Unsupported traversal mode detected"; } } diff --git a/mlir/lib/Tools/mlir-reduce/CMakeLists.txt b/mlir/lib/Tools/mlir-reduce/CMakeLists.txt --- a/mlir/lib/Tools/mlir-reduce/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-reduce/CMakeLists.txt @@ -1,11 +1,3 @@ -set(LLVM_OPTIONAL_SOURCES - MlirReduceMain.cpp -) - -set(LLVM_LINK_COMPONENTS - Support - ) - add_mlir_library(MLIRReduceLib MlirReduceMain.cpp @@ -15,4 +7,5 @@ MLIRPass MLIRReduce MLIRSupport + MLIRTransforms ) diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -22,6 +22,7 @@ #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/ToolOutputFile.h" @@ -49,8 +50,7 @@ llvm::InitLLVM y(argc, argv); registerReducerPasses(); - registerMLIRContextCLOptions(); - registerPassManagerCLOptions(); + registerSymbolDCEPass(); PassPipelineCLParser parser("", "Reduction Passes to Run"); llvm::cl::ParseCommandLineOptions(argc, argv, diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -1,20 +1,12 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) - -if(MLIR_INCLUDE_TESTS) - set(test_libs - MLIRTestDialect - ) -endif() set(LIBS ${dialect_libs} - ${conversion_libs} - ${test_libs} MLIRDialect MLIRIR MLIRPass MLIRReduceLib + MLIRTestDialect ) add_llvm_tool(mlir-reduce diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp --- a/mlir/tools/mlir-reduce/mlir-reduce.cpp +++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -16,7 +16,6 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-reduce/MlirReduceMain.h" using namespace mlir; @@ -30,8 +29,6 @@ } // namespace mlir int main(int argc, char **argv) { - registerAllPasses(); - DialectRegistry registry; registerAllDialects(registry); #ifdef MLIR_INCLUDE_TESTS