diff --git a/mlir/include/mlir/Reducer/OptReductionPass.h b/mlir/include/mlir/Reducer/OptReductionPass.h deleted file mode 100644 --- a/mlir/include/mlir/Reducer/OptReductionPass.h +++ /dev/null @@ -1,41 +0,0 @@ -//===- OptReductionPass.h - Optimization Reduction Pass Wrapper -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to -// run any optimization pass within it and only replaces the output module with -// the transformed version if it is smaller and interesting. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_REDUCER_OPTREDUCTIONPASS_H -#define MLIR_REDUCER_OPTREDUCTIONPASS_H - -#include "PassDetail.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Reducer/ReductionNode.h" -#include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Tester.h" -#include "mlir/Transforms/Passes.h" -#include "llvm/Support/Debug.h" - -namespace mlir { - -class OptReductionPass : public OptReductionBase { -public: - OptReductionPass() = default; - - OptReductionPass(const OptReductionPass &srcPass) = default; - - /// Runs the pass instance in the pass pipeline. - void runOnOperation() override; -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Reducer/Passes.h b/mlir/include/mlir/Reducer/Passes.h --- a/mlir/include/mlir/Reducer/Passes.h +++ b/mlir/include/mlir/Reducer/Passes.h @@ -9,8 +9,6 @@ #define MLIR_REDUCER_PASSES_H #include "mlir/Pass/Pass.h" -#include "mlir/Reducer/OptReductionPass.h" -#include "mlir/Reducer/ReductionTreePass.h" namespace mlir { diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td --- a/mlir/include/mlir/Reducer/Passes.td +++ b/mlir/include/mlir/Reducer/Passes.td @@ -24,14 +24,12 @@ ]; } -def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { +def ReductionTree : Pass<"reduction-tree"> { let summary = "A general reduction tree pass for the MLIR Reduce Tool"; let constructor = "mlir::createReductionTreePass()"; let options = [ - Option<"opReducerName", "op-reducer", "std::string", /* default */"", - "The OpReducer to reduce the module">, Option<"traversalModeId", "traversal-mode", "unsigned", /* default */"0", "The graph traversal mode">, ] # CommonReductionPassOptions.options; diff --git a/mlir/include/mlir/Reducer/Passes/OpReducer.h b/mlir/include/mlir/Reducer/Passes/OpReducer.h deleted file mode 100644 --- a/mlir/include/mlir/Reducer/Passes/OpReducer.h +++ /dev/null @@ -1,76 +0,0 @@ -//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the OpReducer class. It defines a variant generator method -// with the purpose of producing different variants by eliminating a -// parameterizable type of operations from the parent module. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H -#define MLIR_REDUCER_PASSES_OPREDUCER_H - -#include - -#include "mlir/Reducer/ReductionNode.h" -#include "mlir/Reducer/Tester.h" - -namespace mlir { - -class OpReducer { -public: - virtual ~OpReducer() = default; - /// According to rangeToKeep, try to reduce the given module. We implicitly - /// number each interesting operation and rangeToKeep indicates that if an - /// operation's number falls into certain range, then we will not try to - /// reduce that operation. - virtual void reduce(ModuleOp module, - ArrayRef rangeToKeep) = 0; - /// Return the number of certain kind of operations that we would like to - /// reduce. This can be used to build a range map to exclude uninterested - /// operations. - virtual int getNumTargetOps(ModuleOp module) const = 0; -}; - -/// Reducer is a helper class to remove potential uninteresting operations from -/// module. -template -class Reducer : public OpReducer { -public: - ~Reducer() override = default; - - int getNumTargetOps(ModuleOp module) const override { - return std::distance(module.getOps().begin(), - module.getOps().end()); - } - - void reduce(ModuleOp module, - ArrayRef rangeToKeep) override { - std::vector opsToRemove; - size_t keepIndex = 0; - - for (auto op : enumerate(module.getOps())) { - int index = op.index(); - if (keepIndex < rangeToKeep.size() && - index == rangeToKeep[keepIndex].second) - ++keepIndex; - if (keepIndex == rangeToKeep.size() || - index < rangeToKeep[keepIndex].first) - opsToRemove.push_back(op.value()); - } - - for (Operation *o : opsToRemove) { - o->dropAllUses(); - o->erase(); - } - } -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -21,19 +21,25 @@ #include #include "mlir/Reducer/Tester.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ToolOutputFile.h" namespace mlir { +class ModuleOp; +class Region; + /// Defines the traversal method options to be used in the reduction tree /// traversal. enum TraversalMode { SinglePath, Backtrack, MultiPath }; -/// This class defines the ReductionNode which is used to generate variant and -/// keep track of the necessary metadata for the reduction pass. The nodes are -/// linked together in a reduction tree structure which defines the relationship -/// between all the different generated variants. +/// ReductionTreePass will build a reduction tree during module reduction and +/// the ReductionNode represents the vertex of the tree. A ReductionNode records +/// the information such as the reduced module, how this node is reduced from +/// the parent node, etc. This information will be used to construct a reduction +/// path to reduce the certain module. class ReductionNode { public: template @@ -44,23 +50,46 @@ ReductionNode(ReductionNode *parent, std::vector range, llvm::SpecificBumpPtrAllocator &allocator); - ReductionNode *getParent() const; + ReductionNode *getParent() const { return parent; } + + /// If the ReductionNode hasn't been tested the interestingness, it'll be the + /// same module as the one in the parent node. Otherwise, the returned module + /// will have been applied certain reduction strategies. Note that it's not + /// necessary to be an interesting case or a reduced module (has smaller size + /// than parent's). + ModuleOp getModule() const { return module; } + + /// Return the region we're reducing. + Region &getRegion() const { return *region; } - size_t getSize() const; + /// Return the size of the module. + size_t getSize() const { return size; } /// Returns true if the module exhibits the interesting behavior. - Tester::Interestingness isInteresting() const; + Tester::Interestingness isInteresting() const { return interesting; } - std::vector getRanges() const; + /// Return the range information that how this node is reduced from the parent + /// node. + ArrayRef getStartRanges() const { return startRanges; } - std::vector &getVariants(); + /// Return the range set we are using to generate variants. + ArrayRef getRanges() const { return ranges; } + + /// Return the generated variants(the child nodes). + ArrayRef getVariants() const { return variants; } /// Split the ranges and generate new variants. - std::vector generateNewVariants(); + ArrayRef generateNewVariants(); /// Update the interestingness result from tester. void update(std::pair result); + /// Each Reduction Node contains a copy of module for applying rewrite + /// patterns. In addition, we only apply rewrite patterns in a certain region. + /// In init(), we will duplicate the module from parent node and locate the + /// corresponding region. + LogicalResult initialize(ModuleOp parentModule, Region &parentRegion); + private: /// A custom BFS iterator. The difference between /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. @@ -87,8 +116,7 @@ BaseIterator &operator++() { ReductionNode *top = visitQueue.front(); visitQueue.pop(); - std::vector neighbors = getNeighbors(top); - for (ReductionNode *node : neighbors) + for (ReductionNode *node : getNeighbors(top)) visitQueue.push(node); return *this; } @@ -103,7 +131,7 @@ ReductionNode *operator->() const { return visitQueue.front(); } protected: - std::vector getNeighbors(ReductionNode *node) { + ArrayRef getNeighbors(ReductionNode *node) { return static_cast(this)->getNeighbors(node); } @@ -111,21 +139,42 @@ std::queue visitQueue; }; - /// The size of module after applying the range constraints. + /// This is a copy of module from parent node. All the reducer patterns will + /// be applied to this instance. + ModuleOp module; + + /// The region of certain operation we're reducing in the module + Region *region; + + /// The node we are reduced from. It means we will be in variants of parent + /// node. + ReductionNode *parent; + + /// The size of module after applying the reducer patterns with range + /// constraints. This is only valid while the interestingness has been tested. size_t size; /// This is true if the module has been evaluated and it exhibits the /// interesting behavior. Tester::Interestingness interesting; - ReductionNode *parent; - - /// We will only keep the operation with index falls into the ranges. - /// For example, number each function in a certain module and then we will - /// remove the functions with index outside the ranges and see if the - /// resulting module is still interesting. + /// `ranges` represents the selected subset of operations in the region. We + /// implictly number each operation in the region and ReductionTreePass will + /// apply reducer patterns on the operation falls into the `ranges`. We will + /// generate new ReductionNode with subset of `ranges` to see if we can do + /// further reduction. we may split the element in the `ranges` so that we can + /// have more subset variants from `ranges`. + /// Note that after applying the reducer patterns the number of operation in + /// the region may have changed, we need to update the `ranges` after that. std::vector ranges; + /// `startRanges` records the ranges of operations selected from the parent + /// node to produce this ReductionNode. It can be used to construct the + /// reduction path from the root. I.e., if we apply the same reducer patterns + /// and `startRanges` selection on the parent region, we will get the same + /// module as this node. + const std::vector startRanges; + /// This points to the child variants that were created using this node as a /// starting point. std::vector variants; @@ -139,9 +188,9 @@ : public BaseIterator> { friend BaseIterator>; using BaseIterator::BaseIterator; - std::vector getNeighbors(ReductionNode *node); + ArrayRef getNeighbors(ReductionNode *node); }; } // end namespace mlir -#endif +#endif // MLIR_REDUCER_REDUCTIONNODE_H diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h @@ -0,0 +1,56 @@ +//===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H +#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H + +#include "mlir/IR/DialectInterface.h" + +namespace mlir { + +class RewritePatternSet; + +/// This is used to report the reduction patterns for a Dialect. While using +/// mlir-reduce to reduce a module, we may want to transform certain cases into +/// simpler forms by applying certain rewrite patterns. Implement the +/// `populateReductionPatterns` to report those patterns by adding them to the +/// RewritePatternSet. +/// +/// Example: +/// MyDialectReductionPattern::populateReductionPatterns( +/// RewritePatternSet &patterns) { +/// patterns.add(patterns.getContext()); +/// } +/// +/// For DRR, mlir-tblgen will generate a helper function +/// `populateWithGenerated` which has the same signature therefore you can +/// delegate to the helper function as well. +/// +/// Example: +/// MyDialectReductionPattern::populateReductionPatterns( +/// RewritePatternSet &patterns) { +/// // Include the autogen file somewhere above. +/// populateWithGenerated(patterns); +/// } +class DialectReductionPatternInterface + : public DialectInterface::Base { +public: + /// Patterns provided here are intended to transform operations from a complex + /// form to a simpler form, without breaking the semantics of the program + /// being reduced. For example, you may want to replace the + /// tensor with a known rank and type, e.g. tensor<1xi32>, or + /// replacing an operation with a constant. + virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0; + +protected: + DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {} +}; + +} // end namespace mlir + +#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H diff --git a/mlir/include/mlir/Reducer/ReductionTreePass.h b/mlir/include/mlir/Reducer/ReductionTreePass.h deleted file mode 100644 --- a/mlir/include/mlir/Reducer/ReductionTreePass.h +++ /dev/null @@ -1,50 +0,0 @@ -//===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the Reduction Tree Pass class. It provides a framework for -// the implementation of different reduction passes in the MLIR Reduce tool. It -// allows for custom specification of the variant generation behavior. It -// implements methods that define the different possible traversals of the -// reduction tree. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H -#define MLIR_REDUCER_REDUCTIONTREEPASS_H - -#include - -#include "PassDetail.h" -#include "ReductionNode.h" -#include "mlir/Reducer/Passes/OpReducer.h" -#include "mlir/Reducer/Tester.h" - -#define DEBUG_TYPE "mlir-reduce" - -namespace mlir { - -/// This class defines the Reduction Tree Pass. It provides a framework to -/// to implement a reduction pass using a tree structure to keep track of the -/// generated reduced variants. -class ReductionTreePass : public ReductionTreeBase { -public: - ReductionTreePass() = default; - ReductionTreePass(const ReductionTreePass &pass) = default; - - /// Runs the pass instance in the pass pipeline. - void runOnOperation() override; - -private: - template - ModuleOp findOptimal(ModuleOp module, std::unique_ptr reducer, - ReductionNode *node); -}; - -} // end namespace mlir - -#endif diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt --- a/mlir/lib/Reducer/CMakeLists.txt +++ b/mlir/lib/Reducer/CMakeLists.txt @@ -1,7 +1,13 @@ add_mlir_library(MLIRReduce + OptReductionPass.cpp + ReductionNode.cpp + ReductionTreePass.cpp Tester.cpp LINK_LIBS PUBLIC MLIRIR + MLIRPass + MLIRRewrite + MLIRTransformUtils ) mlir_check_all_link_libraries(MLIRReduce) diff --git a/mlir/tools/mlir-reduce/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp rename from mlir/tools/mlir-reduce/OptReductionPass.cpp rename to mlir/lib/Reducer/OptReductionPass.cpp --- a/mlir/tools/mlir-reduce/OptReductionPass.cpp +++ b/mlir/lib/Reducer/OptReductionPass.cpp @@ -12,15 +12,27 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Reducer/OptReductionPass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" +#include "mlir/Reducer/PassDetail.h" #include "mlir/Reducer/Passes.h" #include "mlir/Reducer/Tester.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "mlir-reduce" using namespace mlir; +namespace { + +class OptReductionPass : public OptReductionBase { +public: + /// Runs the pass instance in the pass pipeline. + void runOnOperation() override; +}; + +} // end anonymous namespace + /// Runs the pass instance in the pass pipeline. void OptReductionPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: "); diff --git a/mlir/tools/mlir-reduce/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp rename from mlir/tools/mlir-reduce/ReductionNode.cpp rename to mlir/lib/Reducer/ReductionNode.cpp --- a/mlir/tools/mlir-reduce/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Reducer/ReductionNode.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "llvm/ADT/STLExtras.h" #include @@ -23,102 +24,102 @@ using namespace mlir; ReductionNode::ReductionNode( - ReductionNode *parent, std::vector ranges, + ReductionNode *parentNode, std::vector ranges, llvm::SpecificBumpPtrAllocator &allocator) - : size(std::numeric_limits::max()), - interesting(Tester::Interestingness::Untested), - /// Root node will have the parent pointer point to themselves. - parent(parent == nullptr ? this : parent), ranges(ranges), - allocator(allocator) {} - -/// Returns the size in bytes of the module. -size_t ReductionNode::getSize() const { return size; } - -ReductionNode *ReductionNode::getParent() const { return parent; } - -/// Returns true if the module exhibits the interesting behavior. -Tester::Interestingness ReductionNode::isInteresting() const { - return interesting; + /// Root node will have the parent pointer point to themselves. + : parent(parentNode == nullptr ? this : parentNode), + size(std::numeric_limits::max()), + interesting(Tester::Interestingness::Untested), ranges(ranges), + startRanges(ranges), allocator(allocator) { + if (parent != this) + if (failed(initialize(parent->getModule(), parent->getRegion()))) + llvm_unreachable("unexpected initialization failure"); } -std::vector ReductionNode::getRanges() const { - return ranges; +LogicalResult ReductionNode::initialize(ModuleOp parentModule, + Region &targetRegion) { + // Use the mapper help us find the corresponding region after module clone. + BlockAndValueMapping mapper; + module = cast(parentModule->clone(mapper)); + // Use the first block of targetRegion to locate the cloned region. + Block *block = mapper.lookup(&*targetRegion.begin()); + region = block->getParent(); + return success(); } -std::vector &ReductionNode::getVariants() { return variants; } - -#include - /// If we haven't explored any variants from this node, we will create N /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// max element in `ranges` and create 2 new variants for each call. -std::vector ReductionNode::generateNewVariants() { - std::vector newNodes; +ArrayRef ReductionNode::generateNewVariants() { + int oldNumVariant = getVariants().size(); + + auto createNewNode = [this](std::vector ranges) { + return new (allocator.Allocate()) + ReductionNode(this, std::move(ranges), allocator); + }; // If we haven't created new variant, then we can create varients by removing // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can // produce variants with range {{1, 3}} and {{4, 9}}. - if (variants.size() == 0 && ranges.size() != 1) { - for (const Range &range : ranges) { - std::vector subRanges = ranges; + if (variants.size() == 0 && getRanges().size() > 1) { + for (const Range &range : getRanges()) { + std::vector subRanges = getRanges(); llvm::erase_value(subRanges, range); - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, subRanges, allocator); - newNodes.push_back(newNode); - variants.push_back(newNode); + variants.push_back(createNewNode(std::move(subRanges))); } - return newNodes; + return getVariants().drop_front(oldNumVariant); } // At here, we have created the type of variants mentioned above. We would // like to split the max range into 2 to create 2 new variants. Continue on // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The - // result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. + // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. auto maxElement = std::max_element( ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) { return (lhs.second - lhs.first) > (rhs.second - rhs.first); }); - // We can't split range with lenght 1, which means we can't produce new + // The length of range is less than 1, we can't split it to create new // variant. - if (maxElement->second - maxElement->first == 1) + if (maxElement->second - maxElement->first <= 1) return {}; - auto createNewNode = [this](const std::vector &ranges) { - ReductionNode *newNode = allocator.Allocate(); - new (newNode) ReductionNode(this, ranges, allocator); - return newNode; - }; - Range maxRange = *maxElement; - std::vector subRanges = ranges; + std::vector subRanges = getRanges(); auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin()); int half = (maxRange.first + maxRange.second) / 2; *subRangesIter = std::make_pair(maxRange.first, half); - newNodes.push_back(createNewNode(subRanges)); + variants.push_back(createNewNode(subRanges)); *subRangesIter = std::make_pair(half, maxRange.second); - newNodes.push_back(createNewNode(subRanges)); + variants.push_back(createNewNode(std::move(subRanges))); - variants.insert(variants.end(), newNodes.begin(), newNodes.end()); auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second)); it = ranges.insert(it, std::make_pair(maxRange.first, half)); // Remove the range that has been split. ranges.erase(it + 2); - return newNodes; + return getVariants().drop_front(oldNumVariant); } void ReductionNode::update(std::pair result) { std::tie(interesting, size) = result; + // After applying reduction, the number of operation in the region may have + // changed. Non-interesting case won't be explored thus it's safe to keep it + // in a stale status. + if (interesting == Tester::Interestingness::True) { + // This module may has been updated. Reset the range. + ranges.clear(); + ranges.push_back({0, std::distance(region->op_begin(), region->op_end())}); + } } -std::vector +ArrayRef ReductionNode::iterator::getNeighbors(ReductionNode *node) { // Single Path: Traverses the smallest successful variant at each level until // no new successful variants can be created at that level. - llvm::ArrayRef variantsFromParent = + ArrayRef variantsFromParent = node->getParent()->getVariants(); // The parent node created several variants and they may be waiting for @@ -139,7 +140,8 @@ smallest = node; } - if (smallest != nullptr) { + if (smallest != nullptr && + smallest->getSize() < node->getParent()->getSize()) { // We got a smallest one, keep traversing from this node. node = smallest; } else { diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -0,0 +1,247 @@ +//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Reduction Tree Pass class. It provides a framework for +// the implementation of different reduction passes in the MLIR Reduce tool. It +// allows for custom specification of the variant generation behavior. It +// implements methods that define the different possible traversals of the +// reduction tree. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Reducer/PassDetail.h" +#include "mlir/Reducer/Passes.h" +#include "mlir/Reducer/ReductionNode.h" +#include "mlir/Reducer/ReductionPatternInterface.h" +#include "mlir/Reducer/Tester.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ManagedStatic.h" + +using namespace mlir; + +/// We implicitly number each operation in the region and if an operation's +/// number falls into rangeToKeep, we need to keep it and apply the given +/// rewrite patterns on it. +static void applyPatterns(Region ®ion, + const FrozenRewritePatternSet &patterns, + ArrayRef rangeToKeep, + bool eraseOpNotInRange) { + std::vector opsNotInRange; + std::vector opsInRange; + size_t keepIndex = 0; + for (auto op : enumerate(region.getOps())) { + int index = op.index(); + if (keepIndex < rangeToKeep.size() && + index == rangeToKeep[keepIndex].second) + ++keepIndex; + if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) + opsNotInRange.push_back(&op.value()); + else + opsInRange.push_back(&op.value()); + } + + // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern + // matching in above iteration. Besides, erase op not-in-range may end up in + // invalid module, so `applyOpPatternsAndFold` should come before that + // transform. + for (Operation *op : opsInRange) + // `applyOpPatternsAndFold` returns whether the op is convered. Omit it + // because we don't have expectation this reduction will be success or not. + (void)applyOpPatternsAndFold(op, patterns); + + if (eraseOpNotInRange) + for (Operation *op : opsNotInRange) { + op->dropAllUses(); + op->erase(); + } +} + +/// We will apply the reducer patterns to the operations in the ranges specified +/// by ReductionNode. Note that we are not able to remove an operation without +/// replacing it with another valid operation. However, The validity of module +/// reduction is based on the Tester provided by the user and that means certain +/// invalid module is still interested by the use. Thus we provide an +/// alternative way to remove operations, which is using `eraseOpNotInRange` to +/// erase the operations not in the range specified by ReductionNode. +template +static void findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + const Tester &test, bool eraseOpNotInRange) { + std::pair initStatus = + test.isInteresting(module); + // While exploring the reduction tree, we always branch from an interesting + // node. Thus the root node must be interesting. + if (initStatus.first != Tester::Interestingness::True) + return; + + llvm::SpecificBumpPtrAllocator allocator; + + std::vector ranges{ + {0, std::distance(region.op_begin(), region.op_end())}}; + + ReductionNode *root = allocator.Allocate(); + new (root) ReductionNode(nullptr, std::move(ranges), allocator); + // Duplicate the module for root node and locate the region in the copy. + if (failed(root->initialize(module, region))) + llvm_unreachable("unexpected initialization failure"); + root->update(initStatus); + + ReductionNode *smallestNode = root; + IteratorType iter(root); + + while (iter != IteratorType::end()) { + ReductionNode ¤tNode = *iter; + Region &curRegion = currentNode.getRegion(); + + applyPatterns(curRegion, patterns, currentNode.getRanges(), + eraseOpNotInRange); + currentNode.update(test.isInteresting(currentNode.getModule())); + + if (currentNode.isInteresting() == Tester::Interestingness::True && + currentNode.getSize() < smallestNode->getSize()) + smallestNode = ¤tNode; + + ++iter; + } + + // At here, we have found an optimal path to reduce the given region. Retrieve + // the path and apply the reducer to it. + SmallVector trace; + ReductionNode *curNode = smallestNode; + trace.push_back(curNode); + while (curNode != root) { + curNode = curNode->getParent(); + trace.push_back(curNode); + } + + // Reduce the region through the optimal path. + while (!trace.empty()) { + ReductionNode *top = trace.pop_back_val(); + applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); + } + + if (test.isInteresting(module).first != Tester::Interestingness::True) + llvm::report_fatal_error("Reduced module is not interesting"); + if (test.isInteresting(module).second != smallestNode->getSize()) + llvm::report_fatal_error( + "Reduced module doesn't have consistent size with smallestNode"); +} + +template +static void findOptimal(ModuleOp module, Region ®ion, + const FrozenRewritePatternSet &patterns, + const Tester &test) { + // We separate the reduction process into 2 steps, the first one is to erase + // redundant operations and the second one is to apply the reducer patterns. + + // In the first phase, we don't apply any patterns so that we only select the + // range of operations to keep to the module stay interesting. + findOptimal(module, region, /*patterns=*/{}, test, + /*eraseOpNotInRange=*/true); + // In the second phase, we suppose that no operation is redundant, so we try + // to rewrite the operation into simpler form. + findOptimal(module, region, patterns, test, + /*eraseOpNotInRange=*/false); +} + +namespace { + +//===----------------------------------------------------------------------===// +// Reduction Pattern Interface Collection +//===----------------------------------------------------------------------===// + +class ReductionPatternInterfaceCollection + : public DialectInterfaceCollection { +public: + using Base::Base; + + // Collect the reduce patterns defined by each dialect. + void populateReductionPatterns(RewritePatternSet &pattern) const { + for (const DialectReductionPatternInterface &interface : *this) + interface.populateReductionPatterns(pattern); + } +}; + +//===----------------------------------------------------------------------===// +// ReductionTreePass +//===----------------------------------------------------------------------===// + +/// This class defines the Reduction Tree Pass. It provides a framework to +/// to implement a reduction pass using a tree structure to keep track of the +/// generated reduced variants. +class ReductionTreePass : public ReductionTreeBase { +public: + ReductionTreePass() = default; + ReductionTreePass(const ReductionTreePass &pass) = default; + + LogicalResult initialize(MLIRContext *context) override; + + /// Runs the pass instance in the pass pipeline. + void runOnOperation() override; + +private: + void reduceOp(ModuleOp module, Region ®ion); + + FrozenRewritePatternSet reducerPatterns; +}; + +} // end anonymous namespace + +LogicalResult ReductionTreePass::initialize(MLIRContext *context) { + RewritePatternSet patterns(context); + ReductionPatternInterfaceCollection reducePatternCollection(context); + reducePatternCollection.populateReductionPatterns(patterns); + reducerPatterns = std::move(patterns); + return success(); +} + +void ReductionTreePass::runOnOperation() { + Operation *topOperation = getOperation(); + while (topOperation->getParentOp() != nullptr) + topOperation = topOperation->getParentOp(); + ModuleOp module = cast(topOperation); + + SmallVector workList; + workList.push_back(getOperation()); + + do { + Operation *op = workList.pop_back_val(); + + for (Region ®ion : op->getRegions()) + if (!region.empty()) + reduceOp(module, region); + + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (op.getNumRegions() != 0) + workList.push_back(&op); + } while (!workList.empty()); +} + +void ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { + Tester test(testerName, testerArgs); + switch (traversalModeId) { + case TraversalMode::SinglePath: + findOptimal>( + module, region, reducerPatterns, test); + break; + default: + llvm_unreachable("Unsupported mode"); + } +} + +std::unique_ptr mlir::createReductionTreePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Reducer/Tester.cpp b/mlir/lib/Reducer/Tester.cpp --- a/mlir/lib/Reducer/Tester.cpp +++ b/mlir/lib/Reducer/Tester.cpp @@ -15,7 +15,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Reducer/Tester.h" - +#include "mlir/IR/Verifier.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; @@ -25,6 +25,12 @@ std::pair Tester::isInteresting(ModuleOp module) const { + // The reduced module should always be vaild, or we may end up retaining the + // error message by an invalid case. Besides, an invalid module may not be + // able to print properly. + if (failed(verify(module))) + return std::make_pair(Interestingness::False, /*size=*/0); + SmallString<128> filepath; int fd; @@ -50,7 +56,6 @@ /// true if the interesting behavior is present in the test case or false /// otherwise. Tester::Interestingness Tester::isInteresting(StringRef testCase) const { - std::vector testerArgs; testerArgs.push_back(testCase); diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -60,6 +60,7 @@ MLIRInferTypeOpInterface MLIRLinalgTransforms MLIRPass + MLIRReduce MLIRStandard MLIRStandardOpsTransforms MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -34,6 +34,7 @@ namespace mlir { class DLTIDialect; +class RewritePatternSet; } // namespace mlir #include "TestOpEnums.h.inc" @@ -47,6 +48,7 @@ namespace mlir { namespace test { void registerTestDialect(DialectRegistry ®istry); +void populateTestReductionPatterns(RewritePatternSet &patterns); } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "TestAttributes.h" +#include "TestInterfaces.h" #include "TestTypes.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -16,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/StringSwitch.h" @@ -170,6 +172,18 @@ return builder.create(conversionLoc, resultType, input); } }; + +struct TestReductionPatternInterface : public DialectReductionPatternInterface { +public: + TestReductionPatternInterface(Dialect *dialect) + : DialectReductionPatternInterface(dialect) {} + + virtual void + populateReductionPatterns(RewritePatternSet &patterns) const final { + populateTestReductionPatterns(patterns); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -207,7 +221,7 @@ #include "TestOps.cpp.inc" >(); addInterfaces(); + TestInlinerInterface, TestReductionPatternInterface>(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2113,4 +2113,19 @@ let results = (outs AnyType:$res); } +//===----------------------------------------------------------------------===// +// Test Reducer Patterns +//===----------------------------------------------------------------------===// + +def OpCrashLong : TEST_Op<"op_crash_long"> { + let arguments = (ins I32, I32, I32); + let results = (outs I32); +} + +def OpCrashShort : TEST_Op<"op_crash_short"> { + let results = (outs I32); +} + +def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>; + #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -58,6 +58,14 @@ #include "TestPatterns.inc" } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Test Reduce Pattern Interface +//===----------------------------------------------------------------------===// + +void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) { + populateWithGenerated(patterns); +} + //===----------------------------------------------------------------------===// // Canonicalizer Driver. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Reducer/MLIRTestReducer.cpp b/mlir/test/lib/Reducer/MLIRTestReducer.cpp --- a/mlir/test/lib/Reducer/MLIRTestReducer.cpp +++ b/mlir/test/lib/Reducer/MLIRTestReducer.cpp @@ -38,7 +38,7 @@ op.walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); - if (opName == "test.crashOp") { + if (opName.contains("op_crash")) { llvm::errs() << "MLIR Reducer Test generated failure: Found " "\"crashOp\" operation\n"; exit(1); diff --git a/mlir/test/mlir-reduce/crashop-reduction.mlir b/mlir/test/mlir-reduce/crashop-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/crashop-reduction.mlir @@ -0,0 +1,20 @@ +// UNSUPPORTED: system-windows +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short". + +// CHECK-NOT: func @simple1() { +func @simple1() { + return +} + +// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { +func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32 + %0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32 + return +} + +// CHECK-NOT: func @simple5() { +func @simple5() { + return +} diff --git a/mlir/test/mlir-reduce/dce-test.mlir b/mlir/test/mlir-reduce/dce-test.mlir --- a/mlir/test/mlir-reduce/dce-test.mlir +++ b/mlir/test/mlir-reduce/dce-test.mlir @@ -12,6 +12,6 @@ // CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } diff --git a/mlir/test/mlir-reduce/multiple-function.mlir b/mlir/test/mlir-reduce/multiple-function.mlir --- a/mlir/test/mlir-reduce/multiple-function.mlir +++ b/mlir/test/mlir-reduce/multiple-function.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s // This input should be reduced by the pass pipeline so that only // the @simple5 function remains as this is the shortest function // containing the interesting behavior. @@ -16,7 +16,7 @@ // CHECK-LABEL: func @simple3() { func @simple3() { - "test.crashOp" () : () -> () + "test.op_crash" () : () -> () return } @@ -29,7 +29,7 @@ %0 = memref.alloc() : memref<2xf32> br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): - "test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } diff --git a/mlir/test/mlir-reduce/simple-test.mlir b/mlir/test/mlir-reduce/simple-test.mlir --- a/mlir/test/mlir-reduce/simple-test.mlir +++ b/mlir/test/mlir-reduce/simple-test.mlir @@ -1,5 +1,5 @@ // UNSUPPORTED: system-windows -// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh' +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh' func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { cond_br %arg0, ^bb1, ^bb2 diff --git a/mlir/test/mlir-reduce/single-function.mlir b/mlir/test/mlir-reduce/single-function.mlir --- a/mlir/test/mlir-reduce/single-function.mlir +++ b/mlir/test/mlir-reduce/single-function.mlir @@ -2,6 +2,6 @@ // RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer func @test() { - "test.crashOp"() : () -> () + "test.op_crash"() : () -> () return } diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -43,9 +43,6 @@ ) add_llvm_tool(mlir-reduce - OptReductionPass.cpp - ReductionNode.cpp - ReductionTreePass.cpp mlir-reduce.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/tools/mlir-reduce/ReductionTreePass.cpp b/mlir/tools/mlir-reduce/ReductionTreePass.cpp deleted file mode 100644 --- a/mlir/tools/mlir-reduce/ReductionTreePass.cpp +++ /dev/null @@ -1,107 +0,0 @@ -//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the Reduction Tree Pass class. It provides a framework for -// the implementation of different reduction passes in the MLIR Reduce tool. It -// allows for custom specification of the variant generation behavior. It -// implements methods that define the different possible traversals of the -// reduction tree. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Passes.h" - -#include "llvm/Support/Allocator.h" - -using namespace mlir; - -static std::unique_ptr getOpReducer(llvm::StringRef opType) { - if (opType == ModuleOp::getOperationName()) - return std::make_unique>(); - else if (opType == FuncOp::getOperationName()) - return std::make_unique>(); - llvm_unreachable("Now only supports two built-in ops"); -} - -void ReductionTreePass::runOnOperation() { - ModuleOp module = this->getOperation(); - std::unique_ptr reducer = getOpReducer(opReducerName); - std::vector> ranges = { - {0, reducer->getNumTargetOps(module)}}; - - llvm::SpecificBumpPtrAllocator allocator; - - ReductionNode *root = allocator.Allocate(); - new (root) ReductionNode(nullptr, ranges, allocator); - - ModuleOp golden = module; - switch (traversalModeId) { - case TraversalMode::SinglePath: - golden = findOptimal>( - module, std::move(reducer), root); - break; - default: - llvm_unreachable("Unsupported mode"); - } - - if (golden != module) { - module.getBody()->clear(); - module.getBody()->getOperations().splice(module.getBody()->begin(), - golden.getBody()->getOperations()); - golden->destroy(); - } -} - -template -ModuleOp ReductionTreePass::findOptimal(ModuleOp module, - std::unique_ptr reducer, - ReductionNode *root) { - Tester test(testerName, testerArgs); - std::pair initStatus = - test.isInteresting(module); - - if (initStatus.first != Tester::Interestingness::True) { - LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested"); - return module; - } - - root->update(initStatus); - - ReductionNode *smallestNode = root; - ModuleOp golden = module; - - IteratorType iter(root); - - while (iter != IteratorType::end()) { - ModuleOp cloneModule = module.clone(); - - ReductionNode ¤tNode = *iter; - reducer->reduce(cloneModule, currentNode.getRanges()); - - std::pair result = - test.isInteresting(cloneModule); - currentNode.update(result); - - if (result.first == Tester::Interestingness::True && - result.second < smallestNode->getSize()) { - smallestNode = ¤tNode; - golden = cloneModule; - } else { - cloneModule->destroy(); - } - - ++iter; - } - - return golden; -} - -std::unique_ptr mlir::createReductionTreePass() { - return std::make_unique(); -} diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp --- a/mlir/tools/mlir-reduce/mlir-reduce.cpp +++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp @@ -13,22 +13,14 @@ // //===----------------------------------------------------------------------===// -#include - #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Reducer/OptReductionPass.h" #include "mlir/Reducer/Passes.h" -#include "mlir/Reducer/Passes/OpReducer.h" -#include "mlir/Reducer/ReductionNode.h" -#include "mlir/Reducer/ReductionTreePass.h" -#include "mlir/Reducer/Tester.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Passes.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/ToolOutputFile.h"