diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -29,11 +29,94 @@ #define DEBUG_TYPE "greedy-rewriter" +namespace { + +//===----------------------------------------------------------------------===// +// Worklist +//===----------------------------------------------------------------------===// + +/// A LIFO worklist of operations with efficient removal and set semantics. +/// +/// This class maintains a vector of operations and a mapping of operations to +/// positions in the vector, so that operations can be removed efficiently at +/// random. When an operation is removed, it is replaced with nullptr. Such +/// nullptr are skipped when pop'ing elements. +class Worklist { +public: + Worklist() { list.reserve(64); } + + /// Clear the worklist. + void clear() { + list.clear(); + map.clear(); + } + + /// Return whether the worklist is empty. + bool empty() const { + // Skip all nullptr. + for (Operation *op : list) + if (op) + return false; + return true; + } + + /// Push an operation to the end of the worklist, unless the operation is + /// already on the worklist. + void push(Operation *op) { + assert(op && "cannot push nullptr to worklist"); + // Check to see if the worklist already contains this op. + if (map.count(op)) + return; + map[op] = list.size(); + list.push_back(op); + } + + /// Pop the an operation from the end of the worklist. Only allowed on + /// non-empty worklists. + Operation *pop() { + assert(!empty() && "cannot pop from empty worklist"); + // Skip and remove all trailing nullptr. + while (!list.back()) + list.pop_back(); + Operation *op = list.back(); + list.pop_back(); + map.erase(op); + // Cleanup: Remove all trailing nullptr. + while (!list.empty() && !list.back()) + list.pop_back(); + return op; + } + + /// Remove an operation from the worklist. + void remove(Operation *op) { + assert(op && "cannot remove nullptr from worklist"); + auto it = map.find(op); + if (it != map.end()) { + assert(list[it->second] == op && "malformed worklist data structure"); + list[it->second] = nullptr; + map.erase(it); + } + } + + /// Reverse the worklist. + void reverse() { + std::reverse(list.begin(), list.end()); + for (size_t i = 0, e = list.size(); i != e; ++i) + map[list[i]] = i; + } + +private: + /// The worklist of operations. + std::vector list; + + /// A mapping of operations to positions in `list`. + DenseMap map; +}; + //===----------------------------------------------------------------------===// // GreedyPatternRewriteDriver //===----------------------------------------------------------------------===// -namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly /// applies the locally optimal patterns. /// @@ -76,11 +159,8 @@ bool processWorklist(); /// The worklist for this transformation keeps track of the operations that - /// need to be revisited, plus their index in the worklist. This allows us to - /// efficiently remove operations from the worklist when they are erased, even - /// if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; + /// need to be (re)visited. + Worklist worklist; /// Non-pattern based folder for operations. OperationFolder folder; @@ -101,9 +181,6 @@ /// simplifications. void addOperandsToWorklist(ValueRange operands); - /// Pop the next operation from the worklist. - Operation *popFromWorklist(); - /// Notify the driver that the given block was created. void notifyBlockCreated(Block *block) override; @@ -112,9 +189,6 @@ notifyMatchFailure(Location loc, function_ref reasonCallback) override; - /// If the specified operation is in the worklist, remove it. - void removeFromWorklist(Operation *op); - #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; @@ -130,8 +204,6 @@ const GreedyRewriteConfig &config) : PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns) { - worklist.reserve(64); - // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); @@ -163,12 +235,7 @@ while (!worklist.empty() && (numRewrites < config.maxNumRewrites || config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - auto *op = popFromWorklist(); - - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; + auto *op = worklist.pop(); LLVM_DEBUG({ logger.getOStream() << "\n"; @@ -265,33 +332,8 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { if (config.strictMode == GreedyRewriteStrictness::AnyOp || - strictModeFilteredOps.contains(op)) { - // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) - return; - - worklistMap[op] = worklist.size(); - worklist.push_back(op); - } -} - -Operation *GreedyPatternRewriteDriver::popFromWorklist() { - auto *op = worklist.back(); - worklist.pop_back(); - - // This operation is no longer in the worklist, keep worklistMap up to date. - if (op) - worklistMap.erase(op); - return op; -} - -void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) { - auto it = worklistMap.find(op); - if (it != worklistMap.end()) { - assert(worklist[it->second] == op && "malformed worklist data structure"); - worklist[it->second] = nullptr; - worklistMap.erase(it); - } + strictModeFilteredOps.contains(op)) + worklist.push(op); } void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) { @@ -345,7 +387,7 @@ addOperandsToWorklist(op->getOperands()); op->walk([this](Operation *operation) { - removeFromWorklist(operation); + worklist.remove(operation); folder.notifyRemoval(operation); }); @@ -450,7 +492,6 @@ break; worklist.clear(); - worklistMap.clear(); if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. @@ -469,10 +510,7 @@ }); // Reverse the list so our pop-back loop processes them in-order. - std::reverse(worklist.begin(), worklist.end()); - // Remember the reverse index. - for (size_t i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; + worklist.reverse(); } ctx->executeAction(