diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -31,11 +31,12 @@ #define DEBUG_TYPE "greedy-rewriter" +namespace { + //===----------------------------------------------------------------------===// // Debugging Infrastructure //===----------------------------------------------------------------------===// -namespace { #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS /// A helper struct that stores finger prints of ops in order to detect broken /// RewritePatterns. A rewrite pattern is broken if it modifies IR without @@ -130,6 +131,100 @@ }; #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +//===----------------------------------------------------------------------===// +// Worklist +//===----------------------------------------------------------------------===// + +/// A LIFO worklist of operations with efficient removal and set semantics. +/// +/// This class maintains a vector of operations and a mapping of operations to +/// positions in the vector, so that operations can be removed efficiently at +/// random. When an operation is removed, it is replaced with nullptr. Such +/// nullptr are skipped when pop'ing elements. +class Worklist { +public: + Worklist(); + + /// Clear the worklist. + void clear(); + + /// Return whether the worklist is empty. + bool empty() const; + + /// Push an operation to the end of the worklist, unless the operation is + /// already on the worklist. + void push(Operation *op); + + /// Pop the an operation from the end of the worklist. Only allowed on + /// non-empty worklists. + Operation *pop(); + + /// Remove an operation from the worklist. + void remove(Operation *op); + + /// Reverse the worklist. + void reverse(); + +private: + /// The worklist of operations. + std::vector list; + + /// A mapping of operations to positions in `list`. + DenseMap map; +}; + +Worklist::Worklist() { list.reserve(64); } + +void Worklist::clear() { + list.clear(); + map.clear(); +} + +bool Worklist::empty() const { + // Skip all nullptr. + return !llvm::any_of(list, + [](Operation *op) { return static_cast(op); }); +} + +void Worklist::push(Operation *op) { + assert(op && "cannot push nullptr to worklist"); + // Check to see if the worklist already contains this op. + if (map.count(op)) + return; + map[op] = list.size(); + list.push_back(op); +} + +Operation *Worklist::pop() { + assert(!empty() && "cannot pop from empty worklist"); + // Skip and remove all trailing nullptr. + while (!list.back()) + list.pop_back(); + Operation *op = list.back(); + list.pop_back(); + map.erase(op); + // Cleanup: Remove all trailing nullptr. + while (!list.empty() && !list.back()) + list.pop_back(); + return op; +} + +void Worklist::remove(Operation *op) { + assert(op && "cannot remove nullptr from worklist"); + auto it = map.find(op); + if (it != map.end()) { + assert(list[it->second] == op && "malformed worklist data structure"); + list[it->second] = nullptr; + map.erase(it); + } +} + +void Worklist::reverse() { + std::reverse(list.begin(), list.end()); + for (size_t i = 0, e = list.size(); i != e; ++i) + map[list[i]] = i; +} + //===----------------------------------------------------------------------===// // GreedyPatternRewriteDriver //===----------------------------------------------------------------------===// @@ -176,11 +271,8 @@ bool processWorklist(); /// The worklist for this transformation keeps track of the operations that - /// need to be revisited, plus their index in the worklist. This allows us to - /// efficiently remove operations from the worklist when they are erased, even - /// if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; + /// need to be (re)visited. + Worklist worklist; /// Non-pattern based folder for operations. OperationFolder folder; @@ -201,9 +293,6 @@ /// simplifications. void addOperandsToWorklist(ValueRange operands); - /// Pop the next operation from the worklist. - Operation *popFromWorklist(); - /// Notify the driver that the given block was created. void notifyBlockCreated(Block *block) override; @@ -212,9 +301,6 @@ notifyMatchFailure(Location loc, function_ref reasonCallback) override; - /// If the specified operation is in the worklist, remove it. - void removeFromWorklist(Operation *op); - #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; @@ -239,8 +325,6 @@ // clang-format on #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS { - worklist.reserve(64); - // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); @@ -278,12 +362,7 @@ while (!worklist.empty() && (numRewrites < config.maxNumRewrites || config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - auto *op = popFromWorklist(); - - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; + auto *op = worklist.pop(); LLVM_DEBUG({ logger.getOStream() << "\n"; @@ -395,33 +474,8 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { if (config.strictMode == GreedyRewriteStrictness::AnyOp || - strictModeFilteredOps.contains(op)) { - // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) - return; - - worklistMap[op] = worklist.size(); - worklist.push_back(op); - } -} - -Operation *GreedyPatternRewriteDriver::popFromWorklist() { - auto *op = worklist.back(); - worklist.pop_back(); - - // This operation is no longer in the worklist, keep worklistMap up to date. - if (op) - worklistMap.erase(op); - return op; -} - -void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) { - auto it = worklistMap.find(op); - if (it != worklistMap.end()) { - assert(worklist[it->second] == op && "malformed worklist data structure"); - worklist[it->second] = nullptr; - worklistMap.erase(it); - } + strictModeFilteredOps.contains(op)) + worklist.push(op); } void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) { @@ -475,7 +529,7 @@ addOperandsToWorklist(op->getOperands()); op->walk([this](Operation *operation) { - removeFromWorklist(operation); + worklist.remove(operation); folder.notifyRemoval(operation); }); @@ -580,7 +634,6 @@ break; worklist.clear(); - worklistMap.clear(); if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. @@ -599,10 +652,7 @@ }); // Reverse the list so our pop-back loop processes them in-order. - std::reverse(worklist.begin(), worklist.end()); - // Remember the reverse index. - for (size_t i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; + worklist.reverse(); } ctx->executeAction(