diff --git a/mlir/include/mlir/Reducer/OptReductionPass.h b/mlir/include/mlir/Reducer/OptReductionPass.h --- a/mlir/include/mlir/Reducer/OptReductionPass.h +++ b/mlir/include/mlir/Reducer/OptReductionPass.h @@ -17,12 +17,7 @@ #include "PassDetail.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Reducer/ReductionNode.h" -#include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Tester.h" #include "mlir/Transforms/Passes.h" -#include "llvm/Support/Debug.h" namespace mlir { diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td --- a/mlir/include/mlir/Reducer/Passes.td +++ b/mlir/include/mlir/Reducer/Passes.td @@ -24,14 +24,12 @@ ]; } -def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { +def ReductionTree : Pass<"reduction-tree"> { let summary = "A general reduction tree pass for the MLIR Reduce Tool"; let constructor = "mlir::createReductionTreePass()"; let options = [ - Option<"opReducerName", "op-reducer", "std::string", /* default */"", - "The OpReducer to reduce the module">, Option<"traversalModeId", "traversal-mode", "unsigned", /* default */"0", "The graph traversal mode">, ] # CommonReductionPassOptions.options; diff --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h deleted file mode 100644 --- a/mlir/include/mlir/Reducer/Passes/OpReducer.h +++ /dev/null @@ -1,76 +0,0 @@ -//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the OpReducer class. It defines a variant generator method -// with the purpose of producing different variants by eliminating a -// parameterizable type of operations from the parent module. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H -#define MLIR_REDUCER_PASSES_OPREDUCER_H - -#include - -#include "mlir/Reducer/ReductionNode.h" -#include "mlir/Reducer/Tester.h" - -namespace mlir { - -class OpReducer { -public: - virtual ~OpReducer() = default; - /// According to rangeToKeep, try to reduce the given module. We implicitly - /// number each interesting operation and rangeToKeep indicates that if an - /// operation's number falls into certain range, then we will not try to - /// reduce that operation. - virtual void reduce(ModuleOp module, - ArrayRef rangeToKeep) = 0; - /// Return the number of certain kind of operations that we would like to - /// reduce. This can be used to build a range map to exclude uninterested - /// operations. - virtual int getNumTargetOps(ModuleOp module) const = 0; -}; - -/// Reducer is a helper class to remove potential uninteresting operations from -/// module. -template -class Reducer : public OpReducer { -public: - ~Reducer() override = default; - - int getNumTargetOps(ModuleOp module) const override { - return std::distance(module.getOps().begin(), - module.getOps().end()); - } - - void reduce(ModuleOp module, - ArrayRef rangeToKeep) override { - std::vector opsToRemove; - size_t keepIndex = 0; - - for (auto op : enumerate(module.getOps())) { - int index = op.index(); - if (keepIndex < rangeToKeep.size() && - index == rangeToKeep[keepIndex].second) - ++keepIndex; - if (keepIndex == rangeToKeep.size() || - index < rangeToKeep[keepIndex].first) - opsToRemove.push_back(op.value()); - } - - for (Operation *o : opsToRemove) { - o->dropAllUses(); - o->erase(); - } - } -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -21,19 +21,24 @@ #include #include "mlir/Reducer/Tester.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ToolOutputFile.h" namespace mlir { +class ModuleOp; +class Region; + /// Defines the traversal method options to be used in the reduction tree /// traversal. enum TraversalMode { SinglePath, Backtrack, MultiPath }; -/// This class defines the ReductionNode which is used to generate variant and -/// keep track of the necessary metadata for the reduction pass. The nodes are -/// linked together in a reduction tree structure which defines the relationship -/// between all the different generated variants. +/// ReductionReePass will build a reduction tree during module reduction and the +/// ReductionNode represents the vertex of the tree. A ReductionNode records the +/// information such as the reduced module, how this node is reduced from the +/// parent node, .etc. This information will be used to construct a reduction +/// path to reduce the certain module. class ReductionNode { public: template @@ -44,23 +49,42 @@ ReductionNode(ReductionNode *parent, std::vector range, llvm::SpecificBumpPtrAllocator &allocator); - ReductionNode *getParent() const; + ReductionNode *getParent() const { return parent; } + + /// Return the reduced module if the reducer pattern has applied, otherwise, + /// it'll be the same module as the one in the parent node. + ModuleOp getModule() { return module; } + + // Return the region we're reducing. + Region &getRegion() { return *region; } - size_t getSize() const; + /// Return the size of the module. + size_t getSize() const { return size; } /// Returns true if the module exhibits the interesting behavior. - Tester::Interestingness isInteresting() const; + Tester::Interestingness isInteresting() const { return interesting; } - std::vector getRanges() const; + /// Return the range information that how this node is reduced from the parent + /// node. + ArrayRef getStartRanges() const { return startRanges; } - std::vector &getVariants(); + /// Return the range set we are using to generate variants. + ArrayRef getRanges() const { return ranges; } + + ArrayRef getVariants() { return variants; } /// Split the ranges and generate new variants. - std::vector generateNewVariants(); + ArrayRef generateNewVariants(); /// Update the interestingness result from tester. void update(std::pair result); + /// Each Reduction Node contains a copy of module for applying rewrite + /// patterns. In addition, we only apply rewrite patterns in a certain region. + /// In init(), we will duplicate the module from parent node and locate the + /// corresponding region. + void init(ModuleOp parentModule, Region &parentRegion); + private: /// A custom BFS iterator. The difference between /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. @@ -87,8 +111,7 @@ BaseIterator &operator++() { ReductionNode *top = visitQueue.front(); visitQueue.pop(); - std::vector neighbors = getNeighbors(top); - for (ReductionNode *node : neighbors) + for (ReductionNode *node : getNeighbors(top)) visitQueue.push(node); return *this; } @@ -103,7 +126,7 @@ ReductionNode *operator->() const { return visitQueue.front(); } protected: - std::vector getNeighbors(ReductionNode *node) { + ArrayRef getNeighbors(ReductionNode *node) { return static_cast(this)->getNeighbors(node); } @@ -111,21 +134,39 @@ std::queue visitQueue; }; - /// The size of module after applying the range constraints. + /// This is a copy of module from parent node. All the reducer patterns will + /// be applied to this instance. + ModuleOp module; + + /// The region of certain operation we're reducint in the module + Region *region; + + ReductionNode *parent; + + /// The size of module after applying the reducer patterns with range + /// constraints. This is only valid while the interestingness has been tested. size_t size; /// This is true if the module has been evaluated and it exhibits the /// interesting behavior. Tester::Interestingness interesting; - ReductionNode *parent; - - /// We will only keep the operation with index falls into the ranges. - /// For example, number each function in a certain module and then we will - /// remove the functions with index outside the ranges and see if the - /// resulting module is still interesting. + /// `ranges` selects a subset of operations in the region. We implictly number + /// each operation in the region and ReductionTreePass will apply reducer + /// patterns on the operation falls into the `ranges`. We will generate new + /// ReductionNode with subset of `ranges` to see if we can do further + /// reduction. we may split the element in the `ranges` so that we can have + /// more subset variants from `ranges`. + /// Note that after applying the reducer patterns the number of operation in + /// the region may have changed, we need to update the `ranges` after that. std::vector ranges; + /// `startRanges` means how this ReductionNode is redeuced from the parent + /// node. I.e., if we apply the same reducer patterns and `startRanges` + /// selection on the parent region, we will get the same module as this node. + /// This can be used to construct the reduction path from root. + const std::vector startRanges; + /// This points to the child variants that were created using this node as a /// starting point. std::vector variants; @@ -139,9 +180,9 @@ : public BaseIterator> { friend BaseIterator>; using BaseIterator::BaseIterator; - std::vector getNeighbors(ReductionNode *node); + ArrayRef getNeighbors(ReductionNode *node); }; } // end namespace mlir -#endif +#endif // MLIR_REDUCER_REDUCTIONNODE_H diff --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h --- a/mlir/include/mlir/Reducer/ReductionTreePass.h +++ b/mlir/include/mlir/Reducer/ReductionTreePass.h @@ -19,15 +19,16 @@ #include -#include "PassDetail.h" -#include "ReductionNode.h" -#include "mlir/Reducer/Passes/OpReducer.h" +#include "mlir/Reducer/PassDetail.h" #include "mlir/Reducer/Tester.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #define DEBUG_TYPE "mlir-reduce" namespace mlir { +class FrozenRewritePatternSet; + /// This class defines the Reduction Tree Pass. It provides a framework to /// to implement a reduction pass using a tree structure to keep track of the /// generated reduced variants. @@ -39,10 +40,12 @@ /// Runs the pass instance in the pass pipeline. void runOnOperation() override; + void populateReducerPatterns(const FrozenRewritePatternSet &patterns); + private: - template - ModuleOp findOptimal(ModuleOp module, std::unique_ptr reducer, - ReductionNode *node); + void reduceOp(ModuleOp module, Region ®ion); + + FrozenRewritePatternSet reducerPatterns; }; } // end namespace mlir diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt --- a/mlir/lib/Reducer/CMakeLists.txt +++ b/mlir/lib/Reducer/CMakeLists.txt @@ -1,7 +1,13 @@ add_mlir_library(MLIRReduce + OptReductionPass.cpp + ReductionNode.cpp + ReductionTreePass.cpp Tester.cpp LINK_LIBS PUBLIC MLIRIR + MLIRPass + MLIRRewrite + MLIRTransformUtils ) mlir_check_all_link_libraries(MLIRReduce) diff --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp rename from mlir/tools/mlir-reduce/OptReductionPass.cpp rename to mlir/lib/Reducer/OptReductionPass.cpp --- a/mlir/tools/mlir-reduce/OptReductionPass.cpp +++ b/mlir/lib/Reducer/OptReductionPass.cpp @@ -13,9 +13,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Reducer/OptReductionPass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Reducer/Passes.h" #include "mlir/Reducer/Tester.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "mlir-reduce" diff --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp rename from mlir/tools/mlir-reduce/ReductionNode.cpp rename to mlir/lib/Reducer/ReductionNode.cpp --- a/mlir/tools/mlir-reduce/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Reducer/ReductionNode.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "llvm/ADT/STLExtras.h" #include @@ -23,102 +24,99 @@ using namespace mlir; ReductionNode::ReductionNode( - ReductionNode *parent, std::vector ranges, + ReductionNode *parentNode, std::vector ranges, llvm::SpecificBumpPtrAllocator &allocator) - : size(std::numeric_limits::max()), - interesting(Tester::Interestingness::Untested), - /// Root node will have the parent pointer point to themselves. - parent(parent == nullptr ? this : parent), ranges(ranges), - allocator(allocator) {} - -/// Returns the size in bytes of the module. -size_t ReductionNode::getSize() const { return size; } - -ReductionNode *ReductionNode::getParent() const { return parent; } - -/// Returns true if the module exhibits the interesting behavior. -Tester::Interestingness ReductionNode::isInteresting() const { - return interesting; + /// Root node will have the parent pointer point to themselves. + : parent(parentNode == nullptr ? this : parentNode), + size(std::numeric_limits::max()), + interesting(Tester::Interestingness::Untested), ranges(ranges), + startRanges(ranges), allocator(allocator) { + if (parent != this) + init(parent->getModule(), parent->getRegion()); } -std::vector ReductionNode::getRanges() const { - return ranges; +void ReductionNode::init(ModuleOp parentModule, Region &targetRegion) { + // Use the mapper help us find the corresponding region after module clone. + BlockAndValueMapping mapper; + module = cast(parentModule->clone(mapper)); + // Use the first block of targetRegion to locate the cloned region. + Block *block = mapper.lookup(&*targetRegion.begin()); + region = block->getParent(); } -std::vector &ReductionNode::getVariants() { return variants; } - -#include - /// If we haven't explored any variants from this node, we will create N /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// max element in `ranges` and create 2 new variants for each call. -std::vector ReductionNode::generateNewVariants() { - std::vector newNodes; +ArrayRef ReductionNode::generateNewVariants() { + int oldNumVariant = getVariants().size(); + + auto createNewNode = [this](std::vector ranges) { + return new (allocator.Allocate()) + ReductionNode(this, std::move(ranges), allocator); + }; // If we haven't created new variant, then we can create varients by removing // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can // produce variants with range {{1, 3}} and {{4, 9}}. - if (variants.size() == 0 && ranges.size() != 1) { - for (const Range &range : ranges) { - std::vector subRanges = ranges; + if (variants.size() == 0 && getRanges().size() != 1) { + for (const Range &range : getRanges()) { + std::vector subRanges = getRanges(); llvm::erase_value(subRanges, range); - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, subRanges, allocator); - newNodes.push_back(newNode); - variants.push_back(newNode); + variants.push_back(createNewNode(std::move(subRanges))); } - return newNodes; + return getVariants().drop_front(oldNumVariant); } // At here, we have created the type of variants mentioned above. We would // like to split the max range into 2 to create 2 new variants. Continue on // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The - // result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. + // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. auto maxElement = std::max_element( ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) { return (lhs.second - lhs.first) > (rhs.second - rhs.first); }); - // We can't split range with lenght 1, which means we can't produce new + // The length of range is less than 1, we can't split it to create new // variant. - if (maxElement->second - maxElement->first == 1) + if (maxElement->second - maxElement->first <= 1) return {}; - auto createNewNode = [this](const std::vector &ranges) { - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, ranges, allocator); - return newNode; - }; - Range maxRange = *maxElement; - std::vector subRanges = ranges; + std::vector subRanges = getRanges(); auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin()); int half = (maxRange.first + maxRange.second) / 2; *subRangesIter = std::make_pair(maxRange.first, half); - newNodes.push_back(createNewNode(subRanges)); + variants.push_back(createNewNode(subRanges)); *subRangesIter = std::make_pair(half, maxRange.second); - newNodes.push_back(createNewNode(subRanges)); + variants.push_back(createNewNode(std::move(subRanges))); - variants.insert(variants.end(), newNodes.begin(), newNodes.end()); auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second)); it = ranges.insert(it, std::make_pair(maxRange.first, half)); // Remove the range that has been split. ranges.erase(it + 2); - return newNodes; + return getVariants().drop_front(oldNumVariant); } void ReductionNode::update(std::pair result) { std::tie(interesting, size) = result; + // After applying reduction, the number of operation in the region may have + // changed. Non-interesting case won't be explored thus it's safe to keep it + // in a stale status. + if (interesting == Tester::Interestingness::True) { + // This module may has been updated. Reset the range. + ranges.clear(); + ranges.push_back({0, std::distance(region->op_begin(), region->op_end())}); + } } -std::vector +ArrayRef ReductionNode::iterator::getNeighbors(ReductionNode *node) { // Single Path: Traverses the smallest successful variant at each level until // no new successful variants can be created at that level. - llvm::ArrayRef variantsFromParent = + ArrayRef variantsFromParent = node->getParent()->getVariants(); // The parent node created several variants and they may be waiting for @@ -139,7 +137,8 @@ smallest = node; } - if (smallest != nullptr) { + if (smallest != nullptr && + smallest->getSize() < node->getParent()->getSize()) { // We got a smallest one, keep traversing from this node. node = smallest; } else { diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -0,0 +1,193 @@ +//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Reduction Tree Pass class. It provides a framework for +// the implementation of different reduction passes in the MLIR Reduce tool. It +// allows for custom specification of the variant generation behavior. It +// implements methods that define the different possible traversals of the +// reduction tree. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Reducer/ReductionTreePass.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Reducer/Passes.h" +#include "mlir/Reducer/ReductionNode.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ManagedStatic.h" + +#include +#include + +using namespace mlir; + +/// We implicitly number each operation in the region and if an operation's +/// number falls into rangeToKeep, we need to keep it and apply the given +/// rewrite patterns on it. +static void applyPatterns(Region ®ion, + const FrozenRewritePatternSet &patterns, + ArrayRef rangeToKeep, + bool eraseOpNotInRange) { + std::vector opsNotInRange; + size_t keepIndex = 0; + for (auto op : enumerate(region.getOps())) { + int index = op.index(); + if (keepIndex < rangeToKeep.size() && + index == rangeToKeep[keepIndex].second) + ++keepIndex; + if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) + opsNotInRange.push_back(&op.value()); + else + (void)applyOpPatternsAndFold(&op.value(), patterns); + } + + if (eraseOpNotInRange) + for (Operation *op : opsNotInRange) { + op->dropAllUses(); + op->erase(); + } +} + +/// We will apply the reducer patterns to the operations in the ranges specified +/// by ReudctionNode. Note that we are not able to remove an operation without +/// replacing it with another valid operation. However, The validity of module +/// reduction is based on the Tester provided by the user and that means certain +/// invalid module is still interested by the use. Thus we provide an +/// alternative way to remove operations, which is using `eraseOpNotInRange` to +/// erase the operations not in the range specified by ReductionNode. +template +static void findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + const Tester &test, bool eraseOpNotInRange) { + std::pair initStatus = + test.isInteresting(module); + // While exploring the reduction tree, we always branch from an interesting + // node. Thus the root node must be interesting. + if (initStatus.first != Tester::Interestingness::True) + return; + + llvm::SpecificBumpPtrAllocator allocator; + + std::vector ranges{ + {0, std::distance(region.op_begin(), region.op_end())}}; + + ReductionNode *root = allocator.Allocate(); + new (root) ReductionNode(nullptr, std::move(ranges), allocator); + // Duplicate the module for root node and locate the region in the copy. + root->init(module, region); + root->update(initStatus); + + ReductionNode *smallestNode = root; + IteratorType iter(root); + + while (iter != IteratorType::end()) { + ReductionNode ¤tNode = *iter; + Region &curRegion = currentNode.getRegion(); + + applyPatterns(curRegion, patterns, currentNode.getRanges(), + eraseOpNotInRange); + currentNode.update(test.isInteresting(currentNode.getModule())); + + if (currentNode.isInteresting() == Tester::Interestingness::True && + currentNode.getSize() < smallestNode->getSize()) + smallestNode = ¤tNode; + + ++iter; + } + + // At here, we have found an optimal path to reduce the given region. Retrieve + // the path and apply the reducer to it. + std::stack trace; + ReductionNode *curNode = smallestNode; + do { + trace.push(curNode); + curNode = curNode->getParent(); + } while (curNode != root); + + while (!trace.empty()) { + ReductionNode *top = trace.top(); + trace.pop(); + applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); + } + + if (test.isInteresting(module).first != Tester::Interestingness::True) + llvm::report_fatal_error("Reduced module is not interesting"); + if (test.isInteresting(module).second != smallestNode->getSize()) + llvm::report_fatal_error( + "Reduced module doesn't have consistent size with smallestNode"); +} + +template +static void findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + const Tester &test) { + // We separate the reduction process into 2 steps, the first one is to erase + // redundant operations and the second one is to apply the reducer patterns. + + // In the first phase, we don't apply any patterns so that we only select the + // range of operations to keep to the module stay interesting. + findOptimal(module, region, /*patterns*/ {}, test, + /*eraseOpNotInRange=*/true); + // In the second phase, we suppose that no operation is redundant, so we try + // to rewrite the operation into simpler form. + findOptimal(module, region, patterns, test, + /*eraseOpNotInRange=*/false); +} + +//===----------------------------------------------------------------------===// +// ReductionTreePass +//===----------------------------------------------------------------------===// + +void ReductionTreePass::populateReducerPatterns( + const FrozenRewritePatternSet &patterns) { + reducerPatterns = patterns; +} + +void ReductionTreePass::runOnOperation() { + Operation *topOperation = getOperation(); + while (topOperation->getParentOp() != nullptr) + topOperation = topOperation->getParentOp(); + ModuleOp module = cast(topOperation); + + std::queue workList; + workList.push(getOperation()); + + do { + Operation *op = workList.front(); + workList.pop(); + + for (Region ®ion : op->getRegions()) + if (!region.empty()) + reduceOp(module, region); + + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (op.getNumRegions() != 0) + workList.push(&op); + } while (!workList.empty()); +} + +void ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { + Tester test(testerName, testerArgs); + switch (traversalModeId) { + case TraversalMode::SinglePath: + findOptimal>( + module, region, reducerPatterns, test); + break; + default: + llvm_unreachable("Unsupported mode"); + } +} + +std::unique_ptr mlir::createReductionTreePass() { + return std::make_unique(); +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -60,6 +60,7 @@ MLIRInferTypeOpInterface MLIRLinalgTransforms MLIRPass + MLIRReduce MLIRStandard MLIRStandardOpsTransforms MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -34,6 +34,7 @@ namespace mlir { class DLTIDialect; +class RewritePatternSet; } // namespace mlir #include "TestOpEnums.h.inc" @@ -47,6 +48,7 @@ namespace mlir { namespace test { void registerTestDialect(DialectRegistry ®istry); +void populateTestReducerPatterns(RewritePatternSet &patterns); } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2083,4 +2083,19 @@ let results = (outs AnyType:$res); } +//===----------------------------------------------------------------------===// +// Test Reducer Patterns +//===----------------------------------------------------------------------===// + +def OpCrashLong : TEST_Op<"op_crash_long"> { + let arguments = (ins I32, I32, I32); + let results = (outs I32); +} + +def OpCrashShort : TEST_Op<"op_crash_short"> { + let results = (outs I32); +} + +def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>; + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -49,6 +49,10 @@ #include "TestPatterns.inc" } // end anonymous namespace +void mlir::test::populateTestReducerPatterns(RewritePatternSet &patterns) { + populateWithGenerated(patterns); +} + //===----------------------------------------------------------------------===// // Canonicalizer Driver. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Reducer/MLIRTestReducer.cpp b/mlir/test/lib/Reducer/MLIRTestReducer.cpp --- a/mlir/test/lib/Reducer/MLIRTestReducer.cpp +++ b/mlir/test/lib/Reducer/MLIRTestReducer.cpp @@ -38,7 +38,7 @@ op.walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); - if (opName == "test.crashOp") { + if (opName.contains("op_crash")) { llvm::errs() << "MLIR Reducer Test generated failure: Found " "\"crashOp\" operation\n"; exit(1); diff --git a/mlir/test/mlir-reduce/crashop-reduction.mlir b/mlir/test/mlir-reduce/crashop-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/crashop-reduction.mlir @@ -0,0 +1,20 @@ +// UNSUPPORTED: system-windows +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short". + +// CHECK-NOT: func @simple1() { +func @simple1() { + return +} + +// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { +func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32 + %0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32 + return +} + +// CHECK-NOT: func @simple5() { +func @simple5() { + return +} diff --git a/mlir/test/mlir-reduce/dce-test.mlir b/mlir/test/mlir-reduce/dce-test.mlir --- a/mlir/test/mlir-reduce/dce-test.mlir +++ b/mlir/test/mlir-reduce/dce-test.mlir @@ -12,6 +12,6 @@ // CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } diff --git a/mlir/test/mlir-reduce/multiple-function.mlir b/mlir/test/mlir-reduce/multiple-function.mlir --- a/mlir/test/mlir-reduce/multiple-function.mlir +++ b/mlir/test/mlir-reduce/multiple-function.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s // This input should be reduced by the pass pipeline so that only // the @simple5 function remains as this is the shortest function // containing the interesting behavior. @@ -16,7 +16,7 @@ // CHECK-LABEL: func @simple3() { func @simple3() { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } @@ -29,7 +29,7 @@ %0 = memref.alloc() : memref<2xf32> br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): - "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } diff --git a/mlir/test/mlir-reduce/simple-test.mlir b/mlir/test/mlir-reduce/simple-test.mlir --- a/mlir/test/mlir-reduce/simple-test.mlir +++ b/mlir/test/mlir-reduce/simple-test.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh' +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh' func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { cond_br %arg0, ^bb1, ^bb2 diff --git a/mlir/test/mlir-reduce/single-function.mlir b/mlir/test/mlir-reduce/single-function.mlir --- a/mlir/test/mlir-reduce/single-function.mlir +++ b/mlir/test/mlir-reduce/single-function.mlir @@ -2,6 +2,6 @@ // RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer func @test() { - "test.crashOp"() : () -> () + "test.op_crash"() : () -> () return } diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -44,9 +44,6 @@ ) add_llvm_tool(mlir-reduce - OptReductionPass.cpp - ReductionNode.cpp - ReductionTreePass.cpp mlir-reduce.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp deleted file mode 100644 --- a/mlir/tools/mlir-reduce/ReductionTreePass.cpp +++ /dev/null @@ -1,107 +0,0 @@ -//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the Reduction Tree Pass class. It provides a framework for -// the implementation of different reduction passes in the MLIR Reduce tool. It -// allows for custom specification of the variant generation behavior. It -// implements methods that define the different possible traversals of the -// reduction tree. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Passes.h" - -#include "llvm/Support/Allocator.h" - -using namespace mlir; - -static std::unique_ptr getOpReducer(llvm::StringRef opType) { - if (opType == ModuleOp::getOperationName()) - return std::make_unique>(); - else if (opType == FuncOp::getOperationName()) - return std::make_unique>(); - llvm_unreachable("Now only supports two built-in ops"); -} - -void ReductionTreePass::runOnOperation() { - ModuleOp module = this->getOperation(); - std::unique_ptr reducer = getOpReducer(opReducerName); - std::vector> ranges = { - {0, reducer->getNumTargetOps(module)}}; - - llvm::SpecificBumpPtrAllocator allocator; - - ReductionNode *root = allocator.Allocate(); - new (root) ReductionNode(nullptr, ranges, allocator); - - ModuleOp golden = module; - switch (traversalModeId) { - case TraversalMode::SinglePath: - golden = findOptimal>( - module, std::move(reducer), root); - break; - default: - llvm_unreachable("Unsupported mode"); - } - - if (golden != module) { - module.getBody()->clear(); - module.getBody()->getOperations().splice(module.getBody()->begin(), - golden.getBody()->getOperations()); - golden->destroy(); - } -} - -template -ModuleOp ReductionTreePass::findOptimal(ModuleOp module, - std::unique_ptr reducer, - ReductionNode *root) { - Tester test(testerName, testerArgs); - std::pair initStatus = - test.isInteresting(module); - - if (initStatus.first != Tester::Interestingness::True) { - LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested"); - return module; - } - - root->update(initStatus); - - ReductionNode *smallestNode = root; - ModuleOp golden = module; - - IteratorType iter(root); - - while (iter != IteratorType::end()) { - ModuleOp cloneModule = module.clone(); - - ReductionNode ¤tNode = *iter; - reducer->reduce(cloneModule, currentNode.getRanges()); - - std::pair result = - test.isInteresting(cloneModule); - currentNode.update(result); - - if (result.first == Tester::Interestingness::True && - result.second < smallestNode->getSize()) { - smallestNode = ¤tNode; - golden = cloneModule; - } else { - cloneModule->destroy(); - } - - ++iter; - } - - return golden; -} - -std::unique_ptr mlir::createReductionTreePass() { - return std::make_unique(); -} diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp --- a/mlir/tools/mlir-reduce/mlir-reduce.cpp +++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -15,20 +15,18 @@ #include +#include "mlir/IR/PatternMatch.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Reducer/OptReductionPass.h" #include "mlir/Reducer/Passes.h" -#include "mlir/Reducer/Passes/OpReducer.h" -#include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Tester.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Passes.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/ToolOutputFile.h" @@ -37,6 +35,7 @@ namespace mlir { namespace test { void registerTestDialect(DialectRegistry &); +void populateTestReducerPatterns(RewritePatternSet &patterns); } // namespace test } // namespace mlir @@ -95,9 +94,24 @@ // Reduction pass pipeline. PassManager pm(&context); + // We may generate invalid operations during reduction. Turn off the verifier + // and let the tester to determine whether we need a verified output. + pm.enableVerifier(false); if (failed(parser.addToPipeline(pm, errorHandler))) llvm::report_fatal_error("Failed to add pipeline"); + RewritePatternSet patterns(&context); +#ifdef MLIR_INCLUDE_TESTS + mlir::test::populateTestReducerPatterns(patterns); +#endif + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + // Only ReductionTreePass will apply reducer patterns. + for (Pass &pass : pm) + if (ReductionTreePass *reductionTreePass = + llvm::dyn_cast_or_null(&pass)) + reductionTreePass->populateReducerPatterns(frozenPatterns); + ModuleOp m = moduleRef.get().clone(); if (failed(pm.run(m)))