diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -36,6 +36,9 @@ /// Return which operand this is in the BlockOperand list of the Operation. unsigned getOperandNumber(); + + /// Notify the owner operation listener (if any) that the operand changed. + void notifyOperandChanged(Block *newBlock); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Listeners.h b/mlir/include/mlir/IR/Listeners.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Listeners.h @@ -0,0 +1,91 @@ +//===- Listeners.h - Listener for IR modification ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_LISTENERS_H +#define MLIR_IR_LISTENERS_H + +#include "mlir/IR/BlockSupport.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" + +namespace mlir { +class Operation; +class Block; +class BlockOperand; +class OpOperand; + +/// This class attaches to an operation and provides a mechanism to listen to IR +/// modifications. The listeners are notified when an operation is inserted, +/// detached, destroyed, moved, or +class IRListener : public llvm::RefCountedBase { +public: + virtual ~IRListener() = default; + virtual void attachToOperation(Operation *op) {} + virtual void notifyOpInserted(Operation *op, Block *oldBlock, + Block *newBlock) {} + virtual void notifyOpDetached(Operation *op) {} + virtual void notifyOpDestroyed(Operation *op) {} + virtual void notifyOpMoved(Operation *op) {} + virtual void notifyOpOperandChanged(OpOperand &operand, Value newValue) {} + virtual void notifyBlockOperandChanged(BlockOperand &operand, + Block *newBlock) {} +}; + +/// This class wraps a collection of IRListener and provides a convenient +/// mechanism to dispatch notifications to every listener. +class IRListeners { +public: + void attachToOperation(Operation *op) { + for (auto &listener : listeners) + listener->attachToOperation(op); + } + + void notifyOpInserted(Operation *op, Block *oldBlock, Block *newBlock) { + for (auto &listener : listeners) + listener->notifyOpInserted(op, oldBlock, newBlock); + } + void notifyOpDestroyed(Operation *op) { + for (auto &listener : listeners) + listener->notifyOpDestroyed(op); + } + void notifyOpDetached(Operation *op) { + for (auto &listener : listeners) + listener->notifyOpDetached(op); + } + void notifyOpMoved(Operation *op) { + for (auto &listener : listeners) + listener->notifyOpMoved(op); + } + void notifyOperandChanged(detail::IROperandBase &operand, Value newValue) { + for (auto &listener : listeners) + listener->notifyOpOperandChanged(static_cast(operand), + newValue); + } + void notifyOperandChanged(detail::IROperandBase &operand, Block *newBlock) { + for (auto &listener : listeners) + listener->notifyBlockOperandChanged(static_cast(operand), + newBlock); + } + void addListener(llvm::IntrusiveRefCntPtr listener) { + listeners.push_back(std::move(listener)); + } + void append(const IRListeners &other) { + listeners.insert(listeners.end(), other.listeners.begin(), + other.listeners.end()); + } + void erase(const llvm::IntrusiveRefCntPtr &listener) { + llvm::erase_value(listeners, listener); + } + +private: + std::vector> listeners; +}; + +} // namespace mlir + +#endif // MLIR_IR_LISTENERS_H diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -16,8 +16,10 @@ #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Listeners.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/Twine.h" #include @@ -715,6 +717,15 @@ /// handlers that may be listening. InFlightDiagnostic emitRemark(const Twine &message = {}); + /// Attach a listener to this operation. + void addListener(llvm::IntrusiveRefCntPtr listener) { + if (!listeners) + listeners = std::make_unique(); + listener->attachToOperation(this); + listeners->addListener(std::move(listener)); + } + IRListeners *getListeners() { return listeners.get(); } + private: //===--------------------------------------------------------------------===// // Ordering @@ -836,6 +847,10 @@ /// This holds general named attributes for the operation. DictionaryAttr attrs; + /// Optionally defined listeners that register here to capture IR modification + /// events. + std::unique_ptr listeners; + // allow ilist_traits access to 'block' field. friend struct llvm::ilist_traits; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -843,6 +843,9 @@ /// Print users of values as comments. OpPrintingFlags &printValueUsers(); + /// Do not print regions. + OpPrintingFlags &elideRegions(bool elide = true); + /// Return if the given ElementsAttr should be elided. bool shouldElideElementsAttr(ElementsAttr attr) const; @@ -867,6 +870,9 @@ /// Return if the printer should print users of values. bool shouldPrintValueUsers() const; + /// Return if the printer should elide regions. + bool shouldElideRegions() const; + private: /// Elide large elements attributes if the number of elements is larger than /// the upper limit. @@ -887,6 +893,9 @@ /// Print users of values. bool printValueUsersFlag : 1; + + /// Elide regions. + bool elideRegionsFlag : 1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -140,6 +140,7 @@ void set(IRValueT newValue) { // It isn't worth optimizing for the case of switching operands on a single // value. + static_cast(this)->notifyOperandChanged(newValue); removeFromCurrent(); value = newValue; insertIntoCurrent(); diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -262,6 +262,8 @@ /// Return which operand this is in the OpOperand list of the Operation. unsigned getOperandNumber(); + void notifyOperandChanged(Value newValue); + private: /// Keep the constructor private and accessible to the OperandStorage class /// only to avoid hard-to-debug typo/programming mistakes. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -185,7 +185,8 @@ OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), printGenericOpFormFlag(false), assumeVerifiedFlag(false), - printLocalScope(false), printValueUsersFlag(false) { + printLocalScope(false), printValueUsersFlag(false), + elideRegionsFlag(false) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; @@ -244,6 +245,12 @@ return *this; } +/// Elide regions. +OpPrintingFlags &OpPrintingFlags::elideRegions(bool elide) { + elideRegionsFlag = elide; + return *this; +} + /// Return if the given ElementsAttr should be elided. bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { return elementsAttrElementLimit && @@ -284,6 +291,9 @@ return printValueUsersFlag; } +/// Return if the printer should elide regions. +bool OpPrintingFlags::shouldElideRegions() const { return elideRegionsFlag; } + /// Returns true if an ElementsAttr with the given number of elements should be /// printed with hex. static bool shouldPrintElementsAttrWithHex(int64_t numElements) { @@ -3339,10 +3349,14 @@ // Print regions. if (op->getNumRegions() != 0) { os << " ("; - interleaveComma(op->getRegions(), [&](Region ®ion) { - printRegion(region, /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); - }); + if (printerFlags.shouldElideRegions()) { + os << op->getNumRegions() << " elided regions..."; + } else { + interleaveComma(op->getRegions(), [&](Region ®ion) { + printRegion(region, /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); + }); + } os << ')'; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -162,6 +162,8 @@ /// Destroy this operation or one of its subclasses. void Operation::destroy() { + if (listeners) + listeners->notifyOpDestroyed(this); // Operations may have additional prefixed allocation, which needs to be // accounted for here when computing the address to free. char *rawMem = reinterpret_cast(this) - @@ -380,8 +382,21 @@ /// keep the block pointer up to date. void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { assert(!op->getBlock() && "already in an operation block!"); + auto oldBlock = op->block; op->block = getContainingBlock(); - + if (op->listeners) + op->listeners->notifyOpInserted(op, oldBlock, op->block); + else if (auto *parentOp = op->block->getParentOp()) { + if (parentOp->listeners) { + op->walk([&](Operation *childOp) { + if (!childOp->listeners) + childOp->listeners = std::make_unique(); + childOp->listeners->append(*parentOp->listeners); + childOp->listeners->attachToOperation(childOp); + }); + op->listeners->notifyOpInserted(op, oldBlock, op->block); + } + } // Invalidate the order on the operation. op->orderIndex = Operation::kInvalidOrderIdx; } @@ -390,6 +405,8 @@ /// We keep the block pointer up to date. void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { assert(op->block && "not already in an operation block!"); + if (op->listeners) + op->listeners->notifyOpDetached(op); op->block = nullptr; } @@ -401,6 +418,9 @@ // Invalidate the ordering of the parent block. curParent->invalidateOpOrder(); + for (auto op = first; op != last; ++op) + if (op->listeners) + op->listeners->notifyOpMoved(&*op); // If we are transferring operations within the same block, the block // pointer doesn't need to be updated. diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -204,6 +204,12 @@ return this - &getOwner()->getBlockOperands()[0]; } +void BlockOperand::notifyOperandChanged(Block *newBlock) { + auto *listeners = getOwner()->getListeners(); + if (listeners) + listeners->notifyOperandChanged(*this, newBlock); +} + //===----------------------------------------------------------------------===// // OpOperand //===----------------------------------------------------------------------===// @@ -212,3 +218,9 @@ unsigned OpOperand::getOperandNumber() { return this - &getOwner()->getOpOperands()[0]; } + +void OpOperand::notifyOperandChanged(Value newValue) { + auto *listeners = getOwner()->getListeners(); + if (listeners) + listeners->notifyOperandChanged(*this, newValue); +} diff --git a/mlir/test/IR/listener.mlir b/mlir/test/IR/listener.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/listener.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -test-ir-listeners -canonicalize 2>&1 | FileCheck %s + +// The test-ir-listeners pass will install a listener that traces out all future +// changes to the IR. The trace is printed when the listener is destroyed, that +// is after the last piece of IR itself is destroyed. + + +// Canonicalization will remove and recreate operations, so that we have something interesting to trace. +func.func @andOfExtSI(%arg0: i8, %arg1: i8) -> i64 { + %ext0 = arith.extsi %arg0 : i8 to i64 + %ext1 = arith.extsi %arg1 : i8 to i64 + %res = arith.andi %ext0, %ext1 : i64 + return %res : i64 +} + +// CHECK: Op inserted: <> = "arith.andi"(%arg0, %arg1) : (i8, i8) -> i8 +// CHECK: Op inserted: <> = "arith.extsi"(%2) : (i8) -> i64 +// CHECK: OpOperand #0 changed on Operation "func.return"(%4) : (i64) -> () to %3 = arith.extsi %2 : i8 to i64 +// CHECK: Op detached: %4 = "arith.andi"(%0, %1) : (i64, i64) -> i64 +// CHECK: Op destroyed: %0 = "arith.andi"(<>, <>) : (i64, i64) -> i64 +// CHECK: Op detached: %1 = "arith.extsi"(%arg1) : (i8) -> i64 +// CHECK: Op destroyed: %0 = "arith.extsi"(<>) : (i8) -> i64 +// CHECK: Op detached: %0 = "arith.extsi"(%arg0) : (i8) -> i64 +// CHECK: Op destroyed: %0 = "arith.extsi"(<>) : (i8) -> i64 +// CHECK: Op destroyed: "builtin.module"() (1 elided regions...) : () -> () +// CHECK: Op detached: "func.func"() (1 elided regions...) {function_type = (i8, i8) -> i64, sym_name = "andOfExtSI"} : () -> () +// CHECK: Op destroyed: "func.func"() (1 elided regions...) {function_type = (i8, i8) -> i64, sym_name = "andOfExtSI"} : () -> () +// CHECK: Op detached: "func.return"(<>) : (<>) -> () +// CHECK: Op destroyed: "func.return"(<>) : (<>) -> () +// CHECK: Op detached: %0 = "arith.extsi"(<>) : (<>) -> i64 +// CHECK: Op destroyed: %0 = "arith.extsi"(<>) : (<>) -> i64 +// CHECK: Op detached: %0 = "arith.andi"(<>, <>) : (<>, <>) -> i8 diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -6,6 +6,7 @@ TestDominance.cpp TestFunc.cpp TestInterfaces.cpp + TestIRListeners.cpp TestMatchers.cpp TestOpaqueLoc.cpp TestOperationEquals.cpp diff --git a/mlir/test/lib/IR/TestIRListeners.cpp b/mlir/test/lib/IR/TestIRListeners.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestIRListeners.cpp @@ -0,0 +1,84 @@ +//===- TestIRListeners.cpp - Pass to test IRListeners ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { +class Listener final : public IRListener { +public: + ~Listener() final { + llvm::errs() << "===== IRListener Trace =====\n"; + llvm::errs() << traceStream.str(); + llvm::errs() << "===== End IRListener Trace =====\n"; + } + void printOp(Operation *op) { + op->print(traceStream, OpPrintingFlags() + .elideLargeElementsAttrs() + .printGenericOpForm() + .assumeVerified() + .useLocalScope() + .elideRegions()); + } + void notifyOpInserted(Operation *op, Block *oldBlock, Block *newBlock) final { + traceStream << "Op inserted: "; + printOp(op); + traceStream << "\n"; + } + void notifyOpDetached(Operation *op) final { + traceStream << "Op detached: "; + printOp(op); + traceStream << "\n"; + } + void notifyOpDestroyed(Operation *op) final { + traceStream << "Op destroyed: "; + printOp(op); + traceStream << "\n"; + } + void notifyOpMoved(Operation *op) final { + traceStream << "Op moved: "; + printOp(op); + traceStream << "\n"; + } + void notifyOpOperandChanged(OpOperand &operand, Value newValue) final { + traceStream << "OpOperand #" << operand.getOperandNumber() + << " changed on Operation "; + printOp(operand.getOwner()); + traceStream << " to " << newValue << "\n"; + } + void notifyBlockOperandChanged(BlockOperand &operand, Block *newBlock) final { + } + +private: + std::string trace; + llvm::raw_string_ostream traceStream{trace}; +}; + +struct TestIRListenersPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRListenersPass) + + StringRef getArgument() const final { return "test-ir-listeners"; } + StringRef getDescription() const final { + return "Test IR Listeners through keeping a trace of all edits."; + } + void runOnOperation() override { + llvm::IntrusiveRefCntPtr listener(new Listener); + getOperation()->walk([&](Operation *op) { op->addListener(listener); }); + } +}; + +} // namespace + +namespace mlir { +void registerTestIRListenersPass() { PassRegistration(); } +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -43,6 +43,7 @@ void registerTestAllReduceLoweringPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); +void registerTestIRListenersPass(); void registerTestLoopPermutationPass(); void registerTestMatchers(); void registerTestOperationEqualPass(); @@ -149,6 +150,7 @@ registerTestAllReduceLoweringPass(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); + registerTestIRListenersPass(); registerTestLoopPermutationPass(); registerTestMatchers(); registerTestOperationEqualPass(); diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -5,6 +5,7 @@ InterfaceTest.cpp IRMapping.cpp InterfaceAttachmentTest.cpp + ListenerTest.cpp OperationSupportTest.cpp PatternMatchTest.cpp ShapedTypeTest.cpp diff --git a/mlir/unittests/IR/ListenerTest.cpp b/mlir/unittests/IR/ListenerTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/IR/ListenerTest.cpp @@ -0,0 +1,245 @@ +//===- ListenerTest.cpp - Test IR modification listners -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestDialect.h" + +using namespace mlir; +using namespace test; +namespace { +class Listener : public IRListener { +public: + ~Listener() override { + if (notifyListenerDestroyedFn) + notifyListenerDestroyedFn(); + } + void notifyOpInserted(Operation *op, Block *oldBlock, Block *newBlock) final { + insertionCount++; + if (notifyOpInsertedFn) + notifyOpInsertedFn(op, oldBlock, newBlock); + } + void notifyOpDetached(Operation *op) final { + detachedCount++; + if (notifyOpDetachedFn) + notifyOpDetachedFn(op); + } + void notifyOpDestroyed(Operation *op) final { + destroyCount++; + if (notifyOpDestroyedFn) + notifyOpDestroyedFn(op); + } + void notifyOpMoved(Operation *op) final { + movedCount++; + if (notifyOpMovedFn) + notifyOpMovedFn(op); + } + void notifyOpOperandChanged(OpOperand &operand, Value newValue) final { + opOperandChangedCount++; + if (notifyOpOperandChangedFn) + notifyOpOperandChangedFn(operand, newValue); + }; + void notifyBlockOperandChanged(BlockOperand &operand, Block *newBlock) final { + blockOperandChangedCount++; + if (notifyBlockOperandChangedFn) + notifyBlockOperandChangedFn(operand, newBlock); + } + + std::function notifyOpInsertedFn; + std::function notifyOpDetachedFn; + std::function notifyOpDestroyedFn; + std::function notifyOpMovedFn; + std::function notifyOpOperandChangedFn; + std::function notifyBlockOperandChangedFn; + std::function notifyListenerDestroyedFn; + int insertionCount = 0; + int detachedCount = 0; + int destroyCount = 0; + int movedCount = 0; + int opOperandChangedCount = 0; + int blockOperandChangedCount = 0; +}; + +TEST(ListenersTest, OpInsertion) { + llvm::IntrusiveRefCntPtr listener(new Listener); + { + MLIRContext context; + context.loadDialect(); + OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context)); + // Set some listener assertions. + listener->notifyOpInsertedFn = [&](Operation *op, Block *oldBlock, + Block *newBlock) { + EXPECT_EQ(op->getName().getStringRef(), "test.side_effect_op"); + EXPECT_EQ(oldBlock, nullptr); + EXPECT_EQ(newBlock, module->getBody()); + }; + listener->notifyOpDetachedFn = [&](Operation *op) { + EXPECT_EQ(op->getName().getStringRef(), "test.side_effect_op"); + EXPECT_EQ(op->getBlock(), module->getBody()); + }; + listener->notifyOpDestroyedFn = [&](Operation *op) { + EXPECT_EQ(op->getName().getStringRef(), "test.side_effect_op"); + }; + listener->notifyOpMovedFn = [&](Operation *op) { + EXPECT_EQ(op->getName().getStringRef(), "test.side_effect_op"); + }; + OpBuilder builder(module->getBody(), module->getBody()->begin()); + + // Adding the listener on the module will ensure any added operation will + // inherit the listener. + module.get()->addListener(listener); + + builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + EXPECT_EQ(listener->insertionCount, 1); + EXPECT_EQ(listener->detachedCount, 0); + EXPECT_EQ(listener->destroyCount, 0); + EXPECT_EQ(listener->movedCount, 0); + auto op2 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + EXPECT_EQ(listener->insertionCount, 2); + EXPECT_EQ(listener->detachedCount, 0); + EXPECT_EQ(listener->destroyCount, 0); + EXPECT_EQ(listener->movedCount, 0); + auto op3 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + EXPECT_EQ(listener->insertionCount, 3); + EXPECT_EQ(listener->detachedCount, 0); + EXPECT_EQ(listener->destroyCount, 0); + EXPECT_EQ(listener->movedCount, 0); + op2->moveAfter(op3); + EXPECT_EQ(listener->insertionCount, 3); + EXPECT_EQ(listener->detachedCount, 0); + EXPECT_EQ(listener->destroyCount, 0); + EXPECT_EQ(listener->movedCount, 1); + op3->erase(); + EXPECT_EQ(listener->insertionCount, 3); + EXPECT_EQ(listener->detachedCount, 1); + EXPECT_EQ(listener->destroyCount, 1); + EXPECT_EQ(listener->movedCount, 1); + + // The module will be deleted, so the listener shouldn't refer to it now. + listener->notifyOpDetachedFn = nullptr; + listener->notifyOpDestroyedFn = nullptr; + } + EXPECT_EQ(listener->insertionCount, 3); + EXPECT_EQ(listener->detachedCount, 3); + EXPECT_EQ(listener->destroyCount, 4); + EXPECT_EQ(listener->movedCount, 1); + EXPECT_EQ(listener->opOperandChangedCount, 0); + EXPECT_EQ(listener->blockOperandChangedCount, 0); +} + +TEST(ListenersTest, OperandChanged) { + llvm::IntrusiveRefCntPtr listener(new Listener); + + MLIRContext context; + context.loadDialect(); + OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module->getBody(), module->getBody()->begin()); + + // Adding the listener on the module will ensure any added operation will + // inherit the listener. + module.get()->addListener(listener); + + Operation *producerOp; + { + OperationState state(UnknownLoc::get(&context), "test.producer_op"); + state.addTypes({builder.getI32Type(), builder.getI32Type()}); + state.addSuccessors(module->getBody()); + state.addRegion(std::make_unique()); + producerOp = builder.create(state); + producerOp->getRegion(0).push_back(new Block()); + } + // Build a consumer op that is using the results of the producer, we'll check + // we correctly catch updates to the operands. + Operation *consumerOp = builder.create( + builder.getUnknownLoc(), producerOp->getResult(0), + producerOp->getResult(1)); + EXPECT_EQ(listener->insertionCount, 2); + EXPECT_EQ(listener->detachedCount, 0); + EXPECT_EQ(listener->destroyCount, 0); + EXPECT_EQ(listener->movedCount, 0); + EXPECT_EQ(listener->opOperandChangedCount, 0); + EXPECT_EQ(listener->blockOperandChangedCount, 0); + producerOp->getResult(0).replaceAllUsesWith(producerOp->getResult(1)); + EXPECT_EQ(listener->opOperandChangedCount, 1); + EXPECT_EQ(listener->blockOperandChangedCount, 0); + producerOp->setSuccessor(&producerOp->getRegion(0).front(), 0); + EXPECT_EQ(listener->opOperandChangedCount, 1); + EXPECT_EQ(listener->blockOperandChangedCount, 1); + // Test that we can detach the listener. + module.get()->getListeners()->erase(listener); + consumerOp->getListeners()->erase(listener); + + // Check that the listener is properly detached when the operation is. + // Here the operands of the consumer op are modified but only the produceOp + // has a listener. + producerOp->getResult(1).replaceAllUsesWith(producerOp->getResult(0)); + EXPECT_EQ(listener->opOperandChangedCount, 1); + EXPECT_EQ(listener->blockOperandChangedCount, 1); + + bool listenerDestroyed = false; + listener->notifyListenerDestroyedFn = [&]() { listenerDestroyed = true; }; + int opDestroyedCount = 0; + listener->notifyOpDestroyedFn = [&](Operation *op) { + EXPECT_EQ(op, producerOp); + opDestroyedCount++; + }; + + // Release the listener here, but the producerOp still has a reference to it. + listener->Release(); + EXPECT_EQ(listenerDestroyed, false); + // Delete the entire module, deleting the producerOp will delete the listener. + module = nullptr; + // only the producerOp should have notified for destruction here. + EXPECT_EQ(opDestroyedCount, 1); + EXPECT_EQ(listenerDestroyed, true); +} + +TEST(ListenersTest, RecursiveInheritance) { + // Check that when we insert an operation with a region inside a block where + // the parent op has a listener, we propagate the listener to the operation in + // the regions of the current op. + llvm::IntrusiveRefCntPtr listener(new Listener); + + MLIRContext context; + context.loadDialect(); + OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module->getBody(), module->getBody()->begin()); + + // Adding the listener on the module will ensure any added operation will + // inherit the listener. + module.get()->addListener(listener); + + Operation *opWithRegion; + { + OperationState state(UnknownLoc::get(&context), "test.producer_op"); + state.addTypes({builder.getI32Type(), builder.getI32Type()}); + state.addSuccessors(module->getBody()); + state.addRegion(std::make_unique()); + opWithRegion = OpBuilder(&context).create(state); + opWithRegion->getRegion(0).push_back(new Block()); + Block *body = &opWithRegion->getRegion(0).front(); + OpBuilder::atBlockBegin(body).template create( + builder.getUnknownLoc(), builder.getI32Type()); + } + EXPECT_EQ(listener->insertionCount, 0); + module->getBody()->push_back(opWithRegion); + EXPECT_EQ(listener->insertionCount, 1); +} + +} // namespace \ No newline at end of file