diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -253,10 +253,32 @@ // Listeners //===--------------------------------------------------------------------===// + /// Base class for listeners. + struct ListenerBase { + /// The kind of listener. + enum class Kind { + /// OpBuilder::Listener or user-derived class. + OpBuilderListener = 0, + + /// RewriterBase::Listener or user-derived class. + RewriterBaseListener = 1 + }; + + Kind getKind() const { return kind; } + + protected: + ListenerBase(Kind kind) : kind(kind) {} + + private: + const Kind kind; + }; + /// This class represents a listener that may be used to hook into various /// actions within an OpBuilder. - struct Listener { - virtual ~Listener(); + struct Listener : public ListenerBase { + Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {} + + virtual ~Listener() = default; /// Notification handler for when an operation is inserted into the builder. /// `op` is the operation that was inserted. @@ -265,6 +287,9 @@ /// Notification handler for when a block is created using the builder. /// `block` is the block that was created. virtual void notifyBlockCreated(Block *block) {} + + protected: + Listener(Kind kind) : ListenerBase(kind) {} }; /// Sets the listener of this builder to the one provided. @@ -537,14 +562,16 @@ return cast(cloneWithoutRegions(*op.getOperation())); } +protected: + /// The optional listener for events of this builder. + Listener *listener; + private: /// The current block this builder is inserting into. Block *block = nullptr; /// The insertion point within the block that this builder is inserting /// before. Block::iterator insertPoint; - /// The optional listener for events of this builder. - Listener *listener; }; } // namespace mlir diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -396,8 +396,36 @@ /// This class serves as a common API for IR mutation between pattern rewrites /// and non-pattern rewrites, and facilitates the development of shared /// IR transformation utilities. -class RewriterBase : public OpBuilder, public OpBuilder::Listener { +class RewriterBase : public OpBuilder { public: + struct Listener : public OpBuilder::Listener { + Listener() + : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {} + + /// Notify the listener that the specified operation is about to be replaced + /// with the set of values potentially produced by new operations. This is + /// called before the uses of the operation have been changed. + virtual void notifyOperationReplaced(Operation *op, + ValueRange replacement) {} + + /// This is called on an operation that a rewrite is removing, right before + /// the operation is deleted. At this point, the operation has zero uses. + virtual void notifyOperationRemoved(Operation *op) {} + + /// Notify the listener that the pattern failed to match the given + /// operation, and provide a callback to populate a diagnostic with the + /// reason why the failure occurred. This method allows for derived + /// listeners to optionally hook into the reason why a rewrite failed, and + /// display it to users. + virtual LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) { + return failure(); + } + + static bool classof(const OpBuilder::Listener *base); + }; + /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller /// is responsible for creating or updating the operation transferring flow @@ -541,8 +569,10 @@ std::enable_if_t::value, LogicalResult> notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { #ifndef NDEBUG - return notifyMatchFailure(loc, - function_ref(reasonCallback)); + if (auto *rewriteListener = dyn_cast_if_present(listener)) + return rewriteListener->notifyMatchFailure( + loc, function_ref(reasonCallback)); + return failure(); #else return failure(); #endif @@ -550,8 +580,10 @@ template std::enable_if_t::value, LogicalResult> notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { - return notifyMatchFailure(op->getLoc(), - function_ref(reasonCallback)); + if (auto *rewriteListener = dyn_cast_if_present(listener)) + return rewriteListener->notifyMatchFailure( + op->getLoc(), function_ref(reasonCallback)); + return failure(); } template LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) { @@ -564,35 +596,11 @@ } protected: - /// Initialize the builder with this rewriter as the listener. - explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {} + /// Initialize the builder. + explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx) {} explicit RewriterBase(const OpBuilder &otherBuilder) - : OpBuilder(otherBuilder) { - setListener(this); - } - ~RewriterBase() override; - - /// These are the callback methods that subclasses can choose to implement if - /// they would like to be notified about certain types of mutations. - - /// Notify the rewriter that the specified operation is about to be replaced - /// with the set of values potentially produced by new operations. This is - /// called before the uses of the operation have been changed. - virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {} - - /// This is called on an operation that a rewrite is removing, right before - /// the operation is deleted. At this point, the operation has zero uses. - virtual void notifyOperationRemoved(Operation *op) {} - - /// Notify the rewriter that the pattern failed to match the given operation, - /// and provide a callback to populate a diagnostic with the reason why the - /// failure occurred. This method allows for derived rewriters to optionally - /// hook into the reason why a rewrite failed, and display it to users. - virtual LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) { - return failure(); - } + : OpBuilder(otherBuilder) {} + virtual ~RewriterBase(); private: void operator=(const RewriterBase &) = delete; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -618,7 +618,8 @@ /// This class implements a pattern rewriter for use with ConversionPatterns. It /// extends the base PatternRewriter and provides special conversion specific /// hooks. -class ConversionPatternRewriter final : public PatternRewriter { +class ConversionPatternRewriter final : public PatternRewriter, + public RewriterBase::Listener { public: explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; @@ -742,6 +743,9 @@ detail::ConversionPatternRewriterImpl &getImpl(); private: + using OpBuilder::getListener; + using OpBuilder::setListener; + std::unique_ptr impl; }; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -555,7 +555,7 @@ // Inside regular functions we use the blocking wait operation to wait for // the async object (token, value or group) to become available. if (!isInCoroutine) { - ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); + ImplicitLocOpBuilder builder(loc, op, &rewriter); builder.create(loc, operand); // Assert that the awaited operands is not in the error state. @@ -574,7 +574,7 @@ CoroMachinery &coro = funcCoro->getSecond(); Block *suspended = op->getBlock(); - ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); + ImplicitLocOpBuilder builder(loc, op, &rewriter); MLIRContext *ctx = op->getContext(); // Save the coroutine state and resume on a runtime managed thread when diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -342,7 +342,7 @@ namespace { /// A rewriter that keeps track of extra information during bufferization. -class BufferizationRewriter : public IRRewriter { +class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { public: BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, DenseSet &toMemrefOps, @@ -352,18 +352,18 @@ BufferizationStatistics *statistics) : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), worklist(worklist), analysisState(options), opFilter(opFilter), - statistics(statistics) {} + statistics(statistics) { + setListener(this); + } protected: void notifyOperationRemoved(Operation *op) override { - IRRewriter::notifyOperationRemoved(op); erasedOps.insert(op); // Erase if present. toMemrefOps.erase(op); } void notifyOperationInserted(Operation *op) override { - IRRewriter::notifyOperationInserted(op); erasedOps.erase(op); // Gather statistics about allocs and deallocs. diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -388,8 +388,6 @@ // OpBuilder //===----------------------------------------------------------------------===// -OpBuilder::Listener::~Listener() = default; - /// Insert the given operation at the current insertion point and return it. Operation *OpBuilder::insert(Operation *op) { if (block) diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -217,6 +217,10 @@ // RewriterBase //===----------------------------------------------------------------------===// +bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) { + return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener; +} + RewriterBase::~RewriterBase() { // Out of line to provide a vtable anchor for the class. } @@ -232,7 +236,8 @@ "incorrect number of values to replace operation"); // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op, newValues); + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationReplaced(op, newValues); // Replace each use of the results when the functor is true. bool replacedAllUses = true; @@ -260,13 +265,15 @@ /// the operation. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op, newValues); + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationReplaced(op, newValues); assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); op->replaceAllUsesWith(newValues); - notifyOperationRemoved(op); + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationRemoved(op); op->erase(); } @@ -274,7 +281,8 @@ /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); - notifyOperationRemoved(op); + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationRemoved(op); op->erase(); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1495,7 +1495,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this)) {} + impl(new detail::ConversionPatternRewriterImpl(*this)) { + setListener(this); +} + ConversionPatternRewriter::~ConversionPatternRewriter() = default; void ConversionPatternRewriter::replaceOpWithIf( diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,7 +39,8 @@ /// This abstract class manages the worklist and contains helper methods for /// rewriting ops on the worklist. Derived classes specify how ops are added /// to the worklist in the beginning. -class GreedyPatternRewriteDriver : public PatternRewriter { +class GreedyPatternRewriteDriver : public PatternRewriter, + public RewriterBase::Listener { protected: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, @@ -67,7 +68,7 @@ /// Notify the driver that the specified operation was replaced. Update the /// worklist as needed: New users are added enqueued. - void notifyRootReplaced(Operation *op, ValueRange replacement) override; + void notifyOperationReplaced(Operation *op, ValueRange replacement) override; /// Process ops until the worklist is empty or `config.maxNumRewrites` is /// reached. Return `true` if any IR was changed. @@ -128,6 +129,9 @@ // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); + + // Set up listener. + setListener(this); } bool GreedyPatternRewriteDriver::processWorklist() { @@ -359,8 +363,8 @@ strictModeFilteredOps.erase(op); } -void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op, - ValueRange replacement) { +void GreedyPatternRewriteDriver::notifyOperationReplaced( + Operation *op, ValueRange replacement) { LLVM_DEBUG({ logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n";