diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -86,18 +86,6 @@ return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); } -/// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns -/// success if no more patterns can be matched. `erased` is set to true if `op` -/// was folded away or erased as a result of becoming dead. Note: This does not -/// apply any patterns recursively to the regions of `op`. -/// -/// Returns success if the iterative process converged and no more patterns can -/// be matched. -LogicalResult applyOpPatternsAndFold(Operation *op, - const FrozenRewritePatternSet &patterns, - bool *erased = nullptr); - /// Applies the specified rewrite patterns on `ops` while also trying to fold /// these ops. /// @@ -125,6 +113,21 @@ bool *changed = nullptr, bool *allOpsErased = nullptr); +/// Applies the specified patterns on `op` alone while also trying to fold it, +/// by selecting the highest benefits patterns in a greedy manner. Returns +/// success if no more patterns can be matched. `erased` is set to true if `op` +/// was folded away or erased as a result of becoming dead. +/// +/// Returns success if the iterative process converged and no more patterns can +/// be matched. +inline LogicalResult +applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, + bool *erased = nullptr) { + return applyOpPatternsAndFold(ArrayRef(op), patterns, + GreedyRewriteStrictness::ExistingOps, + /*changed=*/nullptr, erased); +} + } // namespace mlir #endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -457,109 +457,6 @@ return success(converged); } -//===----------------------------------------------------------------------===// -// OpPatternRewriteDriver -//===----------------------------------------------------------------------===// - -namespace { -/// This is a simple driver for the PatternMatcher to apply patterns and perform -/// folding on a single op. It repeatedly applies locally optimal patterns. -class OpPatternRewriteDriver : public PatternRewriter { -public: - explicit OpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { - // Apply a simple cost model based solely on pattern benefit. - matcher.applyDefaultCostModel(); - } - - LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites, - bool &erased); - - // These are hooks implemented for PatternRewriter. -protected: - /// If an operation is about to be removed, mark it so that we can let clients - /// know. - void notifyOperationRemoved(Operation *op) override { - if (this->op == op) - opErasedViaPatternRewrites = true; - } - - // When a root is going to be replaced, its removal will be notified as well. - // So there is nothing to do here. - void notifyRootReplaced(Operation *op, ValueRange replacement) override {} - -private: - /// Op that is being processed. - Operation *op = nullptr; - - /// The low-level pattern applicator. - PatternApplicator matcher; - - /// Non-pattern based folder for operations. - OperationFolder folder; - - /// Set to true if the operation has been erased via pattern rewrites. - bool opErasedViaPatternRewrites = false; -}; - -} // namespace - -/// Performs the rewrites and folding only on `op`. The simplification -/// converges if the op is erased as a result of being folded, replaced, or -/// becoming dead, or no more changes happen in an iteration. Returns success if -/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op` -/// gets erased. -LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, - int64_t maxNumRewrites, - bool &erased) { - this->op = op; - bool changed = false; - erased = false; - opErasedViaPatternRewrites = false; - int64_t numRewrites = 0; - // Iterate until convergence or until maxNumRewrites. Deletion of the op as - // a result of being dead or folded is convergence. - do { - if (numRewrites >= maxNumRewrites && - maxNumRewrites != GreedyRewriteConfig::kNoLimit) - break; - - changed = false; - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - op->erase(); - erased = true; - return success(); - } - - // Try to fold this op. - bool inPlaceUpdate; - if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, - /*preReplaceAction=*/nullptr, - &inPlaceUpdate))) { - changed = true; - if (!inPlaceUpdate) { - erased = true; - return success(); - } - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do here. - if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; - ++numRewrites; - } - if ((erased = opErasedViaPatternRewrites)) - return success(); - } while (changed); - - // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return failure(changed); -} - //===----------------------------------------------------------------------===// // MultiOpPatternRewriteDriver //===----------------------------------------------------------------------===// @@ -717,23 +614,6 @@ return success(worklist.empty()); } -LogicalResult mlir::applyOpPatternsAndFold( - Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { - // Start the pattern driver. - GreedyRewriteConfig config; - OpPatternRewriteDriver driver(op->getContext(), patterns); - bool opErased; - LogicalResult converged = - driver.simplifyLocally(op, config.maxNumRewrites, opErased); - if (erased) - *erased = opErased; - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after " - << config.maxNumRewrites << " rewrites"; - }); - return converged; -} - LogicalResult mlir::applyOpPatternsAndFold( ArrayRef ops, const FrozenRewritePatternSet &patterns, GreedyRewriteStrictness strictMode, bool *changed, bool *allOpsErased) { @@ -754,5 +634,9 @@ if (allOpsErased) *allOpsErased = llvm::all_of(ops, [&](Operation *op) { return erased.contains(op); }); + LLVM_DEBUG(if (failed(converged)) { + llvm::dbgs() << "The pattern rewrite did not converge after " + << GreedyRewriteConfig().maxNumRewrites << " rewrites"; + }); return converged; }