diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake --- a/mlir/include/mlir/Config/mlir-config.h.cmake +++ b/mlir/include/mlir/Config/mlir-config.h.cmake @@ -19,4 +19,11 @@ easier debugging. */ #cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +/* If set, greedy pattern application is randomized: ops on the worklist are + chosen at random. For testing/debugging purposes only. This feature can be + used to ensure that lowering pipelines work correctly regardless of the order + in which ops are processed by the GreedyPatternRewriteDriver. This flag is + numeric seed that is passed to the random number generator. */ +#cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED} + #endif diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -27,6 +27,10 @@ #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" +#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED +#include +#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED + using namespace mlir; #define DEBUG_TYPE "greedy-rewriter" @@ -165,7 +169,7 @@ /// Reverse the worklist. void reverse(); -private: +protected: /// The worklist of operations. std::vector list; @@ -225,6 +229,37 @@ map[list[i]] = i; } +#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED +/// A worklist that pops elements at a random position. This worklist is for +/// testing/debugging purposes only. It can be used to ensure that lowering +/// pipelines work correctly regardless of the order in which ops are processed +/// by the GreedyPatternRewriteDriver. +class RandomizedWorklist : public Worklist { +public: + RandomizedWorklist() : Worklist() { + generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED); + } + + /// Pop a random non-empty op from the worklist. + Operation *pop() { + Operation *op = nullptr; + do { + assert(!list.empty() && "cannot pop from empty worklist"); + int64_t pos = generator() % list.size(); + op = list[pos]; + list.erase(list.begin() + pos); + for (int64_t i = pos, e = list.size(); i < e; ++i) + map[list[i]] = i; + map.erase(op); + } while (!op); + return op; + } + +private: + std::minstd_rand0 generator; +}; +#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED + //===----------------------------------------------------------------------===// // GreedyPatternRewriteDriver //===----------------------------------------------------------------------===// @@ -272,7 +307,11 @@ /// The worklist for this transformation keeps track of the operations that /// need to be (re)visited. +#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED + RandomizedWorklist worklist; +#else Worklist worklist; +#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED /// Non-pattern based folder for operations. OperationFolder folder;