diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -535,14 +535,16 @@ return cast(cloneWithoutRegions(*op.getOperation())); } +protected: + /// The optional listener for events of this builder. + Listener *listener; + private: /// The current block this builder is inserting into. Block *block = nullptr; /// The insertion point within the block that this builder is inserting /// before. Block::iterator insertPoint; - /// The optional listener for events of this builder. - Listener *listener; }; } // namespace mlir diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -396,8 +396,30 @@ /// This class serves as a common API for IR mutation between pattern rewrites /// and non-pattern rewrites, and facilitates the development of shared /// IR transformation utilities. -class RewriterBase : public OpBuilder, public OpBuilder::Listener { +class RewriterBase : public OpBuilder { public: + struct Listener : public OpBuilder::Listener { + /// Notify the listener that the specified operation is about to be replaced + /// with the set of values potentially produced by new operations. This is + /// called before the uses of the operation have been changed. + virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {} + + /// This is called on an operation that a rewrite is removing, right before + /// the operation is deleted. At this point, the operation has zero uses. + virtual void notifyOperationRemoved(Operation *op) {} + + /// Notify the listener that the pattern failed to match the given + /// operation, and provide a callback to populate a diagnostic with the + /// reason why the failure occurred. This method allows for derived + /// listeners to optionally hook into the reason why a rewrite failed, and + /// display it to users. + virtual LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) { + return failure(); + } + }; + /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller /// is responsible for creating or updating the operation transferring flow @@ -532,8 +554,10 @@ std::enable_if_t::value, LogicalResult> notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { #ifndef NDEBUG - return notifyMatchFailure(loc, - function_ref(reasonCallback)); + if (auto *rewriteListener = dynamic_cast(listener)) + return rewriteListener->notifyMatchFailure( + loc, function_ref(reasonCallback)); + return failure(); #else return failure(); #endif @@ -541,8 +565,10 @@ template std::enable_if_t::value, LogicalResult> notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { - return notifyMatchFailure(op->getLoc(), - function_ref(reasonCallback)); + if (auto *rewriteListener = dynamic_cast(listener)) + return rewriteListener->notifyMatchFailure( + op->getLoc(), function_ref(reasonCallback)); + return failure(); } template LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) { @@ -555,35 +581,11 @@ } protected: - /// Initialize the builder with this rewriter as the listener. - explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {} + /// Initialize the builder. + explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx) {} explicit RewriterBase(const OpBuilder &otherBuilder) - : OpBuilder(otherBuilder) { - setListener(this); - } - ~RewriterBase() override; - - /// These are the callback methods that subclasses can choose to implement if - /// they would like to be notified about certain types of mutations. - - /// Notify the rewriter that the specified operation is about to be replaced - /// with the set of values potentially produced by new operations. This is - /// called before the uses of the operation have been changed. - virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {} - - /// This is called on an operation that a rewrite is removing, right before - /// the operation is deleted. At this point, the operation has zero uses. - virtual void notifyOperationRemoved(Operation *op) {} - - /// Notify the rewriter that the pattern failed to match the given operation, - /// and provide a callback to populate a diagnostic with the reason why the - /// failure occurred. This method allows for derived rewriters to optionally - /// hook into the reason why a rewrite failed, and display it to users. - virtual LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) { - return failure(); - } + : OpBuilder(otherBuilder) {} + virtual ~RewriterBase(); private: void operator=(const RewriterBase &) = delete; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -522,7 +522,8 @@ /// This class implements a pattern rewriter for use with ConversionPatterns. It /// extends the base PatternRewriter and provides special conversion specific /// hooks. -class ConversionPatternRewriter final : public PatternRewriter { +class ConversionPatternRewriter final : public PatternRewriter, + public RewriterBase::Listener { public: explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; @@ -646,6 +647,9 @@ detail::ConversionPatternRewriterImpl &getImpl(); private: + using OpBuilder::getListener; + using OpBuilder::setListener; + std::unique_ptr impl; }; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -555,7 +555,7 @@ // Inside regular functions we use the blocking wait operation to wait for // the async object (token, value or group) to become available. if (!isInCoroutine) { - ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); + ImplicitLocOpBuilder builder(loc, op, &rewriter); builder.create(loc, operand); // Assert that the awaited operands is not in the error state. @@ -574,7 +574,7 @@ CoroMachinery &coro = funcCoro->getSecond(); Block *suspended = op->getBlock(); - ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); + ImplicitLocOpBuilder builder(loc, op, &rewriter); MLIRContext *ctx = op->getContext(); // Save the coroutine state and resume on a runtime managed thread when diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -354,7 +354,7 @@ namespace { /// A rewriter that keeps track of extra information during bufferization. -class BufferizationRewriter : public IRRewriter { +class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { public: BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, DenseSet &toMemrefOps, @@ -364,18 +364,18 @@ BufferizationStatistics *statistics) : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), worklist(worklist), analysisState(options), opFilter(opFilter), - statistics(statistics) {} + statistics(statistics) { + setListener(this); + } protected: void notifyOperationRemoved(Operation *op) override { - IRRewriter::notifyOperationRemoved(op); erasedOps.insert(op); // Erase if present. toMemrefOps.erase(op); } void notifyOperationInserted(Operation *op) override { - IRRewriter::notifyOperationInserted(op); erasedOps.erase(op); // Gather statistics about allocs and deallocs. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -232,7 +232,8 @@ "incorrect number of values to replace operation"); // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op, newValues); + if (auto *rewriteListener = dynamic_cast(listener)) + rewriteListener->notifyRootReplaced(op, newValues); // Replace each use of the results when the functor is true. bool replacedAllUses = true; @@ -260,13 +261,15 @@ /// the operation. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op, newValues); + if (auto *rewriteListener = dynamic_cast(listener)) + rewriteListener->notifyRootReplaced(op, newValues); assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); op->replaceAllUsesWith(newValues); - notifyOperationRemoved(op); + if (auto *rewriteListener = dynamic_cast(listener)) + rewriteListener->notifyOperationRemoved(op); op->erase(); } @@ -274,7 +277,8 @@ /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); - notifyOperationRemoved(op); + if (auto *rewriteListener = dynamic_cast(listener)) + rewriteListener->notifyOperationRemoved(op); op->erase(); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1494,7 +1494,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this)) {} + impl(new detail::ConversionPatternRewriterImpl(*this)) { + setListener(this); +} + ConversionPatternRewriter::~ConversionPatternRewriter() = default; void ConversionPatternRewriter::replaceOpWithIf( diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,7 +39,8 @@ /// This abstract class manages the worklist and contains helper methods for /// rewriting ops on the worklist. Derived classes specify how ops are added /// to the worklist in the beginning. -class GreedyPatternRewriteDriver : public PatternRewriter { +class GreedyPatternRewriteDriver : public PatternRewriter, + public RewriterBase::Listener { protected: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, @@ -129,6 +130,9 @@ // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); + + // Set up listener. + setListener(this); } bool GreedyPatternRewriteDriver::processWorklist() {