diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -37,15 +37,21 @@ /// generally more efficient in compile time. When set to false, its initial /// traversal of the region tree is bottom up on each block, which may match /// larger patterns when given an ambiguous pattern set. + /// + /// Note: Only applicable when simplifying entire regions. bool useTopDownTraversal = false; - // Perform control flow optimizations to the region tree after applying all - // patterns. + /// Perform control flow optimizations to the region tree after applying all + /// patterns. + /// + /// Note: Only applicable when simplifying entire regions. bool enableRegionSimplification = true; /// This specifies the maximum number of times the rewriter will iterate /// between applying patterns and simplifying regions. Use `kNoLimit` to /// disable this iteration limit. + /// + /// Note: Only applicable when simplifying entire regions. int64_t maxIterations = 10; /// This specifies the maximum number of rewrites within an iteration. Use @@ -53,6 +59,10 @@ int64_t maxNumRewrites = kNoLimit; static constexpr int64_t kNoLimit = -1; + + /// Only ops within the scope are added to the worklist. If no scope is + /// specified, the closest enclosing region is used as a scope. + Region *scope = nullptr; }; //===----------------------------------------------------------------------===// @@ -117,12 +127,12 @@ /// Returns success if the iterative process converged and no more patterns can /// be matched. `changed` is set to true if the IR was modified at all. /// `allOpsErased` is set to true if all ops in `ops` were erased. -LogicalResult applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, - bool *changed = nullptr, - bool *allErased = nullptr, - Region *scope = nullptr); +LogicalResult +applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr, bool *allErased = nullptr); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -324,6 +324,7 @@ bool erased; (void)applyOpPatternsAndFold(res.getOperation(), std::move(patterns), GreedyRewriteStrictness::ExistingOps, + GreedyRewriteConfig(), /*changed=*/nullptr, &erased); if (!erased && !prologue) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -417,6 +417,7 @@ FrozenRewritePatternSet frozenPatterns(std::move(patterns)); (void)applyOpPatternsAndFold(ifOp.getOperation(), frozenPatterns, GreedyRewriteStrictness::ExistingOps, + GreedyRewriteConfig(), /*changed=*/nullptr, &erased); if (erased) { if (folded) diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,8 +39,7 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, - const Region &scope); + const GreedyRewriteConfig &config); /// Simplify the operations within the given regions. bool simplify(MutableArrayRef regions) &&; @@ -103,9 +102,6 @@ /// Configuration information for how to simplify. const GreedyRewriteConfig config; - /// Only ops within this scope are simplified. - const Region &scope; - private: #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -116,9 +112,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, const Region &scope) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), - scope(scope) { + const GreedyRewriteConfig &config) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + assert(config.scope && "scope is not specified"); worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -317,7 +313,7 @@ SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (&scope == region) { + if (&config.scope == region) { // All gathered ops are in fact ancestors. for (Operation *op : ancestors) addSingleOpToWorklist(op); @@ -438,9 +434,12 @@ assert(region.getParentOp()->hasTrait() && "patterns can only be applied to operations IsolatedFromAbove"); + // Set scope if not specified. + if (!config.scope) + config.scope = ®ion; + // Start the pattern driver. - GreedyPatternRewriteDriver driver(region.getContext(), patterns, config, - region); + GreedyPatternRewriteDriver driver(region.getContext(), patterns, config); bool converged = std::move(driver).simplify(region); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " @@ -464,9 +463,9 @@ public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const Region &scope, GreedyRewriteStrictness strictMode, + GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config, llvm::SmallDenseSet *survivingOps = nullptr) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope), + : GreedyPatternRewriteDriver(ctx, patterns, config), strictMode(strictMode), survivingOps(survivingOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these @@ -640,11 +639,10 @@ return region; } -LogicalResult -mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, bool *changed, - bool *allErased, Region *scope) { +LogicalResult mlir::applyOpPatternsAndFold( + ArrayRef ops, const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, GreedyRewriteConfig config, + bool *changed, bool *allErased) { if (ops.empty()) { if (changed) *changed = false; @@ -653,14 +651,15 @@ return success(); } - if (!scope) { + // Determine scope of rewrite. + if (!config.scope) { // Compute scope if none was provided. - scope = findCommonAncestor(ops); + config.scope = findCommonAncestor(ops); } else { // If a scope was provided, make sure that all ops are in scope. #ifndef NDEBUG bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { - return static_cast(scope->findAncestorOpInRegion(*op)); + return static_cast(config.scope->findAncestorOpInRegion(*op)); }); assert(allOpsInScope && "ops must be within the specified scope"); #endif // NDEBUG @@ -669,14 +668,14 @@ // Start the pattern driver. llvm::SmallDenseSet surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - *scope, strictMode, + strictMode, config, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) *allErased = surviving.empty(); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after " - << GreedyRewriteConfig().maxNumRewrites << " rewrites"; + << config.maxNumRewrites << " rewrites"; }); return converged; }