diff --git a/mlir/include/mlir/IR/Listeners.h b/mlir/include/mlir/IR/Listeners.h --- a/mlir/include/mlir/IR/Listeners.h +++ b/mlir/include/mlir/IR/Listeners.h @@ -59,38 +59,54 @@ } }; -/// A listener that forwards all notifications to another listener. This +/// A listener that forwards all notifications to one or more listeners. This /// struct can be used as a base to create listener chains, so that multiple /// listeners can be notified of IR changes. struct ForwardingRewriteListener : public RewriteListener { - ForwardingRewriteListener(RewriteListener *listener) : listener(listener) {} + ForwardingRewriteListener(ArrayRef listeners) + : listeners(listeners) { +#ifndef NDEBUG + for (RewriteListener *listener : listeners) { + assert(listener && "expected non-null listener"); + } +#endif // NDEBUG + } void notifyOperationInserted(Operation *op) override { - listener->notifyOperationInserted(op); + for (RewriteListener *listener : listeners) + listener->notifyOperationInserted(op); } void notifyBlockCreated(Block *block) override { - listener->notifyBlockCreated(block); + for (RewriteListener *listener : listeners) + listener->notifyBlockCreated(block); } void notifyOperationModified(Operation *op) override { - listener->notifyOperationModified(op); + for (RewriteListener *listener : listeners) + listener->notifyOperationModified(op); } void notifyOperationReplaced(Operation *op, Operation *newOp) override { - listener->notifyOperationReplaced(op, newOp); + for (RewriteListener *listener : listeners) + listener->notifyOperationReplaced(op, newOp); } void notifyOperationReplaced(Operation *op, ValueRange replacement) override { - listener->notifyOperationReplaced(op, replacement); + for (RewriteListener *listener : listeners) + listener->notifyOperationReplaced(op, replacement); } void notifyOperationRemoved(Operation *op) override { - listener->notifyOperationRemoved(op); + for (RewriteListener *listener : listeners) + listener->notifyOperationRemoved(op); } LogicalResult notifyMatchFailure(Location loc, function_ref reasonCallback) override { - return listener->notifyMatchFailure(loc, reasonCallback); + for (RewriteListener *listener : listeners) + if (succeeded(listener->notifyMatchFailure(loc, reasonCallback))) + return success(); + return failure(); } private: - RewriteListener *listener; + SmallVector listeners; }; } // namespace mlir