diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -455,6 +455,15 @@ /// Rewrite the given regions, which must be isolated from above. bool applyPatternsAndFoldGreedily(MutableArrayRef regions, const OwningRewritePatternList &patterns); + +/// Applies the specified patterns on `op` alone while also trying to fold it, +/// by select the highest benefits patterns in a greedy manner. Returns true if +/// no more patterns can be matched. `erased` is set to true if `op` is folded +/// away or erased as a result of becoming dead. +/// Note: This does not apply pattern to any regions of `op`. +bool applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool &erased); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -26,6 +26,10 @@ /// The max number of iterations scanning for pattern match. static unsigned maxPatternMatchIterations = 10; +//===----------------------------------------------------------------------===// +// GreedyPatternRewriteDriver +//===----------------------------------------------------------------------===// + namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly /// applies the locally optimal patterns in a roughly "bottom up" way. @@ -37,8 +41,6 @@ worklist.reserve(64); } - /// Perform the rewrites while folding and erasing any dead ops. Return true - /// if the rewrite converges in `maxIterations`. bool simplify(MutableArrayRef regions, int maxIterations); void addToWorklist(Operation *op) { @@ -248,3 +250,107 @@ }); return converged; } + +//===----------------------------------------------------------------------===// +// OpPatternRewriteDriver +//===----------------------------------------------------------------------===// + +namespace { +/// This is a simple driver for the PatternMatcher to apply patterns and perform +/// folding on a single op. It repeatedly applies the locally optimal patterns. +class OpPatternRewriteDriver : public PatternRewriter { +public: + explicit OpPatternRewriteDriver(MLIRContext *ctx, + const OwningRewritePatternList &patterns) + : PatternRewriter(ctx), matcher(patterns), folder(ctx) {} + + bool simplifyLocally(Operation *op, int maxIterations); + + /// No additional action is performed other than inserting the op for this + /// driver. + Operation *insert(Operation *op) override { return OpBuilder::insert(op); } + + bool isErased() { return erased; } + + // These are hooks implemented for PatternRewriter. +protected: + // If an operation is about to be removed, mark it so that we can let clients + // know. + void notifyOperationRemoved(Operation *op) override { erased = true; } + + // When a root is going to be replaced, its removal will be notified as well. + // So there is nothing to do here. + void notifyRootReplaced(Operation *op) override {} + +private: + /// The low-level pattern matcher. + RewritePatternMatcher matcher; + + /// Non-pattern based folder for operations. + OperationFolder folder; + + /// True if the operation has been removed. + bool erased = false; +}; + +} // anonymous namespace + +/// Performs the rewrites and folding only on `op`. The simplification stops +/// if the op is erased as a result of being folded, replaced, or dead, or no +/// more changes happen in an iteration. Returns true if the rewrite converges +/// in `maxIterations`. +bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations) { + bool changed = false; + int i = 0; + // Iterate until convergence or until maxIterations. Deletion of the op as + // a result of being dead or folded is convergence. + do { + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + op->erase(); + erased = true; + return true; + } + + // Try to fold this op. + bool inPlaceUpdate; + if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, + /*preReplaceAction=*/nullptr, + &inPlaceUpdate))) { + changed = true; + if (!inPlaceUpdate) { + erased = true; + return true; + } + } + + // Make sure that any new operations are inserted at this point. + setInsertionPoint(op); + + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do here. + changed |= matcher.matchAndRewrite(op, *this); + if (erased) + return true; + } while (changed && ++i < maxIterations); + + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return !changed; +} + +/// Rewrites only `op` using the supplied canonicalization patterns and +/// folding. `erased` is set to true if the op is erased as a result of being +/// folded, replaced, or dead. +bool mlir::applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool &erased) { + // Start the pattern driver. + OpPatternRewriteDriver driver(op->getContext(), patterns); + bool converged = driver.simplifyLocally(op, maxPatternMatchIterations); + LLVM_DEBUG(if (!converged) { + llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " + << maxPatternMatchIterations << " times"; + }); + erased = driver.isErased(); + return converged; +}