diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td --- a/mlir/include/mlir/Reducer/Passes.td +++ b/mlir/include/mlir/Reducer/Passes.td @@ -24,14 +24,12 @@ ]; } -def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { +def ReductionTree : Pass<"reduction-tree"> { let summary = "A general reduction tree pass for the MLIR Reduce Tool"; let constructor = "mlir::createReductionTreePass()"; let options = [ - Option<"opReducerName", "op-reducer", "std::string", /* default */"", - "The OpReducer to reduce the module">, Option<"traversalModeId", "traversal-mode", "unsigned", /* default */"0", "The graph traversal mode">, ] # CommonReductionPassOptions.options; diff --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h deleted file mode 100644 --- a/mlir/include/mlir/Reducer/Passes/OpReducer.h +++ /dev/null @@ -1,76 +0,0 @@ -//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the OpReducer class. It defines a variant generator method -// with the purpose of producing different variants by eliminating a -// parameterizable type of operations from the parent module. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H -#define MLIR_REDUCER_PASSES_OPREDUCER_H - -#include - -#include "mlir/Reducer/ReductionNode.h" -#include "mlir/Reducer/Tester.h" - -namespace mlir { - -class OpReducer { -public: - virtual ~OpReducer() = default; - /// According to rangeToKeep, try to reduce the given module. We implicitly - /// number each interesting operation and rangeToKeep indicates that if an - /// operation's number falls into certain range, then we will not try to - /// reduce that operation. - virtual void reduce(ModuleOp module, - ArrayRef rangeToKeep) = 0; - /// Return the number of certain kind of operations that we would like to - /// reduce. This can be used to build a range map to exclude uninterested - /// operations. - virtual int getNumTargetOps(ModuleOp module) const = 0; -}; - -/// Reducer is a helper class to remove potential uninteresting operations from -/// module. -template -class Reducer : public OpReducer { -public: - ~Reducer() override = default; - - int getNumTargetOps(ModuleOp module) const override { - return std::distance(module.getOps().begin(), - module.getOps().end()); - } - - void reduce(ModuleOp module, - ArrayRef rangeToKeep) override { - std::vector opsToRemove; - size_t keepIndex = 0; - - for (auto op : enumerate(module.getOps())) { - int index = op.index(); - if (keepIndex < rangeToKeep.size() && - index == rangeToKeep[keepIndex].second) - ++keepIndex; - if (keepIndex == rangeToKeep.size() || - index < rangeToKeep[keepIndex].first) - opsToRemove.push_back(op.value()); - } - - for (Operation *o : opsToRemove) { - o->dropAllUses(); - o->erase(); - } - } -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Reducer/ReducerRegistry.h b/mlir/include/mlir/Reducer/ReducerRegistry.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/ReducerRegistry.h @@ -0,0 +1,22 @@ +//===- ReducerRegistry.h - // Reducer Callback Registration -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_REDUCERREGISTRY_H +#define MLIR_REDUCER_REDUCERREGISTRY_H + +#include + +namespace mlir { + +class RewritePatternSet; +using ReducerCollectorFunction = std::function; +void registerReducerFunc(const ReducerCollectorFunction &function); + +} // namespace mlir + +#endif // MLIR_REDUCER_REDUCERREGISTRY_H diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -26,6 +26,9 @@ namespace mlir { +class ModuleOp; +class Region; + /// Defines the traversal method options to be used in the reduction tree /// traversal. enum TraversalMode { SinglePath, Backtrack, MultiPath }; @@ -46,11 +49,17 @@ ReductionNode *getParent() const; + ModuleOp getModule(); + + Region &getRegion(); + size_t getSize() const; /// Returns true if the module exhibits the interesting behavior. Tester::Interestingness isInteresting() const; + const std::vector getStartRanges() const; + std::vector getRanges() const; std::vector &getVariants(); @@ -62,6 +71,8 @@ void update(std::pair result); private: + void init(ModuleOp parentModule, Region &parentRegion); + /// A custom BFS iterator. The difference between /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. /// We may explore more neighbors at certain node if we didn't find interested @@ -111,6 +122,14 @@ std::queue visitQueue; }; + /// The top-level module + ModuleOp module; + + /// The region we're reducing + Region *region; + + ReductionNode *parent; + /// The size of module after applying the range constraints. size_t size; @@ -118,7 +137,10 @@ /// interesting behavior. Tester::Interestingness interesting; - ReductionNode *parent; + /// This is the ranges that how we reduce from parent. 'ranges' may be changed + /// as the shape of module has changed. 'startRange' can be used to + /// reconstruct the whole reduction path. + const std::vector startRanges; /// We will only keep the operation with index falls into the ranges. /// For example, number each function in a certain module and then we will @@ -131,6 +153,9 @@ std::vector variants; llvm::SpecificBumpPtrAllocator &allocator; + + // Grant access to 'init'. + friend class ReductionTreePass; }; // Specialized iterator for SinglePath traversal diff --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h --- a/mlir/include/mlir/Reducer/ReductionTreePass.h +++ b/mlir/include/mlir/Reducer/ReductionTreePass.h @@ -19,30 +19,31 @@ #include -#include "PassDetail.h" -#include "ReductionNode.h" -#include "mlir/Reducer/Passes/OpReducer.h" +#include "mlir/Reducer/PassDetail.h" #include "mlir/Reducer/Tester.h" #define DEBUG_TYPE "mlir-reduce" namespace mlir { +class FrozenRewritePatternSet; + /// This class defines the Reduction Tree Pass. It provides a framework to /// to implement a reduction pass using a tree structure to keep track of the /// generated reduced variants. class ReductionTreePass : public ReductionTreeBase { public: ReductionTreePass() = default; - ReductionTreePass(const ReductionTreePass &pass) = default; + ReductionTreePass(const ReductionTreePass &pass); /// Runs the pass instance in the pass pipeline. void runOnOperation() override; private: + void ReduceOp(Region ®ion, const FrozenRewritePatternSet &patterns); + template - ModuleOp findOptimal(ModuleOp module, std::unique_ptr reducer, - ReductionNode *node); + void findOptimal(Region ®ion, const FrozenRewritePatternSet &patterns); }; } // end namespace mlir diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt --- a/mlir/lib/Reducer/CMakeLists.txt +++ b/mlir/lib/Reducer/CMakeLists.txt @@ -1,7 +1,13 @@ add_mlir_library(MLIRReduce + OptReductionPass.cpp + ReductionNode.cpp + ReductionTreePass.cpp Tester.cpp LINK_LIBS PUBLIC MLIRIR + MLIRPass + MLIRRewrite + MLIRTransformUtils ) mlir_check_all_link_libraries(MLIRReduce) diff --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp rename from mlir/tools/mlir-reduce/OptReductionPass.cpp rename to mlir/lib/Reducer/OptReductionPass.cpp diff --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp rename from mlir/tools/mlir-reduce/ReductionNode.cpp rename to mlir/lib/Reducer/ReductionNode.cpp --- a/mlir/tools/mlir-reduce/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Reducer/ReductionNode.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "llvm/ADT/STLExtras.h" #include @@ -23,13 +24,29 @@ using namespace mlir; ReductionNode::ReductionNode( - ReductionNode *parent, std::vector ranges, + ReductionNode *parentNode, std::vector ranges, llvm::SpecificBumpPtrAllocator &allocator) - : size(std::numeric_limits::max()), - interesting(Tester::Interestingness::Untested), - /// Root node will have the parent pointer point to themselves. - parent(parent == nullptr ? this : parent), ranges(ranges), - allocator(allocator) {} + /// Root node will have the parent pointer point to themselves. + : parent(parentNode == nullptr ? this : parentNode), + size(std::numeric_limits::max()), + interesting(Tester::Interestingness::Untested), startRanges(ranges), + ranges(ranges), allocator(allocator) { + if (parent != this) + init(parent->getModule(), parent->getRegion()); +} + +void ReductionNode::init(ModuleOp parentModule, Region &targetRegion) { + // Use the mapper help us find the corresponded region after module clone. + BlockAndValueMapping mapper; + module = cast(parentModule.getOperation()->clone(mapper)); + // Use the first block to locate the cloned region. + Block *block = mapper.lookup(&*targetRegion.begin()); + region = block->getParent(); +} + +ModuleOp ReductionNode::getModule() { return module; } + +Region &ReductionNode::getRegion() { return *region; } /// Returns the size in bytes of the module. size_t ReductionNode::getSize() const { return size; } @@ -41,20 +58,28 @@ return interesting; } +const std::vector ReductionNode::getStartRanges() const { + return startRanges; +} + std::vector ReductionNode::getRanges() const { return ranges; } std::vector &ReductionNode::getVariants() { return variants; } -#include - /// If we haven't explored any variants from this node, we will create N /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// max element in `ranges` and create 2 new variants for each call. std::vector ReductionNode::generateNewVariants() { std::vector newNodes; + auto createNewNode = [this](const std::vector &ranges) { + ReductionNode *newNode = allocator.Allocate(); + new (newNode) ReductionNode(this, ranges, allocator); + return newNode; + }; + // If we haven't created new variant, then we can create varients by removing // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can // produce variants with range {{1, 3}} and {{4, 9}}. @@ -62,8 +87,7 @@ for (const Range &range : ranges) { std::vector subRanges = ranges; llvm::erase_value(subRanges, range); - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, subRanges, allocator); + ReductionNode *newNode = createNewNode(subRanges); newNodes.push_back(newNode); variants.push_back(newNode); } @@ -86,12 +110,6 @@ if (maxElement->second - maxElement->first == 1) return {}; - auto createNewNode = [this](const std::vector &ranges) { - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, ranges, allocator); - return newNode; - }; - Range maxRange = *maxElement; std::vector subRanges = ranges; auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin()); @@ -112,6 +130,14 @@ void ReductionNode::update(std::pair result) { std::tie(interesting, size) = result; + // After applying reduction, the number of operation in the region may have + // changed. Non-interesting case won't be explored thus it's safe to keep it + // in a stale status. + if (interesting == Tester::Interestingness::True) { + // This module may has been updated. Reset the range. + ranges.clear(); + ranges.push_back({0, std::distance(region->op_begin(), region->op_end())}); + } } std::vector diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -0,0 +1,184 @@ +//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Reduction Tree Pass class. It provides a framework for +// the implementation of different reduction passes in the MLIR Reduce tool. It +// allows for custom specification of the variant generation behavior. It +// implements methods that define the different possible traversals of the +// reduction tree. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Reducer/ReductionTreePass.h" +#include "mlir/Reducer/Passes.h" +#include "mlir/Reducer/ReducerRegistry.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ManagedStatic.h" + +#include + +using namespace mlir; + +static llvm::ManagedStatic> + reducerCollector; + +void mlir::registerReducerFunc(const ReducerCollectorFunction &function) { + reducerCollector->push_back(function); +} + +ReductionTreePass::ReductionTreePass(const ReductionTreePass &rhs) { + traversalModeId = rhs.traversalModeId; + testerName = rhs.testerName; + testerArgs = rhs.testerArgs; +} + +void ReductionTreePass::runOnOperation() { + MLIRContext *context = getOperation()->getContext(); + + // Step 1: Try to rewrite op into a simpler form. + RewritePatternSet patterns(context); + for (auto &collect : *reducerCollector) { + collect(patterns); + } + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + for (Region ®ion : getOperation()->getRegions()) { + ReduceOp(region, frozenPatterns); + } + + // Step 2: Try to reduce the nested Regions. + for (Region ®ion : getOperation()->getRegions()) { + for (Operation &op : region.getOps()) { + if (op.getNumRegions() == 0) + continue; + OpPassManager nest(op.getName().getStringRef(), + OpPassManager::Nesting::Implicit); + nest.addPass(std::make_unique(*this)); + if (failed(runPipeline(nest, &op))) + return signalPassFailure(); + } + } +} + +void ReductionTreePass::ReduceOp(Region ®ion, + const FrozenRewritePatternSet &patterns) { + switch (traversalModeId) { + case TraversalMode::SinglePath: + findOptimal>(region, + patterns); + break; + default: + llvm_unreachable("Unsupported mode"); + } +} + +template +void ReductionTreePass::findOptimal(Region ®ion, + const FrozenRewritePatternSet &patterns) { + Region *topRegion = ®ion; + while (topRegion->getParentRegion() != nullptr) { + topRegion = topRegion->getParentRegion(); + } + ModuleOp module = cast(topRegion->getParentOp()); + + llvm::SpecificBumpPtrAllocator allocator; + std::vector ranges{ + {0, std::distance(region.op_begin(), region.op_end())}}; + + ReductionNode *root = allocator.Allocate(); + new (root) ReductionNode(nullptr, ranges, allocator); + root->init(module, region); + + Tester test(testerName, testerArgs); + std::pair initStatus = + test.isInteresting(module); + + if (initStatus.first != Tester::Interestingness::True) { + return; + } + + root->update(initStatus); + + auto applyRewritePatterns = + [&](Region ®ion, const FrozenRewritePatternSet &patterns, + const ArrayRef rangeToKeep) { + std::vector opToRemove; + size_t keepIndex = 0; + for (auto op : enumerate(region.getOps())) { + int index = op.index(); + if (keepIndex < rangeToKeep.size() && + index == rangeToKeep[keepIndex].second) + ++keepIndex; + if (keepIndex == rangeToKeep.size() || + index < rangeToKeep[keepIndex].first) + opToRemove.push_back(&op.value()); + else + // The operation is in the range we would like to keep, try to + // rewrite it into smaller form. + (void)applyOpPatternsAndFold(&op.value(), patterns); + } + + // The operations not in the rangeToKeep means we are supposed to ignore + // them, so just erease it without futher verfication. + for (Operation *op : opToRemove) { + op->dropAllUses(); + op->erase(); + } + }; + + ReductionNode *smallestNode = root; + + IteratorType iter(root); + + while (iter != IteratorType::end()) { + ReductionNode ¤tNode = *iter; + Region &curRegion = currentNode.getRegion(); + + applyRewritePatterns(curRegion, patterns, currentNode.getRanges()); + + std::pair result = + test.isInteresting(currentNode.getModule()); + currentNode.update(result); + + if (result.first == Tester::Interestingness::True && + result.second < smallestNode->getSize()) { + smallestNode = ¤tNode; + } + + ++iter; + } + + // At here, we have found a optimal path to reduce region. Retrieve the path + // and apply the reducer to the original region. + std::stack trace; + ReductionNode *curNode = smallestNode; + do { + trace.push(curNode); + curNode = curNode->getParent(); + } while (curNode != root); + + while (!trace.empty()) { + ReductionNode *top = trace.top(); + trace.pop(); + applyRewritePatterns(region, patterns, top->getStartRanges()); + } + + if (test.isInteresting(module).first != Tester::Interestingness::True) + llvm::report_fatal_error("Reduced module is not interesting"); + if (test.isInteresting(module).second != smallestNode->getSize()) + llvm::report_fatal_error( + "Reduced module doesn't have consistent size with smallestNode"); +} + +std::unique_ptr mlir::createReductionTreePass() { + return std::make_unique(); +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -60,6 +60,7 @@ MLIRInferTypeOpInterface MLIRLinalgTransforms MLIRPass + MLIRReduce MLIRStandard MLIRStandardOpsTransforms MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -47,6 +47,7 @@ namespace mlir { namespace test { void registerTestDialect(DialectRegistry ®istry); +void registerTestReducerCollector(); } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2083,4 +2083,19 @@ let results = (outs AnyType:$res); } +//===----------------------------------------------------------------------===// +// Test Reducer Patterns +//===----------------------------------------------------------------------===// + +def OpCrashLong : TEST_Op<"op_crash_long"> { + let arguments = (ins I32, I32, I32); + let results = (outs I32); +} + +def OpCrashShort : TEST_Op<"op_crash_short"> { + let results = (outs I32); +} + +def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>; + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" +#include "mlir/Reducer/ReducerRegistry.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -49,6 +50,10 @@ #include "TestPatterns.inc" } // end anonymous namespace +void mlir::test::registerTestReducerCollector() { + registerReducerFunc(populateWithGenerated); +} + //===----------------------------------------------------------------------===// // Canonicalizer Driver. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Reducer/MLIRTestReducer.cpp b/mlir/test/lib/Reducer/MLIRTestReducer.cpp --- a/mlir/test/lib/Reducer/MLIRTestReducer.cpp +++ b/mlir/test/lib/Reducer/MLIRTestReducer.cpp @@ -38,7 +38,7 @@ op.walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); - if (opName == "test.crashOp") { + if (opName.find("op_crash") != StringRef::npos) { llvm::errs() << "MLIR Reducer Test generated failure: Found " "\"crashOp\" operation\n"; exit(1); diff --git a/mlir/test/mlir-reduce/crashop-reduction.mlir b/mlir/test/mlir-reduce/crashop-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/crashop-reduction.mlir @@ -0,0 +1,20 @@ +// UNSUPPORTED: system-windows +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short". + +// CHECK-NOT: func @simple1() { +func @simple1() { + return +} + +// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { +func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32 + %0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32 + return +} + +// CHECK-NOT: func @simple5() { +func @simple5() { + return +} diff --git a/mlir/test/mlir-reduce/dce-test.mlir b/mlir/test/mlir-reduce/dce-test.mlir --- a/mlir/test/mlir-reduce/dce-test.mlir +++ b/mlir/test/mlir-reduce/dce-test.mlir @@ -12,6 +12,6 @@ // CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } diff --git a/mlir/test/mlir-reduce/multiple-function.mlir b/mlir/test/mlir-reduce/multiple-function.mlir --- a/mlir/test/mlir-reduce/multiple-function.mlir +++ b/mlir/test/mlir-reduce/multiple-function.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s // This input should be reduced by the pass pipeline so that only // the @simple5 function remains as this is the shortest function // containing the interesting behavior. @@ -16,7 +16,7 @@ // CHECK-LABEL: func @simple3() { func @simple3() { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } @@ -29,7 +29,7 @@ %0 = memref.alloc() : memref<2xf32> br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): - "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } diff --git a/mlir/test/mlir-reduce/simple-test.mlir b/mlir/test/mlir-reduce/simple-test.mlir --- a/mlir/test/mlir-reduce/simple-test.mlir +++ b/mlir/test/mlir-reduce/simple-test.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh' +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh' func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { cond_br %arg0, ^bb1, ^bb2 diff --git a/mlir/test/mlir-reduce/single-function.mlir b/mlir/test/mlir-reduce/single-function.mlir --- a/mlir/test/mlir-reduce/single-function.mlir +++ b/mlir/test/mlir-reduce/single-function.mlir @@ -2,6 +2,6 @@ // RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer func @test() { - "test.crashOp"() : () -> () + "test.op_crash"() : () -> () return } diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -44,9 +44,6 @@ ) add_llvm_tool(mlir-reduce - OptReductionPass.cpp - ReductionNode.cpp - ReductionTreePass.cpp mlir-reduce.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp deleted file mode 100644 --- a/mlir/tools/mlir-reduce/ReductionTreePass.cpp +++ /dev/null @@ -1,107 +0,0 @@ -//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the Reduction Tree Pass class. It provides a framework for -// the implementation of different reduction passes in the MLIR Reduce tool. It -// allows for custom specification of the variant generation behavior. It -// implements methods that define the different possible traversals of the -// reduction tree. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Passes.h" - -#include "llvm/Support/Allocator.h" - -using namespace mlir; - -static std::unique_ptr getOpReducer(llvm::StringRef opType) { - if (opType == ModuleOp::getOperationName()) - return std::make_unique>(); - else if (opType == FuncOp::getOperationName()) - return std::make_unique>(); - llvm_unreachable("Now only supports two built-in ops"); -} - -void ReductionTreePass::runOnOperation() { - ModuleOp module = this->getOperation(); - std::unique_ptr reducer = getOpReducer(opReducerName); - std::vector> ranges = { - {0, reducer->getNumTargetOps(module)}}; - - llvm::SpecificBumpPtrAllocator allocator; - - ReductionNode *root = allocator.Allocate(); - new (root) ReductionNode(nullptr, ranges, allocator); - - ModuleOp golden = module; - switch (traversalModeId) { - case TraversalMode::SinglePath: - golden = findOptimal>( - module, std::move(reducer), root); - break; - default: - llvm_unreachable("Unsupported mode"); - } - - if (golden != module) { - module.getBody()->clear(); - module.getBody()->getOperations().splice(module.getBody()->begin(), - golden.getBody()->getOperations()); - golden->destroy(); - } -} - -template -ModuleOp ReductionTreePass::findOptimal(ModuleOp module, - std::unique_ptr reducer, - ReductionNode *root) { - Tester test(testerName, testerArgs); - std::pair initStatus = - test.isInteresting(module); - - if (initStatus.first != Tester::Interestingness::True) { - LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested"); - return module; - } - - root->update(initStatus); - - ReductionNode *smallestNode = root; - ModuleOp golden = module; - - IteratorType iter(root); - - while (iter != IteratorType::end()) { - ModuleOp cloneModule = module.clone(); - - ReductionNode ¤tNode = *iter; - reducer->reduce(cloneModule, currentNode.getRanges()); - - std::pair result = - test.isInteresting(cloneModule); - currentNode.update(result); - - if (result.first == Tester::Interestingness::True && - result.second < smallestNode->getSize()) { - smallestNode = ¤tNode; - golden = cloneModule; - } else { - cloneModule->destroy(); - } - - ++iter; - } - - return golden; -} - -std::unique_ptr mlir::createReductionTreePass() { - return std::make_unique(); -} diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp --- a/mlir/tools/mlir-reduce/mlir-reduce.cpp +++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -22,7 +22,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Reducer/OptReductionPass.h" #include "mlir/Reducer/Passes.h" -#include "mlir/Reducer/Passes/OpReducer.h" +#include "mlir/Reducer/ReducerRegistry.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionTreePass.h" #include "mlir/Reducer/Tester.h" @@ -37,6 +37,7 @@ namespace mlir { namespace test { void registerTestDialect(DialectRegistry &); +void registerTestReducerCollector(); } // namespace test } // namespace mlir @@ -82,6 +83,7 @@ registerAllDialects(registry); #ifdef MLIR_INCLUDE_TESTS mlir::test::registerTestDialect(registry); + mlir::test::registerTestReducerCollector(); #endif mlir::MLIRContext context(registry); @@ -95,6 +97,9 @@ // Reduction pass pipeline. PassManager pm(&context); + // We may generate invalid operations during reduction. Turn off the verifier + // and let the tester to determine whether we need a verified output. + pm.enableVerifier(false); if (failed(parser.addToPipeline(pm, errorHandler))) llvm::report_fatal_error("Failed to add pipeline");