diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -37,15 +37,21 @@ /// generally more efficient in compile time. When set to false, its initial /// traversal of the region tree is bottom up on each block, which may match /// larger patterns when given an ambiguous pattern set. + /// + /// Note: Only applicable when simplifying entire regions. bool useTopDownTraversal = false; - // Perform control flow optimizations to the region tree after applying all - // patterns. + /// Perform control flow optimizations to the region tree after applying all + /// patterns. + /// + /// Note: Only applicable when simplifying entire regions. bool enableRegionSimplification = true; /// This specifies the maximum number of times the rewriter will iterate /// between applying patterns and simplifying regions. Use `kNoLimit` to /// disable this iteration limit. + /// + /// Note: Only applicable when simplifying entire regions. int64_t maxIterations = 10; /// This specifies the maximum number of rewrites within an iteration. Use @@ -53,6 +59,10 @@ int64_t maxNumRewrites = kNoLimit; static constexpr int64_t kNoLimit = -1; + + /// Only ops within the scope are added to the worklist. If no scope is + /// specified, the closest enclosing region is used as a scope. + Region *scope = nullptr; }; //===----------------------------------------------------------------------===// @@ -117,12 +127,12 @@ /// Returns success if the iterative process converged and no more patterns can /// be matched. `changed` is set to true if the IR was modified at all. /// `allOpsErased` is set to true if all ops in `ops` were erased. -LogicalResult applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, - bool *changed = nullptr, - bool *allErased = nullptr, - Region *scope = nullptr); +LogicalResult +applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr, bool *allErased = nullptr); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns @@ -133,9 +143,10 @@ /// be matched. inline LogicalResult applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), bool *erased = nullptr) { return applyOpPatternsAndFold(ArrayRef(op), patterns, - GreedyRewriteStrictness::ExistingOps, + GreedyRewriteStrictness::ExistingOps, config, /*changed=*/nullptr, erased); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -322,8 +322,8 @@ RewritePatternSet patterns(res.getContext()); AffineForOp::getCanonicalizationPatterns(patterns, res.getContext()); bool erased; - (void)applyOpPatternsAndFold(res, std::move(patterns), &erased); - + (void)applyOpPatternsAndFold(res, std::move(patterns), + GreedyRewriteConfig(), &erased); if (!erased && !prologue) prologue = res; if (!erased) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -415,7 +415,8 @@ AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); bool erased; FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); + (void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(), + &erased); if (erased) { if (folded) *folded = true; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,8 +39,7 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, - const Region &scope); + const GreedyRewriteConfig &config); /// Simplify the ops within the given region. bool simplify(Region ®ion) &&; @@ -103,9 +102,6 @@ /// Configuration information for how to simplify. const GreedyRewriteConfig config; - /// Only ops within this scope are simplified. - const Region &scope; - private: #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -116,9 +112,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, const Region &scope) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), - scope(scope) { + const GreedyRewriteConfig &config) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + assert(config.scope && "scope is not specified"); worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -313,7 +309,7 @@ SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (&scope == region) { + if (config.scope == region) { // All gathered ops are in fact ancestors. for (Operation *op : ancestors) addSingleOpToWorklist(op); @@ -434,9 +430,12 @@ assert(region.getParentOp()->hasTrait() && "patterns can only be applied to operations IsolatedFromAbove"); + // Set scope if not specified. + if (!config.scope) + config.scope = ®ion; + // Start the pattern driver. - GreedyPatternRewriteDriver driver(region.getContext(), patterns, config, - region); + GreedyPatternRewriteDriver driver(region.getContext(), patterns, config); bool converged = std::move(driver).simplify(region); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " @@ -460,9 +459,9 @@ public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const Region &scope, GreedyRewriteStrictness strictMode, + GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config, llvm::SmallDenseSet *survivingOps = nullptr) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope), + : GreedyPatternRewriteDriver(ctx, patterns, config), strictMode(strictMode), survivingOps(survivingOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these @@ -636,11 +635,10 @@ return region; } -LogicalResult -mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, bool *changed, - bool *allErased, Region *scope) { +LogicalResult mlir::applyOpPatternsAndFold( + ArrayRef ops, const FrozenRewritePatternSet &patterns, + GreedyRewriteStrictness strictMode, GreedyRewriteConfig config, + bool *changed, bool *allErased) { if (ops.empty()) { if (changed) *changed = false; @@ -649,14 +647,15 @@ return success(); } - if (!scope) { + // Determine scope of rewrite. + if (!config.scope) { // Compute scope if none was provided. - scope = findCommonAncestor(ops); + config.scope = findCommonAncestor(ops); } else { // If a scope was provided, make sure that all ops are in scope. #ifndef NDEBUG bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { - return static_cast(scope->findAncestorOpInRegion(*op)); + return static_cast(config.scope->findAncestorOpInRegion(*op)); }); assert(allOpsInScope && "ops must be within the specified scope"); #endif // NDEBUG @@ -665,14 +664,14 @@ // Start the pattern driver. llvm::SmallDenseSet surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - *scope, strictMode, + strictMode, config, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) *allErased = surviving.empty(); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after " - << GreedyRewriteConfig().maxNumRewrites << " rewrites"; + << config.maxNumRewrites << " rewrites"; }); return converged; } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -283,7 +283,7 @@ bool changed = false; bool allErased = false; (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode, - &changed, &allErased); + GreedyRewriteConfig(), &changed, &allErased); Builder b(ctx); getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); getOperation()->setAttr("pattern_driver_all_erased",