Index: mlir/include/mlir/Reducer/CMakeLists.txt =================================================================== --- mlir/include/mlir/Reducer/CMakeLists.txt +++ mlir/include/mlir/Reducer/CMakeLists.txt @@ -1,5 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Reducer) add_public_tablegen_target(MLIRReducerIncGen) add_mlir_doc(Passes -gen-pass-doc ReducerPasses ./) Index: mlir/include/mlir/Reducer/OptReductionPass.h =================================================================== --- mlir/include/mlir/Reducer/OptReductionPass.h +++ mlir/include/mlir/Reducer/OptReductionPass.h @@ -28,23 +28,12 @@ class OptReductionPass : public OptReductionBase { public: - OptReductionPass(const Tester &test, MLIRContext *context, - std::unique_ptr optPass); + OptReductionPass() = default; - OptReductionPass(const OptReductionPass &srcPass); + OptReductionPass(const OptReductionPass &srcPass) = default; /// Runs the pass instance in the pass pipeline. void runOnOperation() override; - -private: - // Points to the context to be used in the pass manager. - MLIRContext *context; - - // This is used to test the interesting behavior of the transformed module. - const Tester &test; - - // Points to the mlir-opt pass to be called. - std::unique_ptr optPass; }; } // end namespace mlir Index: mlir/include/mlir/Reducer/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Reducer/Passes.h @@ -0,0 +1,26 @@ +//===- Passes.h - Reducer Pass Construction and Registration ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_REDUCER_PASSES_H +#define MLIR_REDUCER_PASSES_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Reducer/OptReductionPass.h" +#include "mlir/Reducer/ReductionTreePass.h" + +namespace mlir { + +std::unique_ptr createReductionTreePass(); + +std::unique_ptr createOptReductionPass(); + +/// Generate the code for registering reducer passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Reducer/Passes.h.inc" + +} // namespace mlir +#endif // MLIR_REDUCER_PASSES_H Index: mlir/include/mlir/Reducer/Passes.td =================================================================== --- mlir/include/mlir/Reducer/Passes.td +++ mlir/include/mlir/Reducer/Passes.td @@ -17,10 +17,34 @@ def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { let summary = "A general reduction tree pass for the MLIR Reduce Tool"; + + let constructor = "mlir::createReductionTreePass()"; + + let options = [ + Option<"OpReducerName", "op-reducer", "std::string", /* default */"", + "The OpReducer to reduce the module">, + Option<"TraversalModeId", "traversal-mode", "unsigned", + /* default */"0", "The graph traversal mode">, + Option<"TesterName", "test", "std::string", /* default */"", + "The filename of the tester">, + ListOption<"TesterArgs", "test-arg", "std::string", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; } def OptReduction : Pass<"opt-reduction-pass", "ModuleOp"> { let summary = "A reduction pass wrapper for optimization passes"; + + let constructor = "mlir::createOptReductionPass()"; + + let options = [ + Option<"OptPass", "opt-pass", "std::string", /* default */"", + "The optimization pass will be run dynamically in OptReductionPass">, + Option<"TesterName", "test", "std::string", /* default */"", + "The filename of the tester">, + ListOption<"TesterArgs", "test-arg", "std::string", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; } #endif // MLIR_REDUCER_PASSES Index: mlir/include/mlir/Reducer/ReductionTreePass.h =================================================================== --- mlir/include/mlir/Reducer/ReductionTreePass.h +++ mlir/include/mlir/Reducer/ReductionTreePass.h @@ -33,13 +33,8 @@ /// generated reduced variants. class ReductionTreePass : public ReductionTreeBase { public: - ReductionTreePass(const ReductionTreePass &pass) - : ReductionTreeBase(pass), opType(pass.opType), - mode(pass.mode), test(pass.test) {} - - ReductionTreePass(llvm::StringRef opType, TraversalMode mode, - const Tester &test) - : opType(opType), mode(mode), test(test) {} + ReductionTreePass() = default; + ReductionTreePass(const ReductionTreePass &pass) = default; /// Runs the pass instance in the pass pipeline. void runOnOperation() override; @@ -48,15 +43,6 @@ template ModuleOp findOptimal(ModuleOp module, std::unique_ptr reducer, ReductionNode *node); - - /// The name of operation that we will try to remove. - llvm::StringRef opType; - - TraversalMode mode; - - /// This is used to test the interesting behavior of the reduction nodes in - /// the tree. - const Tester &test; }; } // end namespace mlir Index: mlir/test/mlir-reduce/dce-test.mlir =================================================================== --- mlir/test/mlir-reduce/dce-test.mlir +++ mlir/test/mlir-reduce/dce-test.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -test %S/failure-test.sh -pass-test DCE | FileCheck %s +// RUN: mlir-reduce %s -opt-reduction-pass='opt-pass=symbol-dce test=%S/failure-test.sh' | FileCheck %s // This input should be reduced by the pass pipeline so that only // the @simple1 function remains as the other functions should be // removed by the dead code elimination pass. Index: mlir/test/mlir-reduce/multiple-function.mlir =================================================================== --- mlir/test/mlir-reduce/multiple-function.mlir +++ mlir/test/mlir-reduce/multiple-function.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -test %S/failure-test.sh -pass-test function-reducer | FileCheck %s +// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s // This input should be reduced by the pass pipeline so that only // the @simple5 function remains as this is the shortest function // containing the interesting behavior. Index: mlir/test/mlir-reduce/simple-test.mlir =================================================================== --- mlir/test/mlir-reduce/simple-test.mlir +++ mlir/test/mlir-reduce/simple-test.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -test %S/test.sh -pass-test function +// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh' func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { cond_br %arg0, ^bb1, ^bb2 Index: mlir/tools/mlir-reduce/OptReductionPass.cpp =================================================================== --- mlir/tools/mlir-reduce/OptReductionPass.cpp +++ mlir/tools/mlir-reduce/OptReductionPass.cpp @@ -13,33 +13,39 @@ //===----------------------------------------------------------------------===// #include "mlir/Reducer/OptReductionPass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Reducer/Tester.h" +#include "mlir/Support/LogicalResult.h" #define DEBUG_TYPE "mlir-reduce" -using namespace mlir; - -OptReductionPass::OptReductionPass(const Tester &test, MLIRContext *context, - std::unique_ptr optPass) - : context(context), test(test), optPass(std::move(optPass)) {} - -OptReductionPass::OptReductionPass(const OptReductionPass &srcPass) - : OptReductionBase(srcPass), test(srcPass.test), - optPass(srcPass.optPass.get()) {} +namespace mlir { /// Runs the pass instance in the pass pipeline. void OptReductionPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: "); - LLVM_DEBUG(llvm::dbgs() << optPass.get()->getName() << "\nTesting:\n"); + + const Tester test(TesterName, TesterArgs); ModuleOp module = this->getOperation(); ModuleOp moduleVariant = module.clone(); - PassManager pmTransform(context); - pmTransform.addPass(std::move(optPass)); + + PassManager passManager(module.getContext()); + if (failed(parsePassPipeline(OptPass, passManager))) { + LLVM_DEBUG(llvm::dbgs() << "\nFailed to parse pass pipeline"); + return; + } std::pair original = test.isInteresting(module); + if (original.first != Tester::Interestingness::True) { + LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested"); + return; + } - if (failed(pmTransform.run(moduleVariant))) + if (failed(passManager.run(moduleVariant))) { + LLVM_DEBUG(llvm::dbgs() << "\nFailed to run pass pipeline"); return; + } std::pair reduced = test.isInteresting(moduleVariant); @@ -58,3 +64,9 @@ LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n"); } + +std::unique_ptr createOptReductionPass() { + return std::make_unique(); +} + +} // end namespace mlir Index: mlir/tools/mlir-reduce/ReductionTreePass.cpp =================================================================== --- mlir/tools/mlir-reduce/ReductionTreePass.cpp +++ mlir/tools/mlir-reduce/ReductionTreePass.cpp @@ -34,7 +34,7 @@ void ReductionTreePass::runOnOperation() { ModuleOp module = this->getOperation(); - std::unique_ptr reducer = getOpReducer(opType); + std::unique_ptr reducer = getOpReducer(OpReducerName); std::vector> ranges = { {0, reducer->getNumTargetOps(module)}}; @@ -44,7 +44,7 @@ new (root) ReductionNode(nullptr, ranges, allocator); ModuleOp golden = module; - switch (mode) { + switch (TraversalModeId) { case TraversalMode::SinglePath: golden = findOptimal>( module, std::move(reducer), root); @@ -65,8 +65,15 @@ ModuleOp ReductionTreePass::findOptimal(ModuleOp module, std::unique_ptr reducer, ReductionNode *root) { + const Tester test(TesterName, TesterArgs); std::pair initStatus = test.isInteresting(module); + + if (initStatus.first != Tester::Interestingness::True) { + LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested"); + return module; + } + root->update(initStatus); ReductionNode *smallestNode = root; @@ -98,4 +105,8 @@ return golden; } +std::unique_ptr createReductionTreePass() { + return std::make_unique(); +} + } // end namespace mlir Index: mlir/tools/mlir-reduce/mlir-reduce.cpp =================================================================== --- mlir/tools/mlir-reduce/mlir-reduce.cpp +++ mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -16,10 +16,12 @@ #include #include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Reducer/OptReductionPass.h" +#include "mlir/Reducer/Passes.h" #include "mlir/Reducer/Passes/OpReducer.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionTreePass.h" @@ -42,23 +44,11 @@ llvm::cl::Required, llvm::cl::desc("")); -static llvm::cl::opt - testFilename("test", llvm::cl::Required, llvm::cl::desc("Testing script")); - -static llvm::cl::list - testArguments("test-args", llvm::cl::ZeroOrMore, - llvm::cl::desc("Testing script arguments")); - static llvm::cl::opt outputFilename("o", llvm::cl::desc("Output filename for the reduced test case"), llvm::cl::init("-")); -// TODO: Use PassPipelineCLParser to define pass pieplines in the command line. -static llvm::cl::opt - passTestSpecifier("pass-test", - llvm::cl::desc("Indicate a specific pass to be tested")); - // Parse and verify the input MLIR file. static LogicalResult loadModule(MLIRContext &context, OwningModuleRef &module, StringRef inputFilename) { @@ -75,16 +65,15 @@ registerMLIRContextCLOptions(); registerPassManagerCLOptions(); + registerAllPasses(); + registerReducerPasses(); + PassPipelineCLParser parser("", "Reduction Passes to Run"); llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR test case reduction tool.\n"); std::string errorMessage; - auto testscript = openInputFile(testFilename, &errorMessage); - if (!testscript) - llvm::report_fatal_error(errorMessage); - auto output = openOutputFile(outputFilename, &errorMessage); if (!output) llvm::report_fatal_error(errorMessage); @@ -100,29 +89,15 @@ if (failed(loadModule(context, moduleRef, inputFilename))) llvm::report_fatal_error("Input test case can't be parsed"); - // Initialize test environment. - const Tester test(testFilename, testArguments); - - if (test.isInteresting(inputFilename) != Tester::Interestingness::True) - llvm::report_fatal_error( - "Input test case does not exhibit interesting behavior"); + auto errorHandler = [&](const Twine &msg) { + emitError(UnknownLoc::get(&context)) << msg << "GG"; + return failure(); + }; // Reduction pass pipeline. PassManager pm(&context); - - if (passTestSpecifier == "DCE") { - - // Opt Reduction Pass with SymbolDCEPass as opt pass. - pm.addPass(std::make_unique(test, &context, - createSymbolDCEPass())); - - } else if (passTestSpecifier == "function-reducer") { - - // Reduction tree pass with Reducer variant generation and single path - // traversal. - pm.addPass(std::make_unique( - FuncOp::getOperationName(), TraversalMode::SinglePath, test)); - } + if (failed(parser.addToPipeline(pm, errorHandler))) + llvm::report_fatal_error("Failed to add pipeline"); ModuleOp m = moduleRef.get().clone();