diff --git a/mlir/include/mlir/Reducer/OptReductionPass.h b/mlir/include/mlir/Reducer/OptReductionPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/OptReductionPass.h @@ -0,0 +1,61 @@ +//===- OptReductionPass.h - Optimization Reduction Pass Wrapper -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Opt Reduction Pass Wrapper. It creates a pass to run +// any optimization pass within it and only replaces the output module with the +// transformed version if it is smaller and interesting. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_OPTREDUCTIONPASS_H +#define MLIR_REDUCER_OPTREDUCTIONPASS_H + +#include + +#include "PassDetail.h" +#include "ReductionNode.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Reducer/Passes/FuncReducer.h" +#include "mlir/Reducer/ReductionTreePass.h" +#include "mlir/Reducer/Tester.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { + +class OptReductionPass : public OptReductionBase { +public: + OptReductionPass(Tester test, MLIRContext *context, + std::unique_ptr optPass); + + OptReductionPass(const OptReductionPass &srcPass); + + /// Runs the pass instance in the pass pipeline. + void runOnOperation() override; + +private: + // Points to the reduction node containing the input module to this pass. + std::unique_ptr root; + + // Points to the context to be used in the pass manager. + MLIRContext *context; + + // This is used to test the interesting behavior of the transformed module. + Tester test; + + // Points to the mlir-opt to be used to transform the module. + std::unique_ptr optPass; + + /// Runs the optimization pass on the module and returns the pointer to a + /// reduction node with the transformed module. + ReductionNode *singleTransform(); +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/Passes.td @@ -0,0 +1,28 @@ +//===-- Passes.td - MLIR Reduce pass definition file -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the passes for the MLIR Reduce Tool. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_PASSES +#define MLIR_REDUCER_PASSES + +include "mlir/Pass/PassBase.td" + +def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { + let summary = "A general reduction tree pass for the MLIR Reduce Tool"; + let constructor = "mlir::createReductionTreePass()"; +} + +def OptReduction : Pass<"opt-reduction-pass", "ModuleOp"> { + let summary = "A reduction pass wrapper for mlir-opt passes"; + let constructor = "mlir::createOptReductionPass()"; +} + +#endif // MLIR_REDUCE_PASSES diff --git a/mlir/test/mlir-reduce/dce-test.mlir b/mlir/test/mlir-reduce/dce-test.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/dce-test.mlir @@ -0,0 +1,17 @@ +// UNSUPPORTED: -windows- +// RUN: mlir-reduce %s -test %S/failure-test.sh | FileCheck %s +// This input should be reduced by the pass pipeline so that only +// the @simple1 function remains as the other fucntions should be +// removed by the dead code elimination pass. +// CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + +// CHECK-NOT: func @dead_nested_function +func @dead_private_function() attributes { sym_visibility = "private" } + +// CHECK-NOT: func @dead_nested_function +func @dead_nested_function() attributes { sym_visibility = "nested" } + +func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + "test.crashOp" () : () -> () + return +} diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -32,10 +32,20 @@ ) add_llvm_tool(mlir-reduce + OptReductionPass.cpp + Passes/FuncReducer.cpp + ReductionNode.cpp + ReductionTreePass.cpp mlir-reduce.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Reducer + + DEPENDS + MLIRReducerIncGen ) target_link_libraries(mlir-reduce PRIVATE ${LIBS}) llvm_update_compile_flags(mlir-reduce) -mlir_check_all_link_libraries(mlir-reduce) \ No newline at end of file +mlir_check_all_link_libraries(mlir-reduce) diff --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/tools/mlir-reduce/OptReductionPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-reduce/OptReductionPass.cpp @@ -0,0 +1,62 @@ +//===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Opt Reduction Pass class. It creates a pass to run +// any optimization pass within it and only replaces the output module with the +// transformed version if it is smaller and interesting. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Reducer/OptReductionPass.h" + +using namespace mlir; + +OptReductionPass::OptReductionPass(Tester test, MLIRContext *context, + std::unique_ptr optPass) + : context(context), test(test), optPass(std::move(optPass)) {} + +OptReductionPass::OptReductionPass(const OptReductionPass &srcPass) + : root(new ReductionNode(srcPass.root->getModule().clone())), + test(srcPass.test), optPass(srcPass.optPass.get()) {} + +/// Runs the pass instance in the pass pipeline. +void OptReductionPass::runOnOperation() { + ModuleOp module = this->getOperation(); + this->root = std::make_unique(module); + ReductionNode *transformed; + ReductionTreeUtils utils; + + transformed = singleTransform(); + + if (transformed != nullptr) + utils.updateGoldenModule(module, transformed->getModule().clone()); +} + +/// Runs the optimization pass on the module and returns the pointer to a +/// reduction node with the transformed module. +ReductionNode *OptReductionPass::singleTransform() { + ReductionNode *initNode = root.get(); + ModuleOp moduleVariant = initNode->getModule().clone(); + + PassManager pmTransform(context); + pmTransform.addPass(std::move(optPass)); + if (failed(pmTransform.run(moduleVariant))) + return nullptr; + + ReductionNode *currVariant = new ReductionNode(moduleVariant); + root.get()->linkVariant(currVariant); + initNode->organizeVariants(test); + if (!initNode->variantsEmpty()) { + ReductionNode *transformedNode = initNode->getVariant(0); + + if (transformedNode->getSize() <= initNode->getSize()) + return transformedNode; + } + + return initNode; +} diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp --- a/mlir/tools/mlir-reduce/mlir-reduce.cpp +++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -19,6 +19,9 @@ #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Reducer/OptReductionPass.h" +#include "mlir/Reducer/ReductionNode.h" +#include "mlir/Reducer/ReductionTreePass.h" #include "mlir/Reducer/Tester.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" @@ -84,14 +87,30 @@ // Initialize test environment. Tester test(testFilename, testArguments); - test.setMostReduced(moduleRef.get()); if (!test.isInteresting(inputFilename)) llvm::report_fatal_error( "Input test case does not exhibit interesting behavior"); - test.getMostReduced().print(output->os()); + // Reduction pass pipeline. + PassManager pm(&context); + + // Reduction tree pass with FuncReducer variant generation and single path + // traversal. + pm.addPass( + std::make_unique>(test)); + + // Opt Reduction Pass with SymbolDCEPass as opt pass. + pm.addPass(std::make_unique(test, &context, + createSymbolDCEPass())); + + ModuleOp m = moduleRef.get().clone(); + + if (failed(pm.run(m))) + llvm::report_fatal_error("Error running the reduction pass pipeline"); + + m.print(output->os()); output->keep(); return 0; -} \ No newline at end of file +}