diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -86,18 +86,6 @@ return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); } -/// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns -/// success if no more patterns can be matched. `erased` is set to true if `op` -/// was folded away or erased as a result of becoming dead. Note: This does not -/// apply any patterns recursively to the regions of `op`. -/// -/// Returns success if the iterative process converged and no more patterns can -/// be matched. -LogicalResult applyOpPatternsAndFold(Operation *op, - const FrozenRewritePatternSet &patterns, - bool *erased = nullptr); - /// Applies the specified rewrite patterns on `ops` while also trying to fold /// these ops. /// @@ -132,6 +120,21 @@ bool *allErased = nullptr, Region *scope = nullptr); +/// Applies the specified patterns on `op` alone while also trying to fold it, +/// by selecting the highest benefits patterns in a greedy manner. Returns +/// success if no more patterns can be matched. `erased` is set to true if `op` +/// was folded away or erased as a result of becoming dead. +/// +/// Returns success if the iterative process converged and no more patterns can +/// be matched. +inline LogicalResult +applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, + bool *erased = nullptr) { + return applyOpPatternsAndFold(ArrayRef(op), patterns, + GreedyRewriteStrictness::ExistingOps, + /*changed=*/nullptr, erased); +} + } // namespace mlir #endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -459,109 +459,6 @@ return success(converged); } -//===----------------------------------------------------------------------===// -// OpPatternRewriteDriver -//===----------------------------------------------------------------------===// - -namespace { -/// This is a simple driver for the PatternMatcher to apply patterns and perform -/// folding on a single op. It repeatedly applies locally optimal patterns. -class OpPatternRewriteDriver : public PatternRewriter { -public: - explicit OpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { - // Apply a simple cost model based solely on pattern benefit. - matcher.applyDefaultCostModel(); - } - - LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites, - bool &erased); - - // These are hooks implemented for PatternRewriter. -protected: - /// If an operation is about to be removed, mark it so that we can let clients - /// know. - void notifyOperationRemoved(Operation *op) override { - if (this->op == op) - opErasedViaPatternRewrites = true; - } - - // When a root is going to be replaced, its removal will be notified as well. - // So there is nothing to do here. - void notifyRootReplaced(Operation *op, ValueRange replacement) override {} - -private: - /// The low-level pattern applicator. - PatternApplicator matcher; - - /// Non-pattern based folder for operations. - OperationFolder folder; - - /// Op that is being processed. - Operation *op = nullptr; - - /// Set to true if the operation has been erased via pattern rewrites. - bool opErasedViaPatternRewrites = false; -}; - -} // namespace - -/// Performs the rewrites and folding only on `op`. The simplification -/// converges if the op is erased as a result of being folded, replaced, or -/// becoming dead, or no more changes happen in an iteration. Returns success if -/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op` -/// gets erased. -LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, - int64_t maxNumRewrites, - bool &erased) { - this->op = op; - bool changed = false; - erased = false; - opErasedViaPatternRewrites = false; - int64_t numRewrites = 0; - // Iterate until convergence or until maxNumRewrites. Deletion of the op as - // a result of being dead or folded is convergence. - do { - if (numRewrites >= maxNumRewrites && - maxNumRewrites != GreedyRewriteConfig::kNoLimit) - break; - - changed = false; - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - op->erase(); - erased = true; - return success(); - } - - // Try to fold this op. - bool inPlaceUpdate; - if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, - /*preReplaceAction=*/nullptr, - &inPlaceUpdate))) { - changed = true; - if (!inPlaceUpdate) { - erased = true; - return success(); - } - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do here. - if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; - ++numRewrites; - } - if ((erased = opErasedViaPatternRewrites)) - return success(); - } while (changed); - - // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return failure(changed); -} - //===----------------------------------------------------------------------===// // MultiOpPatternRewriteDriver //===----------------------------------------------------------------------===// @@ -734,23 +631,6 @@ return success(worklist.empty()); } -LogicalResult mlir::applyOpPatternsAndFold( - Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { - // Start the pattern driver. - GreedyRewriteConfig config; - OpPatternRewriteDriver driver(op->getContext(), patterns); - bool opErased; - LogicalResult converged = - driver.simplifyLocally(op, config.maxNumRewrites, opErased); - if (erased) - *erased = opErased; - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after " - << config.maxNumRewrites << " rewrites"; - }); - return converged; -} - /// Find the region that is the closest common ancestor of all given ops. static Region *findCommonAncestor(ArrayRef ops) { assert(!ops.empty() && "expected at least one op"); @@ -811,5 +691,9 @@ ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope); if (allErased) *allErased = surviving.empty(); + LLVM_DEBUG(if (failed(converged)) { + llvm::dbgs() << "The pattern rewrite did not converge after " + << GreedyRewriteConfig().maxNumRewrites << " rewrites"; + }); return converged; }