diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -11,6 +11,7 @@ #include "mlir/IR/Listeners.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Compiler.h" #include @@ -252,8 +253,25 @@ return OpBuilder(block, Block::iterator(terminator), listener); } - /// Sets the listener of this builder to the one provided. - void setListener(RewriteListener *newListener) { listener = newListener; } + /// Attach the given listener to this builder and return an object that resets + /// the listener to the original state when the object is destructed at scope + /// exit. + auto attachScopedListener(RewriteListener *newListener) { + assert(newListener && "expected non-null listener"); + RewriteListener *previousListener = getListener(); + SmallVector delegatedListeners({newListener}); + if (previousListener) + delegatedListeners.push_back(previousListener); + // Forwarding listener is deleted when the returned object is destructed. + auto forwardingListener = + std::make_unique(delegatedListeners); + setListener(forwardingListener.get()); + return llvm::make_scope_exit( + [this, previousListener, u = move(forwardingListener)]() { + // Reset the listener to the original value. + this->setListener(previousListener); + }); + } /// Returns the current listener of this builder, or nullptr if this builder /// doesn't have a listener. @@ -523,6 +541,9 @@ } protected: + /// Sets the listener of this builder to the one provided. + void setListener(RewriteListener *newListener) { listener = newListener; } + /// The optional listener for events of this builder. RewriteListener *listener = nullptr; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -200,9 +200,9 @@ } namespace { -class NewOpsListener : public ForwardingRewriteListener { +class NewOpsListener : public RewriteListener { public: - using ForwardingRewriteListener::ForwardingRewriteListener; + using RewriteListener::RewriteListener; SmallVector getNewOps() const { return SmallVector(newOps.begin(), newOps.end()); @@ -210,14 +210,14 @@ private: void notifyOperationInserted(Operation *op) override { - ForwardingRewriteListener::notifyOperationInserted(op); + RewriteListener::notifyOperationInserted(op); auto inserted = newOps.insert(op); (void)inserted; assert(inserted.second && "expected newly created op"); } void notifyOperationRemoved(Operation *op) override { - ForwardingRewriteListener::notifyOperationRemoved(op); + RewriteListener::notifyOperationRemoved(op); op->walk([&](Operation *op) { newOps.erase(op); }); } @@ -229,11 +229,8 @@ transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { // Attach listener to keep track of newly created ops. - RewriteListener *previousListener = rewriter.getListener(); - auto resetListener = - llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); - NewOpsListener newOpsListener(previousListener); - rewriter.setListener(&newOpsListener); + NewOpsListener newOpsListener; + auto resetListener = rewriter.attachScopedListener(&newOpsListener); linalg::BufferizeToAllocationOptions options; if (getMemcpyOp() == "memref.tensor_store") {