diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -880,7 +880,8 @@ /// place. class PatternRewriter : public RewriterBase { public: - using RewriterBase::RewriterBase; + explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} + explicit PatternRewriter(const OpBuilder &builder) : RewriterBase(builder) {} }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -59,16 +59,29 @@ /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -/// You may configure several aspects of this with GreedyRewriteConfig. +/// You may configure several aspects of this with GreedyRewriteConfig. A +/// pattern rewriter instance can be supplied to hook on to rewrite events, such +/// as operations being removed or replaced, or otherwise alter the behaviour of +/// pattern rewrite functions. LogicalResult applyPatternsAndFoldGreedily( MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, GreedyRewriteConfig config = GreedyRewriteConfig()); -/// Rewrite the given regions, which must be isolated from above. +/// Rewrite the given regions, which must be isolated from above, using the +/// default pattern rewriter. +LogicalResult applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig()); + +/// Rewrite the regions of the given operations, which must be isolated from +/// above, using the default pattern rewriter. inline LogicalResult applyPatternsAndFoldGreedily( Operation *op, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig()) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); + PatternRewriter rewriter(op->getContext()); + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, rewriter, + config); } /// Applies the specified patterns on `op` alone while also trying to fold it, diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -36,7 +36,8 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config); + const GreedyRewriteConfig &config, + PatternRewriter &rewriter); /// Simplify the operations within the given regions. bool simplify(MutableArrayRef regions); @@ -74,6 +75,9 @@ /// PatternRewriter hook for erasing a dead operation. void eraseOp(Operation *op) override; + /// PatternRewriter hook for replacing an operation. + void replaceOp(Operation *op, ValueRange newValues) override; + /// PatternRewriter hook for notifying match failure reasons. LogicalResult notifyMatchFailure(Operation *op, @@ -92,6 +96,10 @@ /// Non-pattern based folder for operations. OperationFolder folder; + /// The pattern rewriter instance to use to perform rewrite operations and to + /// notify of rewrite events. + PatternRewriter &rewriter; + private: /// Configuration information for how to simplify. GreedyRewriteConfig config; @@ -105,8 +113,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + const GreedyRewriteConfig &config, PatternRewriter &rewriter) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), rewriter(rewriter), + config(config) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -309,6 +318,7 @@ << ")\n"; }); addToWorklist(op); + rewriter.notifyOperationInserted(op); } template @@ -335,10 +345,6 @@ } void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) { - LLVM_DEBUG({ - logger.startLine() << "** Replace : '" << op->getName() << "'(" << op - << ")\n"; - }); for (auto result : op->getResults()) for (auto *user : result.getUsers()) addToWorklist(user); @@ -349,7 +355,19 @@ logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); - PatternRewriter::eraseOp(op); + notifyOperationRemoved(op); + rewriter.eraseOp(op); +} + +void GreedyPatternRewriteDriver::replaceOp(Operation *op, + ValueRange newValues) { + LLVM_DEBUG({ + logger.startLine() << "** Replace : '" << op->getName() << "'(" << op + << ")\n"; + }); + notifyRootReplaced(op); + notifyOperationRemoved(op); + rewriter.replaceOp(op, newValues); } LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( @@ -368,10 +386,9 @@ /// in the result operation regions. Note: This does not apply patterns to the /// top-level operation itself. /// -LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config) { +LogicalResult mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, GreedyRewriteConfig config) { if (regions.empty()) return success(); @@ -386,7 +403,8 @@ "patterns can only be applied to operations IsolatedFromAbove"); // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, + rewriter); bool converged = driver.simplify(regions); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " @@ -395,6 +413,17 @@ return success(converged); } +LogicalResult +mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config) { + if (regions.empty()) + return success(); + + PatternRewriter rewriter(regions.front().getContext()); + return applyPatternsAndFoldGreedily(regions, patterns, rewriter, config); +} + //===----------------------------------------------------------------------===// // OpPatternRewriteDriver //===----------------------------------------------------------------------===// @@ -501,8 +530,9 @@ public: explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - bool strict) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), + bool strict, PatternRewriter &rewriter) + : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), + rewriter), strictMode(strict) {} bool simplifyLocally(ArrayRef op); @@ -664,7 +694,8 @@ return false; // Start the pattern driver. + PatternRewriter rewriter(ops.front()->getContext()); MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strict); + strict, rewriter); return driver.simplifyLocally(ops); }