diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td --- a/mlir/include/mlir/Reducer/Passes.td +++ b/mlir/include/mlir/Reducer/Passes.td @@ -24,14 +24,12 @@ ]; } -def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { +def ReductionTree : Pass<"reduction-tree"> { let summary = "A general reduction tree pass for the MLIR Reduce Tool"; let constructor = "mlir::createReductionTreePass()"; let options = [ - Option<"opReducerName", "op-reducer", "std::string", /* default */"", - "The OpReducer to reduce the module">, Option<"traversalModeId", "traversal-mode", "unsigned", /* default */"0", "The graph traversal mode">, ] # CommonReductionPassOptions.options; diff --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h deleted file mode 100644 --- a/mlir/include/mlir/Reducer/Passes/OpReducer.h +++ /dev/null @@ -1,76 +0,0 @@ -//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the OpReducer class. It defines a variant generator method -// with the purpose of producing different variants by eliminating a -// parameterizable type of operations from the parent module. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H -#define MLIR_REDUCER_PASSES_OPREDUCER_H - -#include - -#include "mlir/Reducer/ReductionNode.h" -#include "mlir/Reducer/Tester.h" - -namespace mlir { - -class OpReducer { -public: - virtual ~OpReducer() = default; - /// According to rangeToKeep, try to reduce the given module. We implicitly - /// number each interesting operation and rangeToKeep indicates that if an - /// operation's number falls into certain range, then we will not try to - /// reduce that operation. - virtual void reduce(ModuleOp module, - ArrayRef rangeToKeep) = 0; - /// Return the number of certain kind of operations that we would like to - /// reduce. This can be used to build a range map to exclude uninterested - /// operations. - virtual int getNumTargetOps(ModuleOp module) const = 0; -}; - -/// Reducer is a helper class to remove potential uninteresting operations from -/// module. -template -class Reducer : public OpReducer { -public: - ~Reducer() override = default; - - int getNumTargetOps(ModuleOp module) const override { - return std::distance(module.getOps().begin(), - module.getOps().end()); - } - - void reduce(ModuleOp module, - ArrayRef rangeToKeep) override { - std::vector opsToRemove; - size_t keepIndex = 0; - - for (auto op : enumerate(module.getOps())) { - int index = op.index(); - if (keepIndex < rangeToKeep.size() && - index == rangeToKeep[keepIndex].second) - ++keepIndex; - if (keepIndex == rangeToKeep.size() || - index < rangeToKeep[keepIndex].first) - opsToRemove.push_back(op.value()); - } - - for (Operation *o : opsToRemove) { - o->dropAllUses(); - o->erase(); - } - } -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Reducer/ReducerRegistry.h b/mlir/include/mlir/Reducer/ReducerRegistry.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/ReducerRegistry.h @@ -0,0 +1,27 @@ +//===- ReducerRegistry.h - Reducer Callback Registration --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_REDUCERREGISTRY_H +#define MLIR_REDUCER_REDUCERREGISTRY_H + +#include + +namespace mlir { + +/// Reducer is a kind of RewritePattern, which means we can use DRR to write the +/// reducer pattern. mlir-tblgen will generate a helper function +/// populateWithGenerated() to register all the patterns from certain file. +/// registerReducerFunc is used to collect all these helper functions so that we +/// are able to get all the user defined reducer patterns. +class RewritePatternSet; +using ReducerCollectorFunction = std::function; +void registerReducerFunc(const ReducerCollectorFunction &function); + +} // namespace mlir + +#endif // MLIR_REDUCER_REDUCERREGISTRY_H diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -21,19 +21,22 @@ #include #include "mlir/Reducer/Tester.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ToolOutputFile.h" namespace mlir { +class ModuleOp; +class Region; + /// Defines the traversal method options to be used in the reduction tree /// traversal. enum TraversalMode { SinglePath, Backtrack, MultiPath }; -/// This class defines the ReductionNode which is used to generate variant and -/// keep track of the necessary metadata for the reduction pass. The nodes are -/// linked together in a reduction tree structure which defines the relationship -/// between all the different generated variants. +/// A ReductionNode records a state of module reduction process and it generates +/// new nodes based on TraversalMode. All the nodes form a tree structure which +/// we can use it to retrieve a path to reduce a certain module. class ReductionNode { public: template @@ -46,22 +49,34 @@ ReductionNode *getParent() const; + ModuleOp getModule(); + + Region &getRegion(); + size_t getSize() const; /// Returns true if the module exhibits the interesting behavior. Tester::Interestingness isInteresting() const; - std::vector getRanges() const; + llvm::ArrayRef getStartRanges() const; + + llvm::ArrayRef getRanges() const; - std::vector &getVariants(); + llvm::ArrayRef getVariants(); /// Split the ranges and generate new variants. - std::vector generateNewVariants(); + llvm::ArrayRef generateNewVariants(); /// Update the interestingness result from tester. void update(std::pair result); private: + /// Each Reduction Node contains a copy of module for applying rewrite + /// patterns. In addition, we only apply rewrite patterns in a certain region. + /// In init(), we will duplicate the module from parent node and locate the + /// corresponding region. + void init(ModuleOp parentModule, Region &parentRegion); + /// A custom BFS iterator. The difference between /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. /// We may explore more neighbors at certain node if we didn't find interested @@ -87,8 +102,7 @@ BaseIterator &operator++() { ReductionNode *top = visitQueue.front(); visitQueue.pop(); - std::vector neighbors = getNeighbors(top); - for (ReductionNode *node : neighbors) + for (ReductionNode *node : getNeighbors(top)) visitQueue.push(node); return *this; } @@ -103,7 +117,7 @@ ReductionNode *operator->() const { return visitQueue.front(); } protected: - std::vector getNeighbors(ReductionNode *node) { + llvm::ArrayRef getNeighbors(ReductionNode *node) { return static_cast(this)->getNeighbors(node); } @@ -111,26 +125,48 @@ std::queue visitQueue; }; - /// The size of module after applying the range constraints. + /// ReductionTreePass will apply the reducer patterns on this instance. This + /// is a copy of top-level module, which means it's independent between + /// different ReductionNodes. + ModuleOp module; + + /// The region of certain operation in the module we're reducing + Region *region; + + ReductionNode *parent; + + /// The size of module after applying the reducer patterns with range + /// constraints. size_t size; /// This is true if the module has been evaluated and it exhibits the /// interesting behavior. Tester::Interestingness interesting; - ReductionNode *parent; - - /// We will only keep the operation with index falls into the ranges. - /// For example, number each function in a certain module and then we will - /// remove the functions with index outside the ranges and see if the - /// resulting module is still interesting. + /// `ranges` selects a subset of operations in the region. We implictly number + /// each operation in the region and ReductionTreePass will apply reducer + /// patterns on the operation falls into the `ranges`. We will generate new + /// ReductionNode with subset of `ranges` to see if we can do further + /// reduction. we may split the element in the `ranges` so that we can have + /// more subset variants from `ranges`. + /// Note that after applying the reducer patterns the number of operation in + /// the region may have changed, we need to update the `ranges` after that. std::vector ranges; + /// `startRanges` means how this ReductionNode is redeuced from the parent + /// node. I.e., if we apply the same reducer patterns and `startRanges` + /// selection on the parent region, we will get the same module as this node. + /// This can be used to construct the reduction path from root. + const std::vector startRanges; + /// This points to the child variants that were created using this node as a /// starting point. std::vector variants; llvm::SpecificBumpPtrAllocator &allocator; + + // Grant access to 'init'. + friend class ReductionTreePass; }; // Specialized iterator for SinglePath traversal @@ -139,9 +175,9 @@ : public BaseIterator> { friend BaseIterator>; using BaseIterator::BaseIterator; - std::vector getNeighbors(ReductionNode *node); + llvm::ArrayRef getNeighbors(ReductionNode *node); }; } // end namespace mlir -#endif +#endif // MLIR_REDUCER_REDUCTIONNODE_H diff --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h --- a/mlir/include/mlir/Reducer/ReductionTreePass.h +++ b/mlir/include/mlir/Reducer/ReductionTreePass.h @@ -19,15 +19,16 @@ #include -#include "PassDetail.h" -#include "ReductionNode.h" -#include "mlir/Reducer/Passes/OpReducer.h" +#include "mlir/Reducer/PassDetail.h" #include "mlir/Reducer/Tester.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #define DEBUG_TYPE "mlir-reduce" namespace mlir { +class FrozenRewritePatternSet; + /// This class defines the Reduction Tree Pass. It provides a framework to /// to implement a reduction pass using a tree structure to keep track of the /// generated reduced variants. @@ -36,13 +37,32 @@ ReductionTreePass() = default; ReductionTreePass(const ReductionTreePass &pass) = default; + /// Initialize the reducer list. + LogicalResult initialize(MLIRContext *context) override; + /// Runs the pass instance in the pass pipeline. void runOnOperation() override; private: + void reduceOp(ModuleOp module, Region ®ion); + template - ModuleOp findOptimal(ModuleOp module, std::unique_ptr reducer, - ReductionNode *node); + void findOptimal(ModuleOp module, Region ®ion); + + /// We will apply the `rewritePatterns` to the operations in the ranges + /// specified by ReudctionNode. Note that we are not able to remove an + /// operation without replacing it with another valid operation. However, + /// The validity of module reduction is based on the Tester provided by the + /// user and that means certain invalid module is still interested by the use. + /// Thus we provide an alternative way to remove operations, which is using + /// `eraseOpNotInRange` to erase the operations not in the range specified by + /// ReductionNode. + template + void findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + bool eraseOpNotInRange = false); + + FrozenRewritePatternSet rewritePatterns; }; } // end namespace mlir diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt --- a/mlir/lib/Reducer/CMakeLists.txt +++ b/mlir/lib/Reducer/CMakeLists.txt @@ -1,7 +1,13 @@ add_mlir_library(MLIRReduce + OptReductionPass.cpp + ReductionNode.cpp + ReductionTreePass.cpp Tester.cpp LINK_LIBS PUBLIC MLIRIR + MLIRPass + MLIRRewrite + MLIRTransformUtils ) mlir_check_all_link_libraries(MLIRReduce) diff --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp rename from mlir/tools/mlir-reduce/OptReductionPass.cpp rename to mlir/lib/Reducer/OptReductionPass.cpp diff --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp rename from mlir/tools/mlir-reduce/ReductionNode.cpp rename to mlir/lib/Reducer/ReductionNode.cpp --- a/mlir/tools/mlir-reduce/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Reducer/ReductionNode.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "llvm/ADT/STLExtras.h" #include @@ -23,13 +24,29 @@ using namespace mlir; ReductionNode::ReductionNode( - ReductionNode *parent, std::vector ranges, + ReductionNode *parentNode, std::vector ranges, llvm::SpecificBumpPtrAllocator &allocator) - : size(std::numeric_limits::max()), - interesting(Tester::Interestingness::Untested), - /// Root node will have the parent pointer point to themselves. - parent(parent == nullptr ? this : parent), ranges(ranges), - allocator(allocator) {} + /// Root node will have the parent pointer point to themselves. + : parent(parentNode == nullptr ? this : parentNode), + size(std::numeric_limits::max()), + interesting(Tester::Interestingness::Untested), ranges(ranges), + startRanges(ranges), allocator(allocator) { + if (parent != this) + init(parent->getModule(), parent->getRegion()); +} + +void ReductionNode::init(ModuleOp parentModule, Region &targetRegion) { + // Use the mapper help us find the corresponding region after module clone. + BlockAndValueMapping mapper; + module = cast(parentModule.getOperation()->clone(mapper)); + // Use the first block of targetRegion to locate the cloned region. + Block *block = mapper.lookup(&*targetRegion.begin()); + region = block->getParent(); +} + +ModuleOp ReductionNode::getModule() { return module; } + +Region &ReductionNode::getRegion() { return *region; } /// Returns the size in bytes of the module. size_t ReductionNode::getSize() const { return size; } @@ -41,80 +58,88 @@ return interesting; } -std::vector ReductionNode::getRanges() const { - return ranges; +llvm::ArrayRef ReductionNode::getStartRanges() const { + return startRanges; } -std::vector &ReductionNode::getVariants() { return variants; } +llvm::ArrayRef ReductionNode::getRanges() const { + return ranges; +} -#include +llvm::ArrayRef ReductionNode::getVariants() { + return variants; +} /// If we haven't explored any variants from this node, we will create N /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// max element in `ranges` and create 2 new variants for each call. -std::vector ReductionNode::generateNewVariants() { - std::vector newNodes; +llvm::ArrayRef ReductionNode::generateNewVariants() { + int oldNumVariant = getVariants().size(); + + auto createNewNode = [this](llvm::ArrayRef ranges) { + return new (allocator.Allocate()) ReductionNode(this, ranges, allocator); + }; // If we haven't created new variant, then we can create varients by removing // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can // produce variants with range {{1, 3}} and {{4, 9}}. - if (variants.size() == 0 && ranges.size() != 1) { - for (const Range &range : ranges) { - std::vector subRanges = ranges; + if (variants.size() == 0 && getRanges().size() != 1) { + for (const Range &range : getRanges()) { + std::vector subRanges = getRanges(); llvm::erase_value(subRanges, range); - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, subRanges, allocator); - newNodes.push_back(newNode); - variants.push_back(newNode); + variants.push_back(createNewNode(subRanges)); } - return newNodes; + return llvm::ArrayRef( + getVariants().begin() + oldNumVariant, getVariants().end()); } // At here, we have created the type of variants mentioned above. We would // like to split the max range into 2 to create 2 new variants. Continue on // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The - // result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. + // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. auto maxElement = std::max_element( ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) { return (lhs.second - lhs.first) > (rhs.second - rhs.first); }); - // We can't split range with lenght 1, which means we can't produce new + // The length of range is less than 1, we can't split it to create new // variant. - if (maxElement->second - maxElement->first == 1) + if (maxElement->second - maxElement->first <= 1) return {}; - auto createNewNode = [this](const std::vector &ranges) { - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, ranges, allocator); - return newNode; - }; - Range maxRange = *maxElement; - std::vector subRanges = ranges; + std::vector subRanges = getRanges(); auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin()); int half = (maxRange.first + maxRange.second) / 2; *subRangesIter = std::make_pair(maxRange.first, half); - newNodes.push_back(createNewNode(subRanges)); + variants.push_back(createNewNode(subRanges)); *subRangesIter = std::make_pair(half, maxRange.second); - newNodes.push_back(createNewNode(subRanges)); + variants.push_back(createNewNode(subRanges)); - variants.insert(variants.end(), newNodes.begin(), newNodes.end()); auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second)); it = ranges.insert(it, std::make_pair(maxRange.first, half)); // Remove the range that has been split. ranges.erase(it + 2); - return newNodes; + return llvm::ArrayRef(getVariants().begin() + oldNumVariant, + getVariants().end()); } void ReductionNode::update(std::pair result) { std::tie(interesting, size) = result; + // After applying reduction, the number of operation in the region may have + // changed. Non-interesting case won't be explored thus it's safe to keep it + // in a stale status. + if (interesting == Tester::Interestingness::True) { + // This module may has been updated. Reset the range. + ranges.clear(); + ranges.push_back({0, std::distance(region->op_begin(), region->op_end())}); + } } -std::vector +llvm::ArrayRef ReductionNode::iterator::getNeighbors(ReductionNode *node) { // Single Path: Traverses the smallest successful variant at each level until // no new successful variants can be created at that level. @@ -139,7 +164,8 @@ smallest = node; } - if (smallest != nullptr) { + if (smallest != nullptr && + smallest->getSize() < node->getParent()->getSize()) { // We got a smallest one, keep traversing from this node. node = smallest; } else { diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -0,0 +1,191 @@ +//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Reduction Tree Pass class. It provides a framework for +// the implementation of different reduction passes in the MLIR Reduce tool. It +// allows for custom specification of the variant generation behavior. It +// implements methods that define the different possible traversals of the +// reduction tree. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Reducer/ReductionTreePass.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Reducer/Passes.h" +#include "mlir/Reducer/ReducerRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ManagedStatic.h" + +#include +#include + +using namespace mlir; + +static llvm::ManagedStatic> + reducerCollector; + +void mlir::registerReducerFunc(const ReducerCollectorFunction &function) { + reducerCollector->push_back(function); +} + +LogicalResult ReductionTreePass::initialize(MLIRContext *context) { + RewritePatternSet patterns(context); + for (auto &collect : *reducerCollector) + collect(patterns); + rewritePatterns = std::move(patterns); + + return success(); +} + +void ReductionTreePass::runOnOperation() { + Operation *topOperation = getOperation(); + while (topOperation->getParentOp() != nullptr) { + topOperation = topOperation->getParentOp(); + } + ModuleOp module = cast(topOperation); + + std::queue work_list; + work_list.push(getOperation()); + + do { + Operation *op = work_list.front(); + work_list.pop(); + + for (Region ®ion : op->getRegions()) + if (!region.empty()) + reduceOp(module, region); + + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (op.getNumRegions() != 0) + work_list.push(&op); + } while (!work_list.empty()); +} + +void ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { + switch (traversalModeId) { + case TraversalMode::SinglePath: + findOptimal>(module, + region); + break; + default: + llvm_unreachable("Unsupported mode"); + } +} + +/// We implicitly number each operation in the region and if an operation's +/// number falls into rangeToKeep, we need to keep it and apply the given +/// rewrite patterns on it. +static void applyPatterns(Region ®ion, + const FrozenRewritePatternSet &patterns, + ArrayRef rangeToKeep, + bool eraseOpNotInRange) { + std::vector OpsNotInRange; + size_t keepIndex = 0; + for (auto op : enumerate(region.getOps())) { + int index = op.index(); + if (keepIndex < rangeToKeep.size() && + index == rangeToKeep[keepIndex].second) + ++keepIndex; + if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) + OpsNotInRange.push_back(&op.value()); + else + (void)applyOpPatternsAndFold(&op.value(), patterns); + } + + if (eraseOpNotInRange) + for (Operation *op : OpsNotInRange) { + op->dropAllUses(); + op->erase(); + } +} + +template +void ReductionTreePass::findOptimal(ModuleOp module, Region ®ion) { + // We separate the reduction process into 2 steps, the first one is to erase + // redundant operations and the second one is to apply the reducer patterns. + + // In the first phase, we don't apply any patterns so that we only select the + // range of operations to keep to the module stay interesting. + findOptimal(module, region, /*patterns*/ {}, + /*eraseOpNotInRange=*/true); + // In the second phase, we suppose that no operation is redundant, so we try + // to rewrite the operation into simpler form. + findOptimal(module, region, rewritePatterns); +} + +template +void ReductionTreePass::findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + bool eraseOpNotInRange) { + Tester test(testerName, testerArgs); + std::pair initStatus = + test.isInteresting(module); + // While exploring the reduction tree, we always branch from an interesting + // node. Thus the root node must be interesting. + if (initStatus.first != Tester::Interestingness::True) + return; + + llvm::SpecificBumpPtrAllocator allocator; + + std::vector ranges{ + {0, std::distance(region.op_begin(), region.op_end())}}; + + ReductionNode *root = allocator.Allocate(); + new (root) ReductionNode(nullptr, ranges, allocator); + // Duplicate the module for root node and locate the region in the copy. + root->init(module, region); + root->update(initStatus); + + ReductionNode *smallestNode = root; + IteratorType iter(root); + + while (iter != IteratorType::end()) { + ReductionNode ¤tNode = *iter; + Region &curRegion = currentNode.getRegion(); + + applyPatterns(curRegion, patterns, currentNode.getRanges(), + eraseOpNotInRange); + currentNode.update(test.isInteresting(currentNode.getModule())); + + if (currentNode.isInteresting() == Tester::Interestingness::True && + currentNode.getSize() < smallestNode->getSize()) + smallestNode = ¤tNode; + + ++iter; + } + + // At here, we have found an optimal path to reduce the given region. Retrieve + // the path and apply the reducer to it. + std::stack trace; + ReductionNode *curNode = smallestNode; + do { + trace.push(curNode); + curNode = curNode->getParent(); + } while (curNode != root); + + while (!trace.empty()) { + ReductionNode *top = trace.top(); + trace.pop(); + applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); + } + + if (test.isInteresting(module).first != Tester::Interestingness::True) + llvm::report_fatal_error("Reduced module is not interesting"); + if (test.isInteresting(module).second != smallestNode->getSize()) + llvm::report_fatal_error( + "Reduced module doesn't have consistent size with smallestNode"); +} + +std::unique_ptr mlir::createReductionTreePass() { + return std::make_unique(); +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -60,6 +60,7 @@ MLIRInferTypeOpInterface MLIRLinalgTransforms MLIRPass + MLIRReduce MLIRStandard MLIRStandardOpsTransforms MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -47,6 +47,7 @@ namespace mlir { namespace test { void registerTestDialect(DialectRegistry ®istry); +void registerTestReducerCollector(); } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2083,4 +2083,19 @@ let results = (outs AnyType:$res); } +//===----------------------------------------------------------------------===// +// Test Reducer Patterns +//===----------------------------------------------------------------------===// + +def OpCrashLong : TEST_Op<"op_crash_long"> { + let arguments = (ins I32, I32, I32); + let results = (outs I32); +} + +def OpCrashShort : TEST_Op<"op_crash_short"> { + let results = (outs I32); +} + +def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>; + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" +#include "mlir/Reducer/ReducerRegistry.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -49,6 +50,10 @@ #include "TestPatterns.inc" } // end anonymous namespace +void mlir::test::registerTestReducerCollector() { + registerReducerFunc(populateWithGenerated); +} + //===----------------------------------------------------------------------===// // Canonicalizer Driver. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Reducer/MLIRTestReducer.cpp b/mlir/test/lib/Reducer/MLIRTestReducer.cpp --- a/mlir/test/lib/Reducer/MLIRTestReducer.cpp +++ b/mlir/test/lib/Reducer/MLIRTestReducer.cpp @@ -38,7 +38,7 @@ op.walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); - if (opName == "test.crashOp") { + if (opName.contains("op_crash")) { llvm::errs() << "MLIR Reducer Test generated failure: Found " "\"crashOp\" operation\n"; exit(1); diff --git a/mlir/test/mlir-reduce/crashop-reduction.mlir b/mlir/test/mlir-reduce/crashop-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/crashop-reduction.mlir @@ -0,0 +1,20 @@ +// UNSUPPORTED: system-windows +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short". + +// CHECK-NOT: func @simple1() { +func @simple1() { + return +} + +// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { +func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32 + %0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32 + return +} + +// CHECK-NOT: func @simple5() { +func @simple5() { + return +} diff --git a/mlir/test/mlir-reduce/dce-test.mlir b/mlir/test/mlir-reduce/dce-test.mlir --- a/mlir/test/mlir-reduce/dce-test.mlir +++ b/mlir/test/mlir-reduce/dce-test.mlir @@ -12,6 +12,6 @@ // CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } diff --git a/mlir/test/mlir-reduce/multiple-function.mlir b/mlir/test/mlir-reduce/multiple-function.mlir --- a/mlir/test/mlir-reduce/multiple-function.mlir +++ b/mlir/test/mlir-reduce/multiple-function.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s // This input should be reduced by the pass pipeline so that only // the @simple5 function remains as this is the shortest function // containing the interesting behavior. @@ -16,7 +16,7 @@ // CHECK-LABEL: func @simple3() { func @simple3() { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } @@ -29,7 +29,7 @@ %0 = memref.alloc() : memref<2xf32> br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): - "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } diff --git a/mlir/test/mlir-reduce/simple-test.mlir b/mlir/test/mlir-reduce/simple-test.mlir --- a/mlir/test/mlir-reduce/simple-test.mlir +++ b/mlir/test/mlir-reduce/simple-test.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh' +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh' func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { cond_br %arg0, ^bb1, ^bb2 diff --git a/mlir/test/mlir-reduce/single-function.mlir b/mlir/test/mlir-reduce/single-function.mlir --- a/mlir/test/mlir-reduce/single-function.mlir +++ b/mlir/test/mlir-reduce/single-function.mlir @@ -2,6 +2,6 @@ // RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer func @test() { - "test.crashOp"() : () -> () + "test.op_crash"() : () -> () return } diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -44,9 +44,6 @@ ) add_llvm_tool(mlir-reduce - OptReductionPass.cpp - ReductionNode.cpp - ReductionTreePass.cpp mlir-reduce.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp deleted file mode 100644 --- a/mlir/tools/mlir-reduce/ReductionTreePass.cpp +++ /dev/null @@ -1,107 +0,0 @@ -//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the Reduction Tree Pass class. It provides a framework for -// the implementation of different reduction passes in the MLIR Reduce tool. It -// allows for custom specification of the variant generation behavior. It -// implements methods that define the different possible traversals of the -// reduction tree. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Passes.h" - -#include "llvm/Support/Allocator.h" - -using namespace mlir; - -static std::unique_ptr getOpReducer(llvm::StringRef opType) { - if (opType == ModuleOp::getOperationName()) - return std::make_unique>(); - else if (opType == FuncOp::getOperationName()) - return std::make_unique>(); - llvm_unreachable("Now only supports two built-in ops"); -} - -void ReductionTreePass::runOnOperation() { - ModuleOp module = this->getOperation(); - std::unique_ptr reducer = getOpReducer(opReducerName); - std::vector> ranges = { - {0, reducer->getNumTargetOps(module)}}; - - llvm::SpecificBumpPtrAllocator allocator; - - ReductionNode *root = allocator.Allocate(); - new (root) ReductionNode(nullptr, ranges, allocator); - - ModuleOp golden = module; - switch (traversalModeId) { - case TraversalMode::SinglePath: - golden = findOptimal>( - module, std::move(reducer), root); - break; - default: - llvm_unreachable("Unsupported mode"); - } - - if (golden != module) { - module.getBody()->clear(); - module.getBody()->getOperations().splice(module.getBody()->begin(), - golden.getBody()->getOperations()); - golden->destroy(); - } -} - -template -ModuleOp ReductionTreePass::findOptimal(ModuleOp module, - std::unique_ptr reducer, - ReductionNode *root) { - Tester test(testerName, testerArgs); - std::pair initStatus = - test.isInteresting(module); - - if (initStatus.first != Tester::Interestingness::True) { - LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested"); - return module; - } - - root->update(initStatus); - - ReductionNode *smallestNode = root; - ModuleOp golden = module; - - IteratorType iter(root); - - while (iter != IteratorType::end()) { - ModuleOp cloneModule = module.clone(); - - ReductionNode ¤tNode = *iter; - reducer->reduce(cloneModule, currentNode.getRanges()); - - std::pair result = - test.isInteresting(cloneModule); - currentNode.update(result); - - if (result.first == Tester::Interestingness::True && - result.second < smallestNode->getSize()) { - smallestNode = ¤tNode; - golden = cloneModule; - } else { - cloneModule->destroy(); - } - - ++iter; - } - - return golden; -} - -std::unique_ptr mlir::createReductionTreePass() { - return std::make_unique(); -} diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp --- a/mlir/tools/mlir-reduce/mlir-reduce.cpp +++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -22,7 +22,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Reducer/OptReductionPass.h" #include "mlir/Reducer/Passes.h" -#include "mlir/Reducer/Passes/OpReducer.h" +#include "mlir/Reducer/ReducerRegistry.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionTreePass.h" #include "mlir/Reducer/Tester.h" @@ -37,6 +37,7 @@ namespace mlir { namespace test { void registerTestDialect(DialectRegistry &); +void registerTestReducerCollector(); } // namespace test } // namespace mlir @@ -82,6 +83,7 @@ registerAllDialects(registry); #ifdef MLIR_INCLUDE_TESTS mlir::test::registerTestDialect(registry); + mlir::test::registerTestReducerCollector(); #endif mlir::MLIRContext context(registry); @@ -95,6 +97,9 @@ // Reduction pass pipeline. PassManager pm(&context); + // We may generate invalid operations during reduction. Turn off the verifier + // and let the tester to determine whether we need a verified output. + pm.enableVerifier(false); if (failed(parser.addToPipeline(pm, errorHandler))) llvm::report_fatal_error("Failed to add pipeline");