diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -10,6 +10,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -89,9 +91,9 @@ PatternApplicator matcher; /// The worklist for this transformation keeps track of the operations that - /// need to be revisited, plus their index in the worklist. This allows us to - /// efficiently remove operations from the worklist when they are erased, even - /// if they aren't the root of a pattern. + /// need to be revisited, plus their index in the worklist (unless fuzzing is + /// enabled). This allows us to efficiently remove operations from the + /// worklist when they are erased, even if they aren't the root of a pattern. std::vector worklist; DenseMap worklistMap; @@ -103,6 +105,9 @@ GreedyRewriteConfig config; private: + /// An optional fuzzer that randomizes the worklist. + std::optional fuzzer; + /// Only ops within this scope are simplified. This is set at the beginning /// of `simplify()` to the current scope the rewriter operates on. DenseSet scope; @@ -118,6 +123,9 @@ MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + // TODO: Where to initialize the fuzzer? + fuzzer.emplace(13); + worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -190,8 +198,10 @@ // Reverse the list so our pop-back loop processes them in-order. std::reverse(worklist.begin(), worklist.end()); // Remember the reverse index. - for (size_t i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; + if (!fuzzer) { + for (size_t i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; + } } // These are scratch vectors used in the folding loop below. @@ -334,29 +344,47 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) + if (!fuzzer) { + if (worklistMap.count(op)) + return; + + worklistMap[op] = worklist.size(); + } else if (llvm::find(worklist, op) != worklist.end()) { return; + } - worklistMap[op] = worklist.size(); worklist.push_back(op); } Operation *GreedyPatternRewriteDriver::popFromWorklist() { - auto *op = worklist.back(); - worklist.pop_back(); + Operation *op = nullptr; + if (fuzzer) { + int64_t pos = (*fuzzer)() % worklist.size(); + op = worklist[pos]; + worklist.erase(worklist.begin() + pos); + } else { + op = worklist.back(); + worklist.pop_back(); + // This operation is no longer in the worklist, keep worklistMap up to date. + if (op) + worklistMap.erase(op); + } - // This operation is no longer in the worklist, keep worklistMap up to date. - if (op) - worklistMap.erase(op); return op; } void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) { - auto it = worklistMap.find(op); - if (it != worklistMap.end()) { - assert(worklist[it->second] == op && "malformed worklist data structure"); - worklist[it->second] = nullptr; - worklistMap.erase(it); + if (fuzzer) { + auto it = llvm::find(worklist, op); + if (it != worklist.end()) + worklist.erase(it); + } else { + auto it = worklistMap.find(op); + if (it != worklistMap.end()) { + assert(worklist[it->second] == op && "malformed worklist data structure"); + worklist[it->second] = nullptr; + worklistMap.erase(it); + } } }