diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -16,6 +16,8 @@ #include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/IR/Action.h" + namespace mlir { class PatternRewriter; @@ -23,6 +25,26 @@ class PDLByteCodeMutableState; } // namespace detail +/// This is the type of Action that is dispatched when a pattern is applied. +/// It captures the pattern to apply on top of the usual context. +class ApplyPatternAction : public tracing::ActionImpl { +public: + using Base = tracing::ActionImpl; + ApplyPatternAction(ArrayRef irUnits, const Pattern &pattern) + : Base(irUnits), pattern(pattern) {} + static constexpr StringLiteral tag = "apply-pattern-action"; + static constexpr StringLiteral desc = + "Encapsulate the application of rewrite patterns"; + + void print(raw_ostream &os) const override { + os << "`" << tag << "`\n" + << " pattern: " << pattern.getDebugName() << '\n'; + } + +private: + const Pattern &pattern; +}; + /// This class manages the application of a group of rewrite patterns, with a /// user-provided cost model. class PatternApplicator { diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -185,35 +185,47 @@ // Try to match and rewrite this pattern. The patterns are sorted by // benefit, so if we match we can immediately rewrite. For PDL patterns, the // match has already been performed, we just need to rewrite. - rewriter.setInsertionPoint(op); + bool matched = false; + op->getContext()->executeAction( + [&]() { + rewriter.setInsertionPoint(op); #ifndef NDEBUG - // Operation `op` may be invalidated after applying the rewrite pattern. - Operation *dumpRootOp = getDumpRootOp(op); + // Operation `op` may be invalidated after applying the rewrite + // pattern. + Operation *dumpRootOp = getDumpRootOp(op); #endif - if (pdlMatch) { - result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); - } else { - LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" - << bestPattern->getDebugName() << "\"\n"); - - const auto *pattern = static_cast(bestPattern); - result = pattern->matchAndRewrite(op, rewriter); - - LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName() - << "\" result " << succeeded(result) << "\n"); - } - - // Process the result of the pattern application. - if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern))) - result = failure(); - if (succeeded(result)) { - LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); + if (pdlMatch) { + result = + bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); + } else { + LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" + << bestPattern->getDebugName() << "\"\n"); + + const auto *pattern = + static_cast(bestPattern); + result = pattern->matchAndRewrite(op, rewriter); + + LLVM_DEBUG(llvm::dbgs() + << "\"" << bestPattern->getDebugName() << "\" result " + << succeeded(result) << "\n"); + } + + // Process the result of the pattern application. + if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern))) + result = failure(); + if (succeeded(result)) { + LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); + matched = true; + return; + } + + // Perform any necessary cleanups. + if (onFailure) + onFailure(*bestPattern); + }, + {op}, *bestPattern); + if (matched) break; - } - - // Perform any necessary cleanups. - if (onFailure) - onFailure(*bestPattern); } while (true); if (mutableByteCodeState)