diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(Dialect) add_subdirectory(IR) add_subdirectory(Interfaces) +add_subdirectory(Reducer) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Reducer/CMakeLists.txt b/mlir/include/mlir/Reducer/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(MLIRReducerIncGen) + +add_mlir_doc(Passes -gen-pass-doc ReducerPasses ./) diff --git a/mlir/include/mlir/Reducer/PassDetail.h b/mlir/include/mlir/Reducer/PassDetail.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/PassDetail.h @@ -0,0 +1,21 @@ +//===- PassDetail.h - Reducer Pass class details ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_PASSDETAIL_H +#define MLIR_REDUCER_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +#define GEN_PASS_CLASSES +#include "mlir/Reducer/Passes.h.inc" + +} // end namespace mlir + +#endif // MLIR_REDUCER_PASSDETAIL_H diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/Passes.td @@ -0,0 +1,23 @@ +//===-- Passes.td - MLIR Reduce pass definition file -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the passes for the MLIR Reduce Tool. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_PASSES +#define MLIR_REDUCER_PASSES + +include "mlir/Pass/PassBase.td" + +def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { + let summary = "A general reduction tree pass for the MLIR Reduce Tool"; + let constructor = "mlir::createReductionTreePass()"; +} + +#endif // MLIR_REDUCE_PASSES diff --git a/mlir/include/mlir/Reducer/Passes/FuncReducer.h b/mlir/include/mlir/Reducer/Passes/FuncReducer.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/Passes/FuncReducer.h @@ -0,0 +1,40 @@ +//===- FuncReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the FuncReducer class. It defines a variant generator +// method with the purpose of producing different variants by eliminating +// operations from the parent module. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_PASSES_FUNCREDUCER_H +#define MLIR_REDUCER_PASSES_FUNCREDUCER_H + +#include "mlir/Reducer/ReductionNode.h" +#include "mlir/Reducer/Tester.h" + +namespace mlir { + +/// The FuncReducer class defines a variant generator method that produces +/// multiple variants by eliminating different operations from the +/// parent module. +class FuncReducer { +public: + /// Generate variants by removing operations from the module in the parent + /// Reduction Node and link the variants as children in the Reduction Tree + /// Pass. + void generateVariants(ReductionNode *parent, const Tester *test); + +private: + /// Iterate over the body of a module and return the number of operations. + static int countOps(ModuleOp module); +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -0,0 +1,88 @@ +//===- ReductionNode.h - Reduction Node Implementation ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the reduction nodes which are used to track of the metadata +// for a specific generated variant within a reduction pass and are the building +// blocks of the reduction tree structure. A reduction tree is used to keep +// track of the different generated variants throughout a reduction pass in the +// MLIR Reduce tool. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_REDUCTIONNODE_H +#define MLIR_REDUCER_REDUCTIONNODE_H + +#include + +#include "mlir/Reducer/Tester.h" +#include "llvm/Support/ToolOutputFile.h" + +namespace mlir { + +/// This class defines the ReductionNode which is used to wrap the module of +/// a generated variant and keep track of the necessary metadata for the +/// reduction pass. The nodes are linked together in a reduction tree stucture +/// which defines the relationship between all the different generated variants. +class ReductionNode { +public: + ReductionNode(ModuleOp module, ReductionNode *parent); + + /// Calculates and initializes the size and interesting values of the node. + void measureAndTest(const Tester *test); + + /// Returns the module. + ModuleOp getModule() const { return module; } + + /// Returns true if the size and interestingness have been calculated. + bool isEvaluated() const; + + /// Returns the size in bytes of the module. + int getSize() const; + + /// Returns true if the module exhibits the interesting behavior. + bool isInteresting() const; + + /// Returns the pointer to a child variant by index. + ReductionNode *getVariant(unsigned long index) const; + + /// Returns true if the vector containing the child variants is empty. + bool variantsEmpty() const; + + /// Sort the child variants and remove the uninteresting ones. + void organizeVariants(const Tester *test); + +private: + // This is the MLIR module of this variant. + ModuleOp module; + + // This is true if the module has been evaluated and it exhibits the + // interesting behavior. + bool interesting; + + // This indicates the number of characters in the printed module if the module + // has been evaluated. + int size; + + // This indicates if the module has been evalueated (measured and tested). + bool evaluated; + + // This points to the ReductionNode that was used as a starting point to + // create this variant. It is null if the reduction node is the root. + ReductionNode *parent; + + // This points to the child variants that were created using this node as a + // starting point. + std::vector> variants; + + /// Link a child variant node. + void linkVariant(ReductionNode *newVariant); +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/ReductionTreePass.h @@ -0,0 +1,107 @@ +//===- ReductionTreePass.h - Reduction Tree Pass Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Reduction Tree Pass class. It provides a framework for +// the implementation of different reduction passes in the MLIR Reduce tool. It +// allows for custom specification of the variant generation behavior. It +// implements methods that define the different possible traversals of the +// reduction tree. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H +#define MLIR_REDUCER_REDUCTIONTREEPASS_H + +#include + +#include "PassDetail.h" +#include "ReductionNode.h" +#include "mlir/Reducer/Passes/FuncReducer.h" +#include "mlir/Reducer/Tester.h" + +namespace mlir { + +/// Defines the traversal method options to be used in the reduction tree +/// traversal. +enum TraversalMode { SinglePath, MultiPath, Concurrent, Backtrack }; + +// This class defines the non- templated utilities used by the ReductionTreePass +// class. +class ReductionTreeUtils { +public: + void updateGoldenModule(ModuleOp &golden, ModuleOp reduced); +}; + +/// This class defines the Reduction Tree Pass. It provides a framework to +/// to implement a reduction pass using a tree structure to keep track of the +/// generated reduced variants. +template +class ReductionTreePass + : public ReductionTreeBase> { +public: + ReductionTreePass(const Tester *test) : test(test) {} + + ReductionTreePass(const ReductionTreePass &pass) + : root(new ReductionNode(pass.root->getModule().clone(), nullptr)), + test(pass.test) {} + + /// Runs the pass instance in the pass pipeline. + void runOnOperation() override { + ModuleOp module = this->getOperation(); + this->root = std::make_unique(module, nullptr); + ReductionNode *reduced; + + switch (mode) { + case SinglePath: + reduced = singlePathTraversal(); + break; + default: + llvm::report_fatal_error("Traversal method not currently supported."); + break; + } + + ReductionTreeUtils utils; + utils.updateGoldenModule(module, reduced->getModule()); + } + +private: + // Points to the root node in this reduction tree. + std::unique_ptr root; + + // This object defines the variant generation at each level of the reduction + // tree. + Reducer reducer; + + // This is used to test the interesting behavior of the reduction nodes in the + // tree. + const Tester *test; + + /// Traverse the most reduced path in the reduction tree by generating the + /// variants at each level using the Reducer parameter's generateVariants + /// function. Stops when no new successful variants can be created at the + /// current level. + ReductionNode *singlePathTraversal() { + ReductionNode *currLevel = root.get(); + + while (true) { + reducer.generateVariants(currLevel, test); + currLevel->organizeVariants(test); + + if (currLevel->variantsEmpty()) + break; + + currLevel = currLevel->getVariant(0); + } + + return currLevel; + } +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h --- a/mlir/include/mlir/Reducer/Tester.h +++ b/mlir/include/mlir/Reducer/Tester.h @@ -9,8 +9,8 @@ // This file defines the Tester class used in the MLIR Reduce tool. // // A Tester object is passed as an argument to the reduction passes and it is -// used to keep track of the state of the reduction throughout the multiple -// passes. +// used to run the interestigness testing script on the different generated +// reduced variants of the test case. // //===----------------------------------------------------------------------===// @@ -27,9 +27,9 @@ namespace mlir { -/// This class is used to keep track of the state of the reduction. It contains -/// a method to run the interestingness testing script on MLIR test case files -/// and provides functionality to track the most reduced test case. +/// This class is used to keep track of the testing environment of the tool. It +/// contains a method to run the interestingness testing script on a MLIR test +/// case file. class Tester { public: Tester(StringRef testScript, ArrayRef testScriptArgs); @@ -37,23 +37,13 @@ /// Runs the interestingness testing script on a MLIR test case file. Returns /// true if the interesting behavior is present in the test case or false /// otherwise. - bool isInteresting(StringRef testCase); - - /// Returns the most reduced MLIR test case module. - ModuleOp getMostReduced() const { return mostReduced; } - - /// Updates the most reduced MLIR test case module. If a - /// generated variant is found to be successful and shorter than the - /// mostReduced module, the mostReduced module must be updated with the new - /// variant. - void setMostReduced(ModuleOp t) { mostReduced = t; } + bool isInteresting(StringRef testCase) const; private: StringRef testScript; ArrayRef testScriptArgs; - ModuleOp mostReduced; }; } // end namespace mlir -#endif \ No newline at end of file +#endif diff --git a/mlir/lib/Reducer/Tester.cpp b/mlir/lib/Reducer/Tester.cpp --- a/mlir/lib/Reducer/Tester.cpp +++ b/mlir/lib/Reducer/Tester.cpp @@ -9,8 +9,8 @@ // This file defines the Tester class used in the MLIR Reduce tool. // // A Tester object is passed as an argument to the reduction passes and it is -// used to keep track of the state of the reduction throughout the multiple -// passes. +// used to run the interestigness testing script on the different generated +// reduced variants of the test case. // //===----------------------------------------------------------------------===// @@ -24,7 +24,7 @@ /// Runs the interestingness testing script on a MLIR test case file. Returns /// true if the interesting behavior is present in the test case or false /// otherwise. -bool Tester::isInteresting(StringRef testCase) { +bool Tester::isInteresting(StringRef testCase) const { std::vector testerArgs; testerArgs.push_back(testCase); @@ -32,6 +32,8 @@ for (const std::string &arg : testScriptArgs) testerArgs.push_back(arg); + testerArgs.push_back(testCase); + std::string errMsg; int result = llvm::sys::ExecuteAndWait( testScript, testerArgs, /*Env=*/None, /*Redirects=*/None, diff --git a/mlir/test/mlir-reduce/failure-test.sh b/mlir/test/mlir-reduce/failure-test.sh new file mode 100755 --- /dev/null +++ b/mlir/test/mlir-reduce/failure-test.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Tests for the keyword "failure" in the stderr of the optimization pass +mlir-opt $1 -test-mlir-reducer > /tmp/stdout.$$ 2>/tmp/stderr.$$ + +if [ $? -ne 0 ] && grep 'failure' /tmp/stderr.$$; then + exit 1 + #Interesting behavior +else + exit 0 +fi diff --git a/mlir/test/mlir-reduce/reduction-tree-pass.mlir b/mlir/test/mlir-reduce/reduction-tree-pass.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/reduction-tree-pass.mlir @@ -0,0 +1,56 @@ +// UNSUPPORTED: -windows- +// RUN: mlir-reduce %s -test %S/failure-test.sh | FileCheck %s +// This input should be reduced by the pass pipeline so that only +// the @simple5 function remains as this is the shortest function +// containing the interesting behavior. +// CHECK-LABEL: func @simple5(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + +func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 +^bb1: + br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = alloc() : memref<2xf32> + br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + return +} + +func @simple2(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 +^bb1: + br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = alloc() : memref<2xf32> + br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + return +} + +func @simple3(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 +^bb1: + br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = alloc() : memref<2xf32> + br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + return +} + +func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + cond_br %arg0, ^bb1, ^bb2 +^bb1: + br ^bb3(%arg1 : memref<2xf32>) +^bb2: + %0 = alloc() : memref<2xf32> + br ^bb3(%0 : memref<2xf32>) +^bb3(%1: memref<2xf32>): + return +} + +func @simple5(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { + "test.crashOp" () : () -> () + return +} diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -32,10 +32,19 @@ ) add_llvm_tool(mlir-reduce + Passes/FuncReducer.cpp + ReductionNode.cpp + ReductionTreePass.cpp mlir-reduce.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Reducer + + DEPENDS + MLIRReducerIncGen ) target_link_libraries(mlir-reduce PRIVATE ${LIBS}) llvm_update_compile_flags(mlir-reduce) -mlir_check_all_link_libraries(mlir-reduce) \ No newline at end of file +mlir_check_all_link_libraries(mlir-reduce) diff --git a/mlir/tools/mlir-reduce/Passes/FuncReducer.cpp b/mlir/tools/mlir-reduce/Passes/FuncReducer.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-reduce/Passes/FuncReducer.cpp @@ -0,0 +1,72 @@ +//===- FuncReducer.cpp - MLIR Reduce Operation Reducer Variant Generator --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the FuncReducer class. It defines a variant generator class +// to be used in a Reduction Tree Pass instantiation with the aim of reducing +// the number of function operations in an MLIR Module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Reducer/Passes/FuncReducer.h" +#include "mlir/IR/Function.h" + +using namespace mlir; + +/// Return the number of function operations in the module's body. +int FuncReducer::countOps(ModuleOp module) { + auto ops = module.getOps(); + return std::distance(ops.begin(), ops.end()); +} + +/// Generate variants by removing function operations from the module in the +/// parent and link the variants as childs in the Reduction Tree Pass. +void FuncReducer::generateVariants(ReductionNode *parent, const Tester *test) { + ModuleOp module = parent->getModule(); + int numVariants = 2; + int opCount = countOps(module); + int sectionSize = opCount / numVariants; + std::vector opsToRemove; + + if (opCount == 0) + return; + + // Create a variant by deleting all ops. + if (opCount == 1) { + opsToRemove.clear(); + ModuleOp moduleVariant = module.clone(); + + for (FuncOp op : moduleVariant.getOps()) + opsToRemove.push_back(op); + + for (Operation *o : opsToRemove) + o->erase(); + + new ReductionNode(moduleVariant, parent); + + return; + } + + // Create two variants by bisecting the module. + for (int i = 0; i < numVariants; ++i) { + opsToRemove.clear(); + ModuleOp moduleVariant = module.clone(); + + for (auto op : enumerate(moduleVariant.getOps())) { + int index = op.index(); + if (index >= sectionSize * i && index < sectionSize * (i + 1)) + opsToRemove.push_back(op.value()); + } + + for (Operation *o : opsToRemove) + o->erase(); + + new ReductionNode(moduleVariant, parent); + } + + return; +} diff --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/tools/mlir-reduce/ReductionNode.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-reduce/ReductionNode.cpp @@ -0,0 +1,99 @@ +//===- ReductionNode.cpp - Reduction Node Implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the reduction nodes which are used to track of the +// metadata for a specific generated variant within a reduction pass and are the +// building blocks of the reduction tree structure. A reduction tree is used to +// keep track of the different generated variants throughout a reduction pass in +// the MLIR Reduce tool. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Reducer/ReductionNode.h" + +using namespace mlir; + +/// Sets up the metadata and links the node to its parent. +ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent) + : module(module), evaluated(false), parent(parent) { + + if (parent != nullptr) { + parent->linkVariant(this); + } +} + +/// Calculates and updates the size and interesting values of the module. +void ReductionNode::measureAndTest(const Tester *test) { + SmallString<128> filepath; + int fd; + + // Print module to temprary file. + std::error_code ec = + llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath); + + if (ec) + llvm::report_fatal_error("Error making unique filename: " + ec.message()); + + llvm::ToolOutputFile out(filepath, fd); + module.print(out.os()); + out.os().close(); + + if (out.os().has_error()) + llvm::report_fatal_error("Error emitting bitcode to file '" + filepath); + + size = out.os().tell(); + interesting = (*test).isInteresting(filepath); + evaluated = true; +} + +/// Returns true if the size and interestingness have been calculated. +bool ReductionNode::isEvaluated() const { return evaluated; } + +/// Returns the size in bytes of the module. +int ReductionNode::getSize() const { return size; } + +/// Returns true if the module exhibits the interesting behavior. +bool ReductionNode::isInteresting() const { return interesting; } + +/// Returns the pointers to the child variants. +ReductionNode *ReductionNode::getVariant(unsigned long index) const { + if (index < variants.size()) + return variants[index].get(); + + return nullptr; +} + +/// Returns true if the child variants vector is empty. +bool ReductionNode::variantsEmpty() const { return variants.empty(); } + +/// Link a child variant node. +void ReductionNode::linkVariant(ReductionNode *newVariant) { + std::unique_ptr ptrVariant(newVariant); + variants.push_back(std::move(ptrVariant)); +} + +/// Sort the child variants and remove the uninteresting ones. +void ReductionNode::organizeVariants(const Tester *test) { + for (auto &var : variants) + if (!var->isEvaluated()) + var->measureAndTest(test); + + // Remove uninteresting variants + auto it = variants.begin(); + while (it != variants.end()) + if (!(*it)->isInteresting()) { + variants.erase(it); + } else { + ++it; + } + + llvm::array_pod_sort(variants.begin(), variants.end(), + [](const auto *lhs, const auto *rhs) { + return (lhs->get()->getSize(), rhs->get()->getSize()); + }); +} diff --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-reduce/ReductionTreePass.cpp @@ -0,0 +1,39 @@ +//===- ReductionTreePass.cpp - Reduction Tree Pass Implementation ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Reduction Tree Pass. It provides a framework for +// the implementation of different reduction passes in the MLIR Reduce tool. It +// allows for custom specification of the variant generation behavior. It +// implements methods that define the different possible traversals of the +// reduction tree. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Reducer/ReductionTreePass.h" + +using namespace mlir; + +/// Update the golden module's content with that of the reduced module. +void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden, + ModuleOp reduced) { + std::vector opsToDelete; + + // Clear golden module body. + for (auto &op : golden) + if (!op.isKnownTerminator()) + opsToDelete.push_back(&op); + + for (auto *op : opsToDelete) + op->erase(); + + // Insert new operations into golden module. + opsToDelete.clear(); + for (auto &op : reduced) + if (!op.isKnownTerminator()) + golden.push_back(op.clone()); +} diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp --- a/mlir/tools/mlir-reduce/mlir-reduce.cpp +++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -19,6 +19,8 @@ #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Reducer/ReductionNode.h" +#include "mlir/Reducer/ReductionTreePass.h" #include "mlir/Reducer/Tester.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" @@ -83,15 +85,28 @@ llvm::report_fatal_error("Input test case can't be parsed"); // Initialize test environment. - Tester test(testFilename, testArguments); - test.setMostReduced(moduleRef.get()); + const Tester test(testFilename, testArguments); + const Tester *testRef = &test; if (!test.isInteresting(inputFilename)) llvm::report_fatal_error( "Input test case does not exhibit interesting behavior"); - test.getMostReduced().print(output->os()); + // Reduction pass pipeline. + PassManager pm(&context); + + // Reduction tree pass with OpReducer variant generation and single path + // traversal. + pm.addPass( + std::make_unique>(testRef)); + + ModuleOp m = moduleRef.get().clone(); + + if (failed(pm.run(m))) + llvm::report_fatal_error("Error running the reduction pass pipeline"); + + m.print(output->os()); output->keep(); return 0; -} \ No newline at end of file +}