diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -171,49 +171,84 @@ /// automatically inserted at an insertion point. The builder is copyable. class OpBuilder : public Builder { public: + struct Listener; + /// Create a builder with the given context. - explicit OpBuilder(MLIRContext *ctx) : Builder(ctx) {} + explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) + : Builder(ctx), listener(listener) {} /// Create a builder and set the insertion point to the start of the region. - explicit OpBuilder(Region *region) : Builder(region->getContext()) { + explicit OpBuilder(Region *region, Listener *listener = nullptr) + : OpBuilder(region->getContext(), listener) { if (!region->empty()) setInsertionPoint(®ion->front(), region->front().begin()); } - explicit OpBuilder(Region ®ion) : OpBuilder(®ion) {} - - virtual ~OpBuilder(); + explicit OpBuilder(Region ®ion, Listener *listener = nullptr) + : OpBuilder(®ion, listener) {} /// Create a builder and set insertion point to the given operation, which /// will cause subsequent insertions to go right before it. - explicit OpBuilder(Operation *op) : Builder(op->getContext()) { + explicit OpBuilder(Operation *op, Listener *listener = nullptr) + : OpBuilder(op->getContext(), listener) { setInsertionPoint(op); } - OpBuilder(Block *block, Block::iterator insertPoint) - : OpBuilder(block->getParent()) { + OpBuilder(Block *block, Block::iterator insertPoint, + Listener *listener = nullptr) + : OpBuilder(block->getParent()->getContext(), listener) { setInsertionPoint(block, insertPoint); } /// Create a builder and set the insertion point to before the first operation /// in the block but still inside the block. - static OpBuilder atBlockBegin(Block *block) { - return OpBuilder(block, block->begin()); + static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { + return OpBuilder(block, block->begin(), listener); } /// Create a builder and set the insertion point to after the last operation /// in the block but still inside the block. - static OpBuilder atBlockEnd(Block *block) { - return OpBuilder(block, block->end()); + static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { + return OpBuilder(block, block->end(), listener); } /// Create a builder and set the insertion point to before the block /// terminator. - static OpBuilder atBlockTerminator(Block *block) { + static OpBuilder atBlockTerminator(Block *block, + Listener *listener = nullptr) { auto *terminator = block->getTerminator(); assert(terminator != nullptr && "the block has no terminator"); - return OpBuilder(block, terminator->getIterator()); + return OpBuilder(block, Block::iterator(terminator), listener); } + //===--------------------------------------------------------------------===// + // Listeners + //===--------------------------------------------------------------------===// + + /// This class represents a listener that may be used to hook into various + /// actions within an OpBuilder. + struct Listener { + virtual ~Listener(); + + /// Notification handler for when an operation is inserted into the builder. + /// `op` is the operation that was inserted. + virtual void notifyOperationInserted(Operation *op) {} + + /// Notification handler for when a block is created using the builder. + /// `block` is the block that was created. + virtual void notifyBlockCreated(Block *block) {} + }; + + /// Sets the listener of this builder to the one provided. + void setListener(Listener *newListener) { listener = newListener; } + + /// Returns the current listener of this builder, or nullptr if this builder + /// doesn't have a listener. + Listener *getListener() const { return listener; } + + //===--------------------------------------------------------------------===// + // Insertion Point Management + //===--------------------------------------------------------------------===// + /// This class represents a saved insertion point. class InsertPoint { public: @@ -304,21 +339,29 @@ /// Returns the current insertion point of the builder. Block::iterator getInsertionPoint() const { return insertPoint; } - /// Insert the given operation at the current insertion point and return it. - virtual Operation *insert(Operation *op); + /// Returns the current block of the builder. + Block *getBlock() const { return block; } + + //===--------------------------------------------------------------------===// + // Block Creation + //===--------------------------------------------------------------------===// /// Add new block with 'argTypes' arguments and set the insertion point to the /// end of it. The block is inserted at the provided insertion point of /// 'parent'. - virtual Block *createBlock(Region *parent, Region::iterator insertPt = {}, - TypeRange argTypes = llvm::None); + Block *createBlock(Region *parent, Region::iterator insertPt = {}, + TypeRange argTypes = llvm::None); /// Add new block with 'argTypes' arguments and set the insertion point to the /// end of it. The block is placed before 'insertBefore'. Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None); - /// Returns the current block of the builder. - Block *getBlock() const { return block; } + //===--------------------------------------------------------------------===// + // Operation Creation + //===--------------------------------------------------------------------===// + + /// Insert the given operation at the current insertion point and return it. + Operation *insert(Operation *op); /// Creates an operation given the fields represented as an OperationState. Operation *createOperation(const OperationState &state); @@ -406,8 +449,13 @@ } private: + /// The current block this builder is inserting into. Block *block = nullptr; + /// The insertion point within the block that this builder is inserting + /// before. Block::iterator insertPoint; + /// The optional listener for events of this builder. + Listener *listener; }; } // namespace mlir diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -211,7 +211,7 @@ /// to apply patterns and observe their effects (e.g. to keep worklists or /// other data structures up to date). /// -class PatternRewriter : public OpBuilder { +class PatternRewriter : public OpBuilder, public OpBuilder::Listener { public: /// Create operation of specific op type at the current insertion point /// without verifying to see if it is valid. @@ -247,10 +247,6 @@ return OpTy(); } - /// This is implemented to insert the specified operation and serves as a - /// notification hook for rewriters that want to know about new operations. - virtual Operation *insert(Operation *op) = 0; - /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller /// is responsible for creating or updating the operation transferring flow @@ -349,11 +345,13 @@ } protected: - explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {} - virtual ~PatternRewriter(); + /// Initialize the builder with this rewriter as the listener. + explicit PatternRewriter(MLIRContext *ctx) + : OpBuilder(ctx, /*listener=*/this) {} + ~PatternRewriter() override; - // These are the callback methods that subclasses can choose to implement if - // they would like to be notified about certain types of mutations. + /// These are the callback methods that subclasses can choose to implement if + /// they would like to be notified about certain types of mutations. /// Notify the pattern rewriter that the specified operation is about to be /// replaced with another set of operations. This is called before the uses diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -348,9 +348,8 @@ /// implemented for dialect conversion. void eraseBlock(Block *block) override; - /// PatternRewriter hook for creating a new block with the given arguments. - Block *createBlock(Region *parent, Region::iterator insertPt = {}, - TypeRange argTypes = llvm::None) override; + /// PatternRewriter hook creating a new block. + void notifyBlockCreated(Block *block) override; /// PatternRewriter hook for splitting a block into two parts. Block *splitBlock(Block *block, Block::iterator before) override; @@ -373,7 +372,7 @@ using PatternRewriter::cloneRegionBefore; /// PatternRewriter hook for inserting a new operation. - Operation *insert(Operation *op) override; + void notifyOperationInserted(Operation *op) override; /// PatternRewriter hook for updating the root operation in-place. /// Note: These methods only track updates to the top-level operation itself, diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -330,15 +330,18 @@ } //===----------------------------------------------------------------------===// -// OpBuilder. +// OpBuilder //===----------------------------------------------------------------------===// -OpBuilder::~OpBuilder() {} +OpBuilder::Listener::~Listener() {} /// Insert the given operation at the current insertion point and return it. Operation *OpBuilder::insert(Operation *op) { if (block) block->getOperations().insert(insertPoint, op); + + if (listener) + listener->notifyOperationInserted(op); return op; } @@ -355,6 +358,9 @@ b->addArguments(argTypes); parent->getBlocks().insert(insertPt, b); setInsertionPointToEnd(b); + + if (listener) + listener->notifyBlockCreated(b); return b; } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -954,12 +954,8 @@ } /// PatternRewriter hook for creating a new block with the given arguments. -Block *ConversionPatternRewriter::createBlock(Region *parent, - Region::iterator insertPtr, - TypeRange argTypes) { - Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes); +void ConversionPatternRewriter::notifyBlockCreated(Block *block) { impl->notifyCreatedBlock(block); - return block; } /// PatternRewriter hook for splitting a block into two parts. @@ -1001,13 +997,12 @@ } /// PatternRewriter hook for creating a new operation. -Operation *ConversionPatternRewriter::insert(Operation *op) { +void ConversionPatternRewriter::notifyOperationInserted(Operation *op) { LLVM_DEBUG({ impl->logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); impl->createdOps.push_back(op); - return OpBuilder::insert(op); } /// PatternRewriter hook for updating the root operation in-place. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -77,10 +77,7 @@ protected: // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. - Operation *insert(Operation *op) override { - addToWorklist(op); - return OpBuilder::insert(op); - } + void notifyOperationInserted(Operation *op) override { addToWorklist(op); } // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. @@ -266,9 +263,6 @@ bool simplifyLocally(Operation *op, int maxIterations, bool &erased); - /// No additional action needed other than inserting the op. - Operation *insert(Operation *op) override { return OpBuilder::insert(op); } - // These are hooks implemented for PatternRewriter. protected: /// If an operation is about to be removed, mark it so that we can let clients