diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -40,30 +40,29 @@ /// Extends the MLIR OpBuilder to provide methods for building common FIR /// patterns. -class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener { +class FirOpBuilder : public mlir::OpBuilder, public mlir::RewriteListener { public: explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap) : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {} explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap) - : OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} { + : OpBuilder(builder), RewriteListener(), kindMap{std::move(kindMap)} { setListener(this); } explicit FirOpBuilder(mlir::OpBuilder &builder, mlir::ModuleOp mod) - : OpBuilder(builder), OpBuilder::Listener(), - kindMap{getKindMapping(mod)} { + : OpBuilder(builder), RewriteListener(), kindMap{getKindMapping(mod)} { setListener(this); } // The listener self-reference has to be updated in case of copy-construction. FirOpBuilder(const FirOpBuilder &other) - : OpBuilder(other), OpBuilder::Listener(), kindMap{other.kindMap}, + : OpBuilder(other), RewriteListener(), kindMap{other.kindMap}, fastMathFlags{other.fastMathFlags} { setListener(this); } FirOpBuilder(FirOpBuilder &&other) - : OpBuilder(other), OpBuilder::Listener(), - kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} { + : OpBuilder(other), RewriteListener(), kindMap{std::move(other.kindMap)}, + fastMathFlags{other.fastMathFlags} { setListener(this); } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -576,7 +576,7 @@ /// listeners. This is required when a pattern uses a firBuilder helper that /// may create illegal operations that will need to be translated and requires /// notifying the rewriter. -struct HLFIRListener : public mlir::OpBuilder::Listener { +struct HLFIRListener : public mlir::RewriteListener { HLFIRListener(fir::FirOpBuilder &builder, mlir::ConversionPatternRewriter &rewriter) : builder{builder}, rewriter{rewriter} {} diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -787,12 +787,12 @@ ]; let extraClassDeclaration = [{ - OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) { + OpBuilder getThenBodyBuilder(RewriteListener *listener = nullptr) { Block* body = getBody(0); return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener) : OpBuilder::atBlockEnd(body, listener); } - OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) { + OpBuilder getElseBodyBuilder(RewriteListener *listener = nullptr) { Block* body = getBody(1); return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener) : OpBuilder::atBlockEnd(body, listener); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -846,7 +846,7 @@ /// A listener that updates a TransformState based on IR modifications. This /// listener can be used during a greedy pattern rewrite to keep the transform /// state up-to-date. -class TrackingListener : public RewriterBase::Listener, +class TrackingListener : public RewriteListener, public TransformState::Extension { public: /// Create a new TrackingListener for usage in the specified transform op. @@ -929,7 +929,7 @@ void notifyOperationRemoved(Operation *op) override; void notifyOperationReplaced(Operation *op, ValueRange newValues) override; - using Listener::notifyOperationReplaced; + using RewriteListener::notifyOperationReplaced; /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILDERS_H #define MLIR_IR_BUILDERS_H +#include "mlir/IR/Listeners.h" #include "mlir/IR/OpDefinition.h" #include "llvm/Support/Compiler.h" #include @@ -202,104 +203,61 @@ /// automatically inserted at an insertion point. The builder is copyable. class OpBuilder : public Builder { public: - struct Listener; - /// Create a builder with the given context. - explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) + explicit OpBuilder(MLIRContext *ctx, RewriteListener *listener = nullptr) : Builder(ctx), listener(listener) {} /// Create a builder and set the insertion point to the start of the region. - explicit OpBuilder(Region *region, Listener *listener = nullptr) + explicit OpBuilder(Region *region, RewriteListener *listener = nullptr) : OpBuilder(region->getContext(), listener) { if (!region->empty()) setInsertionPoint(®ion->front(), region->front().begin()); } - explicit OpBuilder(Region ®ion, Listener *listener = nullptr) + explicit OpBuilder(Region ®ion, RewriteListener *listener = nullptr) : OpBuilder(®ion, listener) {} /// Create a builder and set insertion point to the given operation, which /// will cause subsequent insertions to go right before it. - explicit OpBuilder(Operation *op, Listener *listener = nullptr) + explicit OpBuilder(Operation *op, RewriteListener *listener = nullptr) : OpBuilder(op->getContext(), listener) { setInsertionPoint(op); } OpBuilder(Block *block, Block::iterator insertPoint, - Listener *listener = nullptr) + RewriteListener *listener = nullptr) : OpBuilder(block->getParent()->getContext(), listener) { setInsertionPoint(block, insertPoint); } /// Create a builder and set the insertion point to before the first operation /// in the block but still inside the block. - static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { + static OpBuilder atBlockBegin(Block *block, + RewriteListener *listener = nullptr) { return OpBuilder(block, block->begin(), listener); } /// Create a builder and set the insertion point to after the last operation /// in the block but still inside the block. - static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { + static OpBuilder atBlockEnd(Block *block, + RewriteListener *listener = nullptr) { return OpBuilder(block, block->end(), listener); } /// Create a builder and set the insertion point to before the block /// terminator. static OpBuilder atBlockTerminator(Block *block, - Listener *listener = nullptr) { + RewriteListener *listener = nullptr) { auto *terminator = block->getTerminator(); assert(terminator != nullptr && "the block has no terminator"); return OpBuilder(block, Block::iterator(terminator), listener); } - //===--------------------------------------------------------------------===// - // Listeners - //===--------------------------------------------------------------------===// - - /// Base class for listeners. - struct ListenerBase { - /// The kind of listener. - enum class Kind { - /// OpBuilder::Listener or user-derived class. - OpBuilderListener = 0, - - /// RewriterBase::Listener or user-derived class. - RewriterBaseListener = 1 - }; - - Kind getKind() const { return kind; } - - protected: - ListenerBase(Kind kind) : kind(kind) {} - - private: - const Kind kind; - }; - - /// This class represents a listener that may be used to hook into various - /// actions within an OpBuilder. - struct Listener : public ListenerBase { - Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {} - - virtual ~Listener() = default; - - /// Notification handler for when an operation is inserted into the builder. - /// `op` is the operation that was inserted. - virtual void notifyOperationInserted(Operation *op) {} - - /// Notification handler for when a block is created using the builder. - /// `block` is the block that was created. - virtual void notifyBlockCreated(Block *block) {} - - protected: - Listener(Kind kind) : ListenerBase(kind) {} - }; - /// Sets the listener of this builder to the one provided. - void setListener(Listener *newListener) { listener = newListener; } + void setListener(RewriteListener *newListener) { listener = newListener; } /// Returns the current listener of this builder, or nullptr if this builder /// doesn't have a listener. - Listener *getListener() const { return listener; } + RewriteListener *getListener() const { return listener; } //===--------------------------------------------------------------------===// // Insertion Point Management @@ -566,7 +524,7 @@ protected: /// The optional listener for events of this builder. - Listener *listener; + RewriteListener *listener = nullptr; private: /// The current block this builder is inserting into. diff --git a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h --- a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h +++ b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h @@ -30,22 +30,24 @@ /// Create a builder and set the insertion point to before the first operation /// in the block but still inside the block. - static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, - Listener *listener = nullptr) { + static ImplicitLocOpBuilder + atBlockBegin(Location loc, Block *block, + RewriteListener *listener = nullptr) { return ImplicitLocOpBuilder(loc, block, block->begin(), listener); } /// Create a builder and set the insertion point to after the last operation /// in the block but still inside the block. static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, - Listener *listener = nullptr) { + RewriteListener *listener = nullptr) { return ImplicitLocOpBuilder(loc, block, block->end(), listener); } /// Create a builder and set the insertion point to before the block /// terminator. - static ImplicitLocOpBuilder atBlockTerminator(Location loc, Block *block, - Listener *listener = nullptr) { + static ImplicitLocOpBuilder + atBlockTerminator(Location loc, Block *block, + RewriteListener *listener = nullptr) { auto *terminator = block->getTerminator(); assert(terminator != nullptr && "the block has no terminator"); return ImplicitLocOpBuilder(loc, block, Block::iterator(terminator), diff --git a/mlir/include/mlir/IR/Listeners.h b/mlir/include/mlir/IR/Listeners.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Listeners.h @@ -0,0 +1,97 @@ +//===- Listeners.h - Listeners for op builders/rewriters --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_LISTENERS_H +#define MLIR_IR_LISTENERS_H + +#include "mlir/IR/Operation.h" + +namespace mlir { +/// This class represents a listener that may be used to hook into various +/// actions within an OpBuilder or RewriterBase. +struct RewriteListener { + virtual ~RewriteListener() = default; + + /// Notification handler for when an operation is inserted into the builder. + /// `op` is the operation that was inserted. + virtual void notifyOperationInserted(Operation *op) {} + + /// Notification handler for when a block is created using the builder. + /// `block` is the block that was created. + virtual void notifyBlockCreated(Block *block) {} + + /// Notify the listener that the specified operation was modified in-place. + virtual void notifyOperationModified(Operation *op) {} + + /// Notify the listener that the specified operation is about to be replaced + /// with another operation. This is called before the uses of the old + /// operation have been changed. + /// + /// By default, this function calls the "operation replaced with values" + /// notification. + virtual void notifyOperationReplaced(Operation *op, Operation *replacement) { + notifyOperationReplaced(op, replacement->getResults()); + } + + /// Notify the listener that the specified operation is about to be replaced + /// with the a range of values, potentially produced by other operations. + /// This is called before the uses of the operation have been changed. + virtual void notifyOperationReplaced(Operation *op, ValueRange replacement) {} + + /// Notify the listener that the specified operation is about to be erased. + /// At this point, the operation has zero uses. + virtual void notifyOperationRemoved(Operation *op) {} + + /// Notify the listener that the pattern failed to match the given + /// operation, and provide a callback to populate a diagnostic with the + /// reason why the failure occurred. This method allows for derived + /// listeners to optionally hook into the reason why a rewrite failed, and + /// display it to users. + virtual LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) { + return failure(); + } +}; + +/// A listener that forwards all notifications to another listener. This +/// struct can be used as a base to create listener chains, so that multiple +/// listeners can be notified of IR changes. +struct ForwardingRewriteListener : public RewriteListener { + ForwardingRewriteListener(RewriteListener *listener) : listener(listener) {} + + void notifyOperationInserted(Operation *op) override { + listener->notifyOperationInserted(op); + } + void notifyBlockCreated(Block *block) override { + listener->notifyBlockCreated(block); + } + void notifyOperationModified(Operation *op) override { + listener->notifyOperationModified(op); + } + void notifyOperationReplaced(Operation *op, Operation *newOp) override { + listener->notifyOperationReplaced(op, newOp); + } + void notifyOperationReplaced(Operation *op, ValueRange replacement) override { + listener->notifyOperationReplaced(op, replacement); + } + void notifyOperationRemoved(Operation *op) override { + listener->notifyOperationRemoved(op); + } + LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) override { + return listener->notifyMatchFailure(loc, reasonCallback); + } + +private: + RewriteListener *listener; +}; +} // namespace mlir + +#endif // MLIR_IR_LISTENERS_H diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -398,89 +398,6 @@ /// IR transformation utilities. class RewriterBase : public OpBuilder { public: - struct Listener : public OpBuilder::Listener { - Listener() - : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {} - - /// Notify the listener that the specified operation was modified in-place. - virtual void notifyOperationModified(Operation *op) {} - - /// Notify the listener that the specified operation is about to be replaced - /// with another operation. This is called before the uses of the old - /// operation have been changed. - /// - /// By default, this function calls the "operation replaced with values" - /// notification. - virtual void notifyOperationReplaced(Operation *op, - Operation *replacement) { - notifyOperationReplaced(op, replacement->getResults()); - } - - /// Notify the listener that the specified operation is about to be replaced - /// with the a range of values, potentially produced by other operations. - /// This is called before the uses of the operation have been changed. - virtual void notifyOperationReplaced(Operation *op, - ValueRange replacement) {} - - /// Notify the listener that the specified operation is about to be erased. - /// At this point, the operation has zero uses. - virtual void notifyOperationRemoved(Operation *op) {} - - /// Notify the listener that the pattern failed to match the given - /// operation, and provide a callback to populate a diagnostic with the - /// reason why the failure occurred. This method allows for derived - /// listeners to optionally hook into the reason why a rewrite failed, and - /// display it to users. - virtual LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) { - return failure(); - } - - static bool classof(const OpBuilder::Listener *base); - }; - - /// A listener that forwards all notifications to another listener. This - /// struct can be used as a base to create listener chains, so that multiple - /// listeners can be notified of IR changes. - struct ForwardingListener : public RewriterBase::Listener { - ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {} - - void notifyOperationInserted(Operation *op) override { - listener->notifyOperationInserted(op); - } - void notifyBlockCreated(Block *block) override { - listener->notifyBlockCreated(block); - } - void notifyOperationModified(Operation *op) override { - if (auto *rewriteListener = dyn_cast(listener)) - rewriteListener->notifyOperationModified(op); - } - void notifyOperationReplaced(Operation *op, Operation *newOp) override { - if (auto *rewriteListener = dyn_cast(listener)) - rewriteListener->notifyOperationReplaced(op, newOp); - } - void notifyOperationReplaced(Operation *op, - ValueRange replacement) override { - if (auto *rewriteListener = dyn_cast(listener)) - rewriteListener->notifyOperationReplaced(op, replacement); - } - void notifyOperationRemoved(Operation *op) override { - if (auto *rewriteListener = dyn_cast(listener)) - rewriteListener->notifyOperationRemoved(op); - } - LogicalResult notifyMatchFailure( - Location loc, - function_ref reasonCallback) override { - if (auto *rewriteListener = dyn_cast(listener)) - return rewriteListener->notifyMatchFailure(loc, reasonCallback); - return failure(); - } - - private: - OpBuilder::Listener *listener; - }; - /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller /// is responsible for creating or updating the operation transferring flow @@ -659,8 +576,8 @@ std::enable_if_t::value, LogicalResult> notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { #ifndef NDEBUG - if (auto *rewriteListener = dyn_cast_if_present(listener)) - return rewriteListener->notifyMatchFailure( + if (listener) + return listener->notifyMatchFailure( loc, function_ref(reasonCallback)); return failure(); #else @@ -670,8 +587,8 @@ template std::enable_if_t::value, LogicalResult> notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { - if (auto *rewriteListener = dyn_cast_if_present(listener)) - return rewriteListener->notifyMatchFailure( + if (listener) + return listener->notifyMatchFailure( op->getLoc(), function_ref(reasonCallback)); return failure(); } @@ -687,8 +604,7 @@ protected: /// Initialize the builder. - explicit RewriterBase(MLIRContext *ctx, - OpBuilder::Listener *listener = nullptr) + explicit RewriterBase(MLIRContext *ctx, RewriteListener *listener = nullptr) : OpBuilder(ctx, listener) {} explicit RewriterBase(const OpBuilder &otherBuilder) : OpBuilder(otherBuilder) {} @@ -709,7 +625,7 @@ /// such as a `PatternRewriter`, is not available. class IRRewriter : public RewriterBase { public: - explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) + explicit IRRewriter(MLIRContext *ctx, RewriteListener *listener = nullptr) : RewriterBase(ctx, listener) {} explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -637,7 +637,7 @@ /// extends the base PatternRewriter and provides special conversion specific /// hooks. class ConversionPatternRewriter final : public PatternRewriter, - public RewriterBase::Listener { + public RewriteListener { public: explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -32,7 +32,7 @@ /// generated along the way. class OperationFolder { public: - OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) + OperationFolder(MLIRContext *ctx, RewriteListener *listener = nullptr) : interfaces(ctx), rewriter(ctx, listener) {} /// Tries to perform folding on the given `op`, including unifying diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -78,7 +78,7 @@ GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; /// An optional listener that should be notified about IR modifications. - RewriterBase::Listener *listener = nullptr; + RewriteListener *listener = nullptr; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -152,7 +152,7 @@ class OneToNPatternRewriter : public PatternRewriter { public: OneToNPatternRewriter(MLIRContext *context, - OpBuilder::Listener *listener = nullptr) + RewriteListener *listener = nullptr) : PatternRewriter(context, listener) {} /// Replaces the results of the operation with the specified list of values diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1297,7 +1297,7 @@ SmallVector foldResults; if (failed(applyOp->fold(constOperands, foldResults)) || foldResults.empty()) { - if (OpBuilder::Listener *listener = b.getListener()) + if (RewriteListener *listener = b.getListener()) listener->notifyOperationInserted(applyOp); return applyOp.getResult(); } @@ -1363,7 +1363,7 @@ SmallVector foldResults; if (failed(minMaxOp->fold(constOperands, foldResults)) || foldResults.empty()) { - if (OpBuilder::Listener *listener = b.getListener()) + if (RewriteListener *listener = b.getListener()) listener->notifyOperationInserted(minMaxOp); return minMaxOp.getResult(); } diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -128,8 +128,7 @@ SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; - config.listener = - static_cast(rewriter.getListener()); + config.listener = rewriter.getListener(); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -343,7 +343,7 @@ namespace { /// A rewriter that keeps track of extra information during bufferization. -class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { +class BufferizationRewriter : public IRRewriter, public RewriteListener { public: BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, DenseSet &toMemrefOps, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -200,9 +200,9 @@ } namespace { -class NewOpsListener : public RewriterBase::ForwardingListener { +class NewOpsListener : public ForwardingRewriteListener { public: - using RewriterBase::ForwardingListener::ForwardingListener; + using ForwardingRewriteListener::ForwardingRewriteListener; SmallVector getNewOps() const { return SmallVector(newOps.begin(), newOps.end()); @@ -210,14 +210,14 @@ private: void notifyOperationInserted(Operation *op) override { - ForwardingListener::notifyOperationInserted(op); + ForwardingRewriteListener::notifyOperationInserted(op); auto inserted = newOps.insert(op); (void)inserted; assert(inserted.second && "expected newly created op"); } void notifyOperationRemoved(Operation *op) override { - ForwardingListener::notifyOperationRemoved(op); + ForwardingRewriteListener::notifyOperationRemoved(op); op->walk([&](Operation *op) { newOps.erase(op); }); } @@ -229,7 +229,7 @@ transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { // Attach listener to keep track of newly created ops. - OpBuilder::Listener *previousListener = rewriter.getListener(); + RewriteListener *previousListener = rewriter.getListener(); auto resetListener = llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); NewOpsListener newOpsListener(previousListener); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -315,8 +315,7 @@ // Configure the GreedyPatternRewriteDriver. GreedyRewriteConfig config; - config.listener = - static_cast(rewriter.getListener()); + config.listener = rewriter.getListener(); LogicalResult result = failure(); if (target->hasTrait()) { diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -217,10 +217,6 @@ // RewriterBase //===----------------------------------------------------------------------===// -bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) { - return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener; -} - RewriterBase::~RewriterBase() { // Out of line to provide a vtable anchor for the class. } @@ -236,8 +232,8 @@ "incorrect number of values to replace operation"); // Notify the listener that we're about to replace this op. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationReplaced(op, newValues); + if (listener) + listener->notifyOperationReplaced(op, newValues); // Replace each use of the results when the functor is true. bool replacedAllUses = true; @@ -268,8 +264,8 @@ "incorrect # of replacement values"); // Notify the listener that we're about to replace this op. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationReplaced(op, newValues); + if (listener) + listener->notifyOperationReplaced(op, newValues); // Replace results one-by-one. Also notifies the listener of modifications. for (auto it : llvm::zip(op->getResults(), newValues)) @@ -288,8 +284,8 @@ "ops have different number of results"); // Notify the listener that we're about to replace this op. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationReplaced(op, newOp); + if (listener) + listener->notifyOperationReplaced(op, newOp); // Replace results one-by-one. Also notifies the listener of modifications. for (auto it : llvm::zip(op->getResults(), newOp->getResults())) @@ -303,8 +299,8 @@ /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); + if (listener) + listener->notifyOperationRemoved(op); op->erase(); } @@ -318,8 +314,8 @@ void RewriterBase::finalizeRootUpdate(Operation *op) { // Notify the listener that the operation was modified. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationModified(op); + if (listener) + listener->notifyOperationModified(op); } /// Find uses of `from` and replace them with `to` if the `functor` returns diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -137,8 +137,7 @@ if (hasSSADominance) { // If the region has SSA dominance, then we are guaranteed to have not // visited any use of the current operation. - if (auto *rewriteListener = - dyn_cast_if_present(rewriter.getListener())) + if (RewriteListener *rewriteListener = rewriter.getListener()) rewriteListener->notifyOperationReplaced(op, existing); // Replace all uses, but do not remote the operation yet. This does not // notify the listener because the original op is not erased. @@ -150,8 +149,7 @@ auto wasVisited = [&](OpOperand &operand) { return !knownValues.count(operand.getOwner()); }; - if (auto *rewriteListener = - dyn_cast_if_present(rewriter.getListener())) + if (RewriteListener *rewriteListener = rewriter.getListener()) for (Value v : op->getResults()) if (all_of(v.getUses(), wasVisited)) rewriteListener->notifyOperationReplaced(op, existing); diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -91,10 +91,9 @@ if (results.empty()) { if (inPlaceUpdate) *inPlaceUpdate = true; - if (auto *rewriteListener = dyn_cast_if_present( - rewriter.getListener())) { + if (RewriteListener *listener = rewriter.getListener()) { // Folding API does not notify listeners, so we have to notify manually. - rewriteListener->notifyOperationModified(op); + listener->notifyOperationModified(op); } return success(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -45,9 +45,9 @@ /// A helper struct that stores finger prints of ops in order to detect broken /// RewritePatterns. A rewrite pattern is broken if it modifies IR without /// using the rewriter API or if it returns an inconsistent return value. -struct DebugFingerPrints : public RewriterBase::ForwardingListener { - DebugFingerPrints(RewriterBase::Listener *driver) - : RewriterBase::ForwardingListener(driver) {} +struct DebugFingerPrints : public ForwardingRewriteListener { + DebugFingerPrints(RewriteListener *driver) + : RewriterBase::ForwardingRewriteListener(driver) {} /// Compute finger prints of the given op and its nested ops. void computeFingerPrints(Operation *topLevel) { @@ -271,7 +271,7 @@ /// rewriting ops on the worklist. Derived classes specify how ops are added /// to the worklist in the beginning. class GreedyPatternRewriteDriver : public PatternRewriter, - public RewriterBase::Listener { + public RewriteListener { protected: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -14,7 +14,7 @@ namespace { /// Simple constant folding pass. struct TestConstantFold : public PassWrapper>, - public RewriterBase::Listener { + public RewriteListener { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold) StringRef getArgument() const final { return "test-constant-fold"; }