diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -446,6 +446,15 @@ /// Rewrite the given regions, which must be isolated from above. bool applyPatternsAndFoldGreedily(MutableArrayRef regions, const OwningRewritePatternList &patterns); + +/// Applies the specified patterns on `op` alone while also trying to fold it, +/// by select the highest benefits patterns in a greedy manner. Returns true if +/// no more patterns can be matched. `erased` is set to true if `op` is folded +/// away or erased as a result of becoming dead. +/// Note: This does not apply pattern to any regions of `op`. +bool applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool &erased); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -40,9 +40,8 @@ worklist.reserve(64); } - /// Perform the rewrites while folding and erasing any dead ops. Return true - /// if the rewrite converges in `maxIterations`. bool simplify(MutableArrayRef regions, int maxIterations); + bool simplifyLocally(Operation *op, int maxIterations, bool &erased); void addToWorklist(Operation *op) { // Check to see if the worklist already contains this op. @@ -215,6 +214,47 @@ return !changed; } +/// Performs the rewrites and folding only on `op`. `erased` is set to true if +/// the op is folded or erased as a result of being dead. Returns true if the +/// rewrite converges in `maxIterations`. +bool GreedyPatternRewriteDriver::simplifyLocally(Operation *op, + int maxIterations, + bool &erased) { + bool changed = false; + int i = 0; + erased = false; + do { + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + op->erase(); + erased = true; + return true; + } + + // Try to fold this op. + bool inPlaceUpdate; + if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, + /*preReplaceAction=*/nullptr, + &inPlaceUpdate))) { + changed = true; + if (!inPlaceUpdate) { + erased = true; + return true; + } + } + + // Make sure that any new operations are inserted at this point. + setInsertionPoint(op); + + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do here. + changed |= matcher.matchAndRewrite(op, *this); + } while (changed && ++i < maxIterations); + + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return !changed; +} + /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy /// work-list driven manner. Return true if no more patterns can be matched in @@ -251,3 +291,19 @@ }); return converged; } + +/// Rewrites `op` alone using the supplied canonicalization patterns and +/// folding. +bool mlir::applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool &erased) { + // Start the pattern driver. + GreedyPatternRewriteDriver driver(op->getContext(), patterns); + bool converged = + driver.simplifyLocally(op, maxPatternMatchIterations, erased); + LLVM_DEBUG(if (!converged) { + llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " + << maxPatternMatchIterations << " times"; + }); + return converged; +}